This commit is contained in:
retoor 2025-06-07 05:47:54 +02:00
parent f182c2209e
commit 389d417c1b
45 changed files with 2158 additions and 1235 deletions

View File

@ -1,27 +1,31 @@
import click
import uvloop
from aiohttp import web
import asyncio
from snek.app import Application
from IPython import start_ipython
import sqlite3
import pathlib import pathlib
import shutil import shutil
import sqlite3
import click
from aiohttp import web
from IPython import start_ipython
from snek.app import Application
@click.group() @click.group()
def cli(): def cli():
pass pass
@cli.command() @cli.command()
@click.option('--db_path',default="snek.db", help='Database to initialize if not exists.') @click.option(
@click.option('--source',default=None, help='Database to initialize if not exists.') "--db_path", default="snek.db", help="Database to initialize if not exists."
def init(db_path,source): )
@click.option("--source", default=None, help="Database to initialize if not exists.")
def init(db_path, source):
if source and pathlib.Path(source).exists(): if source and pathlib.Path(source).exists():
print(f"Copying {source} to {db_path}") print(f"Copying {source} to {db_path}")
shutil.copy2(source,db_path) shutil.copy2(source, db_path)
print("Database initialized.") print("Database initialized.")
return return
if pathlib.Path(db_path).exists(): if pathlib.Path(db_path).exists():
return return
print(f"Initializing database at {db_path}") print(f"Initializing database at {db_path}")
@ -33,25 +37,44 @@ def init(db_path,source):
db.close() db.close()
print("Database initialized.") print("Database initialized.")
@cli.command()
@click.option('--port', default=8081, show_default=True, help='Port to run the application on')
@click.option('--host', default='0.0.0.0', show_default=True, help='Host to run the application on')
@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application')
def serve(port, host, db_path):
#init(db_path)
#asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
web.run_app(
Application(db_path=f"sqlite:///{db_path}"), port=port, host=host
)
@cli.command() @cli.command()
@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application') @click.option(
"--port", default=8081, show_default=True, help="Port to run the application on"
)
@click.option(
"--host",
default="0.0.0.0",
show_default=True,
help="Host to run the application on",
)
@click.option(
"--db_path",
default="snek.db",
show_default=True,
help="Database path for the application",
)
def serve(port, host, db_path):
# init(db_path)
# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
web.run_app(Application(db_path=f"sqlite:///{db_path}"), port=port, host=host)
@cli.command()
@click.option(
"--db_path",
default="snek.db",
show_default=True,
help="Database path for the application",
)
def shell(db_path): def shell(db_path):
app = Application(db_path=f"sqlite:///{db_path}") app = Application(db_path=f"sqlite:///{db_path}")
start_ipython(argv=[], user_ns={'app': app}) start_ipython(argv=[], user_ns={"app": app})
def main(): def main():
cli() cli()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -4,13 +4,15 @@ import pathlib
import ssl import ssl
import uuid import uuid
from datetime import datetime from datetime import datetime
from snek import snode from snek import snode
from snek.view.threads import ThreadsView from snek.view.threads import ThreadsView
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
from ipaddress import ip_address
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address
import IP2Location
from aiohttp import web from aiohttp import web
from aiohttp_session import ( from aiohttp_session import (
get_session as session_get, get_session as session_get,
@ -20,54 +22,56 @@ from aiohttp_session import (
from aiohttp_session.cookie_storage import EncryptedCookieStorage from aiohttp_session.cookie_storage import EncryptedCookieStorage
from app.app import Application as BaseApplication from app.app import Application as BaseApplication
from jinja2 import FileSystemLoader from jinja2 import FileSystemLoader
import IP2Location
from snek.sssh import start_ssh_server
from snek.docs.app import Application as DocsApplication
from snek.mapper import get_mappers from snek.mapper import get_mappers
from snek.service import get_services from snek.service import get_services
from snek.sgit import GitApplication
from snek.sssh import start_ssh_server
from snek.system import http from snek.system import http
from snek.system.cache import Cache from snek.system.cache import Cache
from snek.system.markdown import MarkdownExtension from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler from snek.system.profiler import profiler_handler
from snek.system.template import EmojiExtension, LinkifyExtension, PythonExtension from snek.system.template import (
EmojiExtension,
LinkifyExtension,
PythonExtension,
sanitize_html,
)
from snek.view.about import AboutHTMLView, AboutMDView from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView from snek.view.avatar import AvatarView
from snek.view.channel import ChannelAttachmentView, ChannelView
from snek.view.docs import DocsHTMLView, DocsMDView from snek.view.docs import DocsHTMLView, DocsMDView
from snek.view.drive import DriveView from snek.view.drive import DriveApiView, DriveView
from snek.view.drive import DriveApiView
from snek.view.index import IndexView from snek.view.index import IndexView
from snek.view.login import LoginView from snek.view.login import LoginView
from snek.view.push import PushView
from snek.view.logout import LogoutView from snek.view.logout import LogoutView
from snek.view.push import PushView
from snek.view.register import RegisterView from snek.view.register import RegisterView
from snek.view.rpc import RPCView
from snek.view.repository import RepositoryView from snek.view.repository import RepositoryView
from snek.view.rpc import RPCView
from snek.view.search_user import SearchUserView from snek.view.search_user import SearchUserView
from snek.view.settings.repositories import RepositoriesIndexView from snek.view.settings.containers import (
from snek.view.settings.repositories import RepositoriesCreateView ContainersCreateView,
from snek.view.settings.repositories import RepositoriesUpdateView ContainersDeleteView,
from snek.view.settings.repositories import RepositoriesDeleteView ContainersIndexView,
ContainersUpdateView,
)
from snek.view.settings.index import SettingsIndexView from snek.view.settings.index import SettingsIndexView
from snek.view.settings.profile import SettingsProfileView from snek.view.settings.profile import SettingsProfileView
from snek.view.settings.repositories import (
RepositoriesCreateView,
RepositoriesDeleteView,
RepositoriesIndexView,
RepositoriesUpdateView,
)
from snek.view.stats import StatsView from snek.view.stats import StatsView
from snek.view.status import StatusView from snek.view.status import StatusView
from snek.view.terminal import TerminalSocketView, TerminalView from snek.view.terminal import TerminalSocketView, TerminalView
from snek.view.upload import UploadView from snek.view.upload import UploadView
from snek.view.user import UserView from snek.view.user import UserView
from snek.view.web import WebView from snek.view.web import WebView
from snek.view.channel import ChannelAttachmentView
from snek.view.channel import ChannelView
from snek.view.settings.containers import (
ContainersIndexView,
ContainersCreateView,
ContainersUpdateView,
ContainersDeleteView,
)
from snek.webdav import WebdavApplication from snek.webdav import WebdavApplication
from snek.system.template import sanitize_html
from snek.sgit import GitApplication
SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34" SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34"
from snek.system.template import whitelist_attributes from snek.system.template import whitelist_attributes
@ -94,7 +98,7 @@ async def ip2location_middleware(request, handler):
if not user: if not user:
return response return response
location = request.app.ip2location.get(ip) location = request.app.ip2location.get(ip)
original_city = user["city"] user["city"]
if user["city"] != location.city: if user["city"] != location.city:
user["country_long"] = location.country user["country_long"] = location.country
user["country_short"] = locaion.country_short user["country_short"] = locaion.country_short
@ -121,7 +125,7 @@ class Application(BaseApplication):
cors_middleware, cors_middleware,
web.normalize_path_middleware(merge_slashes=True), web.normalize_path_middleware(merge_slashes=True),
ip2location_middleware, ip2location_middleware,
csp_middleware csp_middleware,
] ]
self.template_path = pathlib.Path(__file__).parent.joinpath("templates") self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
self.static_path = pathlib.Path(__file__).parent.joinpath("static") self.static_path = pathlib.Path(__file__).parent.joinpath("static")
@ -139,7 +143,7 @@ class Application(BaseApplication):
self.jinja2_env.add_extension(LinkifyExtension) self.jinja2_env.add_extension(LinkifyExtension)
self.jinja2_env.add_extension(PythonExtension) self.jinja2_env.add_extension(PythonExtension)
self.jinja2_env.add_extension(EmojiExtension) self.jinja2_env.add_extension(EmojiExtension)
self.jinja2_env.filters['sanitize'] = sanitize_html self.jinja2_env.filters["sanitize"] = sanitize_html
self.time_start = datetime.now() self.time_start = datetime.now()
self.ssh_host = "0.0.0.0" self.ssh_host = "0.0.0.0"
self.ssh_port = 2242 self.ssh_port = 2242
@ -272,8 +276,12 @@ class Application(BaseApplication):
self.router.add_get("/http-photo", self.handle_http_photo) self.router.add_get("/http-photo", self.handle_http_photo)
self.router.add_get("/rpc.ws", RPCView) self.router.add_get("/rpc.ws", RPCView)
self.router.add_get("/c/{channel:.*}", ChannelView) self.router.add_get("/c/{channel:.*}", ChannelView)
self.router.add_view("/channel/{channel_uid}/attachment.bin",ChannelAttachmentView) self.router.add_view(
self.router.add_view("/channel/attachment/{relative_url:.*}",ChannelAttachmentView) "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
)
self.router.add_view(
"/channel/attachment/{relative_url:.*}", ChannelAttachmentView
)
self.router.add_view("/channel/{channel}.html", WebView) self.router.add_view("/channel/{channel}.html", WebView)
self.router.add_view("/threads.html", ThreadsView) self.router.add_view("/threads.html", ThreadsView)
self.router.add_view("/terminal.ws", TerminalSocketView) self.router.add_view("/terminal.ws", TerminalSocketView)
@ -284,15 +292,29 @@ class Application(BaseApplication):
self.router.add_view("/stats.json", StatsView) self.router.add_view("/stats.json", StatsView)
self.router.add_view("/user/{user}.html", UserView) self.router.add_view("/user/{user}.html", UserView)
self.router.add_view("/repository/{username}/{repository}", RepositoryView) self.router.add_view("/repository/{username}/{repository}", RepositoryView)
self.router.add_view("/repository/{username}/{repository}/{path:.*}", RepositoryView) self.router.add_view(
"/repository/{username}/{repository}/{path:.*}", RepositoryView
)
self.router.add_view("/settings/repositories/index.html", RepositoriesIndexView) self.router.add_view("/settings/repositories/index.html", RepositoriesIndexView)
self.router.add_view("/settings/repositories/create.html", RepositoriesCreateView) self.router.add_view(
self.router.add_view("/settings/repositories/repository/{name}/update.html", RepositoriesUpdateView) "/settings/repositories/create.html", RepositoriesCreateView
self.router.add_view("/settings/repositories/repository/{name}/delete.html", RepositoriesDeleteView) )
self.router.add_view(
"/settings/repositories/repository/{name}/update.html",
RepositoriesUpdateView,
)
self.router.add_view(
"/settings/repositories/repository/{name}/delete.html",
RepositoriesDeleteView,
)
self.router.add_view("/settings/containers/index.html", ContainersIndexView) self.router.add_view("/settings/containers/index.html", ContainersIndexView)
self.router.add_view("/settings/containers/create.html", ContainersCreateView) self.router.add_view("/settings/containers/create.html", ContainersCreateView)
self.router.add_view("/settings/containers/container/{uid}/update.html", ContainersUpdateView) self.router.add_view(
self.router.add_view("/settings/containers/container/{uid}/delete.html", ContainersDeleteView) "/settings/containers/container/{uid}/update.html", ContainersUpdateView
)
self.router.add_view(
"/settings/containers/container/{uid}/delete.html", ContainersDeleteView
)
self.webdav = WebdavApplication(self) self.webdav = WebdavApplication(self)
self.git = GitApplication(self) self.git = GitApplication(self)
self.add_subapp("/webdav", self.webdav) self.add_subapp("/webdav", self.webdav)
@ -302,9 +324,9 @@ class Application(BaseApplication):
async def handle_test(self, request): async def handle_test(self, request):
return await whitelist_attributes(self.render_template( return await whitelist_attributes(
"test.html", request, context={"name": "retoor"} self.render_template("test.html", request, context={"name": "retoor"})
)) )
async def handle_http_get(self, request: web.Request): async def handle_http_get(self, request: web.Request):
url = request.query.get("url") url = request.query.get("url")
@ -375,8 +397,8 @@ class Application(BaseApplication):
self.jinja2_env.loader = self.original_loader self.jinja2_env.loader = self.original_loader
#rendered.text = whitelist_attributes(rendered.text) # rendered.text = whitelist_attributes(rendered.text)
#rendered.headers['Content-Lenght'] = len(rendered.text) # rendered.headers['Content-Lenght'] = len(rendered.text)
return rendered return rendered
async def static_handler(self, request): async def static_handler(self, request):
@ -423,7 +445,7 @@ app = Application(db_path="sqlite:///snek.db")
async def main(): async def main():
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain('cert.pem', 'key.pem') ssl_context.load_cert_chain("cert.pem", "key.pem")
await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context) await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context)

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import sys import sys
class LoadBalancer: class LoadBalancer:
def __init__(self, backend_ports): def __init__(self, backend_ports):
self.backend_ports = backend_ports self.backend_ports = backend_ports
@ -8,27 +9,29 @@ class LoadBalancer:
self.client_counts = [0] * len(backend_ports) self.client_counts = [0] * len(backend_ports)
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
async def start_backend_servers(self,port,workers): async def start_backend_servers(self, port, workers):
for x in range(workers): for x in range(workers):
port += 1 port += 1
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
sys.executable, sys.executable,
sys.argv[0], sys.argv[0],
'backend', "backend",
str(port), str(port),
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE stderr=asyncio.subprocess.PIPE,
) )
port += 1 port += 1
self.backend_processes.append(process) self.backend_processes.append(process)
print(f"Started backend server on port {(port-1)/port} with PID {process.pid}") print(
f"Started backend server on port {(port-1)/port} with PID {process.pid}"
)
async def handle_client(self, reader, writer): async def handle_client(self, reader, writer):
async with self.lock: async with self.lock:
min_clients = min(self.client_counts) min_clients = min(self.client_counts)
server_index = self.client_counts.index(min_clients) server_index = self.client_counts.index(min_clients)
self.client_counts[server_index] += 1 self.client_counts[server_index] += 1
backend = ('127.0.0.1', self.backend_ports[server_index]) backend = ("127.0.0.1", self.backend_ports[server_index])
try: try:
backend_reader, backend_writer = await asyncio.open_connection(*backend) backend_reader, backend_writer = await asyncio.open_connection(*backend)
@ -62,10 +65,10 @@ class LoadBalancer:
for i, count in enumerate(self.client_counts): for i, count in enumerate(self.client_counts):
print(f"Server {self.backend_ports[i]}: {count} clients") print(f"Server {self.backend_ports[i]}: {count} clients")
async def start(self, host='0.0.0.0', port=8081,workers=5): async def start(self, host="0.0.0.0", port=8081, workers=5):
await self.start_backend_servers(port,workers) await self.start_backend_servers(port, workers)
server = await asyncio.start_server(self.handle_client, host, port) server = await asyncio.start_server(self.handle_client, host, port)
monitor_task = asyncio.create_task(self.monitor()) asyncio.create_task(self.monitor())
# Handle shutdown gracefully # Handle shutdown gracefully
try: try:
@ -80,6 +83,7 @@ class LoadBalancer:
await asyncio.gather(*(p.wait() for p in self.backend_processes)) await asyncio.gather(*(p.wait() for p in self.backend_processes))
print("Backend processes terminated.") print("Backend processes terminated.")
async def backend_echo_server(port): async def backend_echo_server(port):
async def handle_echo(reader, writer): async def handle_echo(reader, writer):
try: try:
@ -94,10 +98,11 @@ async def backend_echo_server(port):
finally: finally:
writer.close() writer.close()
server = await asyncio.start_server(handle_echo, '127.0.0.1', port) server = await asyncio.start_server(handle_echo, "127.0.0.1", port)
print(f"Backend echo server running on port {port}") print(f"Backend echo server running on port {port}")
await server.serve_forever() await server.serve_forever()
async def main(): async def main():
backend_ports = [8001, 8003, 8005, 8006] backend_ports = [8001, 8003, 8005, 8006]
# Launch backend echo servers # Launch backend echo servers
@ -105,19 +110,20 @@ async def main():
lb = LoadBalancer(backend_ports) lb = LoadBalancer(backend_ports)
await lb.start() await lb.start()
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) > 1: if len(sys.argv) > 1:
if sys.argv[1] == 'backend': if sys.argv[1] == "backend":
port = int(sys.argv[2]) port = int(sys.argv[2])
from snek.app import Application from snek.app import Application
snek = Application(port=port) snek = Application(port=port)
web.run_app(snek, port=port, host='127.0.0.1') web.run_app(snek, port=port, host="127.0.0.1")
elif sys.argv[1] == 'sync': elif sys.argv[1] == "sync":
from snek.sync import app
web.run_app(snek, port=port, host='127.0.0.1') web.run_app(snek, port=port, host="127.0.0.1")
else: else:
try: try:
asyncio.run(main()) asyncio.run(main())
except KeyboardInterrupt: except KeyboardInterrupt:
print("Shutting down...") print("Shutting down...")

View File

@ -1,22 +1,21 @@
import functools import functools
from snek.mapper.channel import ChannelMapper from snek.mapper.channel import ChannelMapper
from snek.mapper.channel_attachment import ChannelAttachmentMapper
from snek.mapper.channel_member import ChannelMemberMapper from snek.mapper.channel_member import ChannelMemberMapper
from snek.mapper.channel_message import ChannelMessageMapper from snek.mapper.channel_message import ChannelMessageMapper
from snek.mapper.container import ContainerMapper
from snek.mapper.drive import DriveMapper from snek.mapper.drive import DriveMapper
from snek.mapper.drive_item import DriveItemMapper from snek.mapper.drive_item import DriveItemMapper
from snek.mapper.notification import NotificationMapper from snek.mapper.notification import NotificationMapper
from snek.mapper.push import PushMapper
from snek.mapper.repository import RepositoryMapper
from snek.mapper.user import UserMapper from snek.mapper.user import UserMapper
from snek.mapper.user_property import UserPropertyMapper from snek.mapper.user_property import UserPropertyMapper
from snek.mapper.repository import RepositoryMapper
from snek.mapper.push import PushMapper
from snek.mapper.channel_attachment import ChannelAttachmentMapper
from snek.mapper.container import ContainerMapper
from snek.system.object import Object from snek.system.object import Object
@functools.cache @functools.cache
def get_mappers(app=None): def get_mappers(app=None):
return Object( return Object(
**{ **{

View File

@ -1,6 +1,7 @@
from snek.model.container import Container from snek.model.container import Container
from snek.system.mapper import BaseMapper from snek.system.mapper import BaseMapper
class ContainerMapper(BaseMapper): class ContainerMapper(BaseMapper):
model_class = Container model_class = Container
table_name = "container" table_name = "container"

View File

@ -1,19 +1,19 @@
import functools import functools
from snek.model.channel import ChannelModel from snek.model.channel import ChannelModel
from snek.model.channel_attachment import ChannelAttachmentModel
from snek.model.channel_member import ChannelMemberModel from snek.model.channel_member import ChannelMemberModel
# from snek.model.channel_message import ChannelMessageModel # from snek.model.channel_message import ChannelMessageModel
from snek.model.channel_message import ChannelMessageModel from snek.model.channel_message import ChannelMessageModel
from snek.model.container import Container
from snek.model.drive import DriveModel from snek.model.drive import DriveModel
from snek.model.drive_item import DriveItemModel from snek.model.drive_item import DriveItemModel
from snek.model.notification import NotificationModel from snek.model.notification import NotificationModel
from snek.model.push_registration import PushRegistrationModel
from snek.model.repository import RepositoryModel
from snek.model.user import UserModel from snek.model.user import UserModel
from snek.model.user_property import UserPropertyModel from snek.model.user_property import UserPropertyModel
from snek.model.repository import RepositoryModel
from snek.model.channel_attachment import ChannelAttachmentModel
from snek.model.push_registration import PushRegistrationModel
from snek.model.container import Container
from snek.system.object import Object from snek.system.object import Object

View File

@ -1,4 +1,3 @@
from snek.system.model import BaseModel
from snek.system.model import BaseModel, ModelField from snek.system.model import BaseModel, ModelField
@ -11,6 +10,6 @@ class ChannelAttachmentModel(BaseModel):
user_uid = ModelField(name="user_uid", required=True, kind=str) user_uid = ModelField(name="user_uid", required=True, kind=str)
mime_type = ModelField(name="type", required=True, kind=str) mime_type = ModelField(name="type", required=True, kind=str)
relative_url = ModelField(name="relative_url", required=True, kind=str) relative_url = ModelField(name="relative_url", required=True, kind=str)
resource_type = ModelField(name="resource_type", required=True, kind=str,value="file") resource_type = ModelField(
name="resource_type", required=True, kind=str, value="file"
)

View File

@ -1,6 +1,8 @@
from datetime import datetime, timezone
from snek.model.user import UserModel from snek.model.user import UserModel
from snek.system.model import BaseModel, ModelField from snek.system.model import BaseModel, ModelField
from datetime import datetime,timezone
class ChannelMessageModel(BaseModel): class ChannelMessageModel(BaseModel):
channel_uid = ModelField(name="channel_uid", required=True, kind=str) channel_uid = ModelField(name="channel_uid", required=True, kind=str)
@ -10,7 +12,11 @@ class ChannelMessageModel(BaseModel):
is_final = ModelField(name="is_final", required=True, kind=bool, value=True) is_final = ModelField(name="is_final", required=True, kind=bool, value=True)
def get_seconds_since_last_update(self): def get_seconds_since_last_update(self):
return int((datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])).total_seconds()) return int(
(
datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])
).total_seconds()
)
async def get_user(self) -> UserModel: async def get_user(self) -> UserModel:
return await self.app.services.user.get(uid=self["user_uid"]) return await self.app.services.user.get(uid=self["user_uid"])

View File

@ -1,5 +1,6 @@
from snek.system.model import BaseModel, ModelField from snek.system.model import BaseModel, ModelField
class Container(BaseModel): class Container(BaseModel):
id = ModelField(name="id", required=True, kind=str) id = ModelField(name="id", required=True, kind=str)
name = ModelField(name="name", required=True, kind=str) name = ModelField(name="name", required=True, kind=str)

View File

@ -9,7 +9,9 @@ class DriveItemModel(BaseModel):
path = ModelField(name="path", required=True, kind=str) path = ModelField(name="path", required=True, kind=str)
file_type = ModelField(name="file_type", required=True, kind=str) file_type = ModelField(name="file_type", required=True, kind=str)
file_size = ModelField(name="file_size", required=True, kind=int) file_size = ModelField(name="file_size", required=True, kind=int)
is_available = ModelField(name="is_available", required=True, kind=bool, initial_value=True) is_available = ModelField(
name="is_available", required=True, kind=bool, initial_value=True
)
@property @property
def extension(self): def extension(self):

View File

@ -1,14 +1,10 @@
from snek.model.user import UserModel
from snek.system.model import BaseModel, ModelField from snek.system.model import BaseModel, ModelField
class RepositoryModel(BaseModel): class RepositoryModel(BaseModel):
user_uid = ModelField(name="user_uid", required=True, kind=str) user_uid = ModelField(name="user_uid", required=True, kind=str)
name = ModelField(name="name", required=True, kind=str) name = ModelField(name="name", required=True, kind=str)
is_private = ModelField(name="is_private", required=False, kind=bool) is_private = ModelField(name="is_private", required=False, kind=bool)

View File

@ -30,7 +30,7 @@ class UserModel(BaseModel):
last_ping = ModelField(name="last_ping", required=False, kind=str) last_ping = ModelField(name="last_ping", required=False, kind=str)
is_admin = ModelField(name="is_admin", required=False, kind=bool) is_admin = ModelField(name="is_admin", required=False, kind=bool)
country_short = ModelField(name="country_short", required=False, kind=str) country_short = ModelField(name="country_short", required=False, kind=str)
country_long = ModelField(name="country_long", required=False, kind=str) country_long = ModelField(name="country_long", required=False, kind=str)
city = ModelField(name="city", required=False, kind=str) city = ModelField(name="city", required=False, kind=str)

View File

@ -1,51 +1,54 @@
import snek.serpentarium import time
import time
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
import snek.serpentarium
durations = [] durations = []
def task1(): def task1():
global durations global durations
client = snek.serpentarium.DatasetWrapper() client = snek.serpentarium.DatasetWrapper()
start=time.time() start = time.time()
for x in range(1500): for x in range(1500):
client['a'].delete() client["a"].delete()
client['a'].insert({"foo": x}) client["a"].insert({"foo": x})
client['a'].find(foo=x) client["a"].find(foo=x)
client['a'].find_one(foo=x) client["a"].find_one(foo=x)
client['a'].count() client["a"].count()
#print(client['a'].find(foo=x) ) # print(client['a'].find(foo=x) )
#print(client['a'].find_one(foo=x) ) # print(client['a'].find_one(foo=x) )
#print(client['a'].count()) # print(client['a'].count())
client.close() client.close()
duration1 = f"{time.time()-start}" duration1 = f"{time.time()-start}"
durations.append(duration1) durations.append(duration1)
print(durations) print(durations)
with ProcessPoolExecutor(max_workers=4) as executor: with ProcessPoolExecutor(max_workers=4) as executor:
tasks = [executor.submit(task1), tasks = [
executor.submit(task1),
executor.submit(task1),
executor.submit(task1), executor.submit(task1),
executor.submit(task1), executor.submit(task1),
executor.submit(task1)
] ]
for task in tasks: for task in tasks:
task.result() task.result()
import dataset import dataset
client = dataset.connect("sqlite:///snek.db") client = dataset.connect("sqlite:///snek.db")
start=time.time() start = time.time()
for x in range(1500): for x in range(1500):
client['a'].delete() client["a"].delete()
client['a'].insert({"foo": x}) client["a"].insert({"foo": x})
print([dict(row) for row in client['a'].find(foo=x)]) print([dict(row) for row in client["a"].find(foo=x)])
print(dict(client['a'].find_one(foo=x) )) print(dict(client["a"].find_one(foo=x)))
print(client['a'].count()) print(client["a"].count())
duration2 = f"{time.time()-start}" duration2 = f"{time.time()-start}"
print(duration1,duration2) print(duration1, duration2)

View File

@ -1,22 +1,22 @@
import functools import functools
from snek.service.channel import ChannelService from snek.service.channel import ChannelService
from snek.service.channel_attachment import ChannelAttachmentService
from snek.service.channel_member import ChannelMemberService from snek.service.channel_member import ChannelMemberService
from snek.service.channel_message import ChannelMessageService from snek.service.channel_message import ChannelMessageService
from snek.service.chat import ChatService from snek.service.chat import ChatService
from snek.service.container import ContainerService
from snek.service.db import DBService
from snek.service.drive import DriveService from snek.service.drive import DriveService
from snek.service.drive_item import DriveItemService from snek.service.drive_item import DriveItemService
from snek.service.notification import NotificationService from snek.service.notification import NotificationService
from snek.service.push import PushService
from snek.service.repository import RepositoryService
from snek.service.socket import SocketService from snek.service.socket import SocketService
from snek.service.user import UserService from snek.service.user import UserService
from snek.service.push import PushService
from snek.service.user_property import UserPropertyService from snek.service.user_property import UserPropertyService
from snek.service.util import UtilService from snek.service.util import UtilService
from snek.service.repository import RepositoryService
from snek.service.channel_attachment import ChannelAttachmentService
from snek.service.container import ContainerService
from snek.system.object import Object from snek.system.object import Object
from snek.service.db import DBService
@functools.cache @functools.cache

View File

@ -1,9 +1,9 @@
import pathlib
from datetime import datetime from datetime import datetime
from snek.system.model import now from snek.system.model import now
from snek.system.service import BaseService from snek.system.service import BaseService
import pathlib
class ChannelService(BaseService): class ChannelService(BaseService):
mapper_name = "channel" mapper_name = "channel"
@ -17,12 +17,10 @@ class ChannelService(BaseService):
pass pass
return folder return folder
async def get_attachment_folder(self, channel_uid,ensure=False): async def get_attachment_folder(self, channel_uid, ensure=False):
path = pathlib.Path(f"./drive/{channel_uid}/attachments") path = pathlib.Path(f"./drive/{channel_uid}/attachments")
if ensure: if ensure:
path.mkdir( path.mkdir(parents=True, exist_ok=True)
parents=True, exist_ok=True
)
return path return path
async def get(self, uid=None, **kwargs): async def get(self, uid=None, **kwargs):
@ -78,7 +76,10 @@ class ChannelService(BaseService):
return channel return channel
async def get_recent_users(self, channel_uid): async def get_recent_users(self, channel_uid):
async for user in self.query("SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30", {"channel_uid": channel_uid}): async for user in self.query(
"SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30",
{"channel_uid": channel_uid},
):
yield user yield user
async def get_users(self, channel_uid): async def get_users(self, channel_uid):

View File

@ -1,25 +1,25 @@
from snek.system.service import BaseService
import urllib.parse
import pathlib
import mimetypes import mimetypes
import uuid
from snek.system.service import BaseService
class ChannelAttachmentService(BaseService): class ChannelAttachmentService(BaseService):
mapper_name="channel_attachment" mapper_name = "channel_attachment"
async def create_file(self, channel_uid, user_uid, name): async def create_file(self, channel_uid, user_uid, name):
attachment = await self.new() attachment = await self.new()
attachment["channel_uid"] = channel_uid attachment["channel_uid"] = channel_uid
attachment['user_uid'] = user_uid attachment["user_uid"] = user_uid
attachment["name"] = name attachment["name"] = name
attachment["mime_type"] = mimetypes.guess_type(name)[0] attachment["mime_type"] = mimetypes.guess_type(name)[0]
attachment['resource_type'] = "file" attachment["resource_type"] = "file"
real_file_name = f"{attachment['uid']}-{name}" real_file_name = f"{attachment['uid']}-{name}"
attachment["relative_url"] = (f"{attachment['uid']}-{name}") attachment["relative_url"] = f"{attachment['uid']}-{name}"
attachment_folder = await self.services.channel.get_attachment_folder(channel_uid) attachment_folder = await self.services.channel.get_attachment_folder(
channel_uid
)
attachment_path = attachment_folder.joinpath(real_file_name) attachment_path = attachment_folder.joinpath(real_file_name)
attachment["path"] = str(attachment_path) attachment["path"] = str(attachment_path)
if await self.save(attachment): if await self.save(attachment):
return attachment return attachment
raise Exception(f"Failed to create channel attachment: {attachment.errors}.") raise Exception(f"Failed to create channel attachment: {attachment.errors}.")

View File

@ -11,7 +11,7 @@ class ChannelMessageService(BaseService):
model["channel_uid"] = channel_uid model["channel_uid"] = channel_uid
model["user_uid"] = user_uid model["user_uid"] = user_uid
model["message"] = message model["message"] = message
model['is_final'] = is_final model["is_final"] = is_final
context = {} context = {}
@ -56,7 +56,7 @@ class ChannelMessageService(BaseService):
async def save(self, model): async def save(self, model):
context = {} context = {}
context.update(model.record) context.update(model.record)
user = await self.app.services.user.get(model['user_uid']) user = await self.app.services.user.get(model["user_uid"])
context.update( context.update(
{ {
"user_uid": user["uid"], "user_uid": user["uid"],

View File

@ -13,13 +13,13 @@ class ChatService(BaseService):
channel["last_message_on"] = now() channel["last_message_on"] = now()
await self.services.channel.save(channel) await self.services.channel.save(channel)
await self.services.socket.broadcast( await self.services.socket.broadcast(
channel['uid'], channel["uid"],
{ {
"message": channel_message["message"], "message": channel_message["message"],
"html": channel_message["html"], "html": channel_message["html"],
"user_uid": user['uid'], "user_uid": user["uid"],
"color": user["color"], "color": user["color"],
"channel_uid": channel['uid'], "channel_uid": channel["uid"],
"created_at": channel_message["created_at"], "created_at": channel_message["created_at"],
"updated_at": channel_message["updated_at"], "updated_at": channel_message["updated_at"],
"username": user["username"], "username": user["username"],
@ -32,14 +32,12 @@ class ChatService(BaseService):
self.services.notification.create_channel_message(message_uid) self.services.notification.create_channel_message(message_uid)
) )
async def send(self, user_uid, channel_uid, message, is_final=True): async def send(self, user_uid, channel_uid, message, is_final=True):
channel = await self.services.channel.get(uid=channel_uid) channel = await self.services.channel.get(uid=channel_uid)
if not channel: if not channel:
raise Exception("Channel not found.") raise Exception("Channel not found.")
channel_message = await self.services.channel_message.create( channel_message = await self.services.channel_message.create(
channel_uid, user_uid, message,is_final channel_uid, user_uid, message, is_final
) )
channel_message_uid = channel_message["uid"] channel_message_uid = channel_message["uid"]

View File

@ -1,36 +1,51 @@
from snek.system.docker import ComposeFileManager
from snek.system.service import BaseService from snek.system.service import BaseService
from snek.system.docker import ComposeFileManager
class ContainerService(BaseService): class ContainerService(BaseService):
mapper_name = "container" mapper_name = "container"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.compose_path = "snek-container-compose.yml" self.compose_path = "snek-container-compose.yml"
self.compose = ComposeFileManager(self.compose_path) self.compose = ComposeFileManager(self.compose_path)
async def get_instances(self): async def get_instances(self):
return list(self.compose.list_instances()) return list(self.compose.list_instances())
async def get_container_name(self, channel_uid): async def get_container_name(self, channel_uid):
return f"channel-{channel_uid}" return f"channel-{channel_uid}"
async def create(self, channel_uid, image="ubuntu:latest", command=None, cpus=1, memory='1024m', ports=None, volumes=None): async def create(
self,
channel_uid,
image="ubuntu:latest",
command=None,
cpus=1,
memory="1024m",
ports=None,
volumes=None,
):
name = await self.get_container_name(channel_uid) name = await self.get_container_name(channel_uid)
self.compose.create_instance( self.compose.create_instance(
name, name,
image, image,
command, command,
cpus, cpus,
memory, memory,
ports, ports,
["./"+str((await self.services.channel.get_home_folder(channel_uid))) +":"+ "/home/ubuntu"] [
"./"
+ str(await self.services.channel.get_home_folder(channel_uid))
+ ":"
+ "/home/ubuntu"
],
) )
async def create2(self, id, name, status, resources=None, user_uid=None, path=None, readonly=False): async def create2(
self, id, name, status, resources=None, user_uid=None, path=None, readonly=False
):
model = await self.new() model = await self.new()
model["id"] = id model["id"] = id
model["name"] = name model["name"] = name

View File

@ -1,23 +1,21 @@
from snek.system.service import BaseService import dataset
import dataset
import uuid from snek.system.service import BaseService
from datetime import datetime
class DBService(BaseService): class DBService(BaseService):
async def get_db(self, user_uid): async def get_db(self, user_uid):
home_folder = await self.app.services.user.get_home_folder(user_uid) home_folder = await self.app.services.user.get_home_folder(user_uid)
home_folder.mkdir(parents=True, exist_ok=True) home_folder.mkdir(parents=True, exist_ok=True)
db_path = home_folder.joinpath("snek/user.db") db_path = home_folder.joinpath("snek/user.db")
db_path.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
return dataset.connect("sqlite:///" + str(db_path)) return dataset.connect("sqlite:///" + str(db_path))
async def insert(self, user_uid, table_name, values): async def insert(self, user_uid, table_name, values):
db = await self.get_db(user_uid) db = await self.get_db(user_uid)
return db[table_name].insert(values) return db[table_name].insert(values)
async def update(self, user_uid, table_name, values, filters): async def update(self, user_uid, table_name, values, filters):
db = await self.get_db(user_uid) db = await self.get_db(user_uid)
@ -33,7 +31,7 @@ class DBService(BaseService):
async def find(self, user_uid, table_name, kwargs): async def find(self, user_uid, table_name, kwargs):
db = await self.get_db(user_uid) db = await self.get_db(user_uid)
kwargs['_limit'] = kwargs.get('_limit', 30) kwargs["_limit"] = kwargs.get("_limit", 30)
return [dict(row) for row in db[table_name].find(**kwargs)] return [dict(row) for row in db[table_name].find(**kwargs)]
async def get(self, user_uid, table_name, filters): async def get(self, user_uid, table_name, filters):
@ -45,14 +43,13 @@ class DBService(BaseService):
except ValueError: except ValueError:
return None return None
async def delete(self, user_uid, table_name, filters): async def delete(self, user_uid, table_name, filters):
db = await self.get_db(user_uid) db = await self.get_db(user_uid)
if not filters: if not filters:
filters = {} filters = {}
return db[table_name].delete(**filters) return db[table_name].delete(**filters)
async def query(self, sql,values): async def query(self, sql, values):
db = await self.app.db db = await self.app.db
return [dict(row) for row in db.query(sql, values or {})] return [dict(row) for row in db.query(sql, values or {})]
@ -62,8 +59,6 @@ class DBService(BaseService):
filters = {} filters = {}
return bool(db[table_name].find_one(**filters)) return bool(db[table_name].find_one(**filters))
async def count(self, user_uid, table_name, filters): async def count(self, user_uid, table_name, filters):
db = await self.get_db(user_uid) db = await self.get_db(user_uid)
if not filters: if not filters:

View File

@ -1,6 +1,7 @@
from snek.system.markdown import strip_markdown
from snek.system.model import now from snek.system.model import now
from snek.system.service import BaseService from snek.system.service import BaseService
from snek.system.markdown import strip_markdown
class NotificationService(BaseService): class NotificationService(BaseService):
mapper_name = "notification" mapper_name = "notification"
@ -34,7 +35,7 @@ class NotificationService(BaseService):
uid=channel_message_uid uid=channel_message_uid
) )
if not channel_message["is_final"]: if not channel_message["is_final"]:
return return
user = await self.services.user.get(uid=channel_message["user_uid"]) user = await self.services.user.get(uid=channel_message["user_uid"])
self.app.db.begin() self.app.db.begin()
async for channel_member in self.services.channel_member.find( async for channel_member in self.services.channel_member.find(

View File

@ -1,25 +1,23 @@
import base64
import json import json
import os.path
import aiohttp
from snek.system.service import BaseService
import random import random
import time import time
import base64
import uuid import uuid
from functools import cache
from pathlib import Path from pathlib import Path
import jwt
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from urllib.parse import urlparse from urllib.parse import urlparse
import os.path
import aiohttp
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.hashes import SHA256 from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from snek.system.service import BaseService
# The only reason to persist the keys is to be able to use them in the web push # The only reason to persist the keys is to be able to use them in the web push
PRIVATE_KEY_FILE = Path("./notification-private.pem") PRIVATE_KEY_FILE = Path("./notification-private.pem")
@ -268,4 +266,4 @@ class PushService(BaseService):
raise Exception( raise Exception(
f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}" f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}"
) )

View File

@ -1,13 +1,17 @@
from snek.system.service import BaseService import asyncio
import asyncio
import shutil import shutil
from snek.system.service import BaseService
class RepositoryService(BaseService): class RepositoryService(BaseService):
mapper_name = "repository" mapper_name = "repository"
async def delete(self, user_uid, name): async def delete(self, user_uid, name):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
repository_path = (await self.services.user.get_repository_path(user_uid)).joinpath(name) repository_path = (
await self.services.user.get_repository_path(user_uid)
).joinpath(name)
try: try:
await loop.run_in_executor(None, shutil.rmtree, repository_path) await loop.run_in_executor(None, shutil.rmtree, repository_path)
except Exception as ex: except Exception as ex:
@ -15,7 +19,6 @@ class RepositoryService(BaseService):
await super().delete(user_uid=user_uid, name=name) await super().delete(user_uid=user_uid, name=name)
async def exists(self, user_uid, name, **kwargs): async def exists(self, user_uid, name, **kwargs):
kwargs["user_uid"] = user_uid kwargs["user_uid"] = user_uid
kwargs["name"] = name kwargs["name"] = name
@ -29,18 +32,16 @@ class RepositoryService(BaseService):
repository_path = str(repository_path) repository_path = str(repository_path)
if not repository_path.endswith(".git"): if not repository_path.endswith(".git"):
repository_path += ".git" repository_path += ".git"
command = ['git', 'init', '--bare', repository_path] command = ["git", "init", "--bare", repository_path]
process = await asyncio.subprocess.create_subprocess_exec( process = await asyncio.subprocess.create_subprocess_exec(
*command, *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
return process.returncode == 0 return process.returncode == 0
async def create(self, user_uid, name,is_private=False): async def create(self, user_uid, name, is_private=False):
if await self.exists(user_uid=user_uid, name=name): if await self.exists(user_uid=user_uid, name=name):
return False return False
if not await self.init(user_uid=user_uid, name=name): if not await self.init(user_uid=user_uid, name=name):
return False return False

View File

@ -1,12 +1,14 @@
import asyncio
import logging
from datetime import datetime
from snek.model.user import UserModel from snek.model.user import UserModel
from snek.system.service import BaseService from snek.system.service import BaseService
from datetime import datetime
import json
import asyncio
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from snek.system.model import now from snek.system.model import now
class SocketService(BaseService): class SocketService(BaseService):
class Socket: class Socket:
@ -15,7 +17,6 @@ class SocketService(BaseService):
self.is_connected = True self.is_connected = True
self.user = user self.user = user
async def send_json(self, data): async def send_json(self, data):
if not self.is_connected: if not self.is_connected:
return False return False
@ -40,7 +41,6 @@ class SocketService(BaseService):
self.users = {} self.users = {}
self.subscriptions = {} self.subscriptions = {}
self.last_update = str(datetime.now()) self.last_update = str(datetime.now())
async def user_availability_service(self): async def user_availability_service(self):
logger.info("User availability update service started.") logger.info("User availability update service started.")
@ -50,14 +50,15 @@ class SocketService(BaseService):
for s in self.sockets: for s in self.sockets:
if not s.is_connected: if not s.is_connected:
continue continue
if not s.user in users_updated: if s.user not in users_updated:
s.user["last_ping"] = now() s.user["last_ping"] = now()
await self.app.services.user.save(s.user) await self.app.services.user.save(s.user)
users_updated.append(s.user) users_updated.append(s.user)
logger.info(f"Updated user availability for {len(users_updated)} online users.") logger.info(
f"Updated user availability for {len(users_updated)} online users."
)
await asyncio.sleep(60) await asyncio.sleep(60)
async def add(self, ws, user_uid): async def add(self, ws, user_uid):
s = self.Socket(ws, await self.app.services.user.get(uid=user_uid)) s = self.Socket(ws, await self.app.services.user.get(uid=user_uid))
self.sockets.add(s) self.sockets.add(s)
@ -81,7 +82,6 @@ class SocketService(BaseService):
count += 1 count += 1
return count return count
async def broadcast(self, channel_uid, message): async def broadcast(self, channel_uid, message):
await self._broadcast(channel_uid, message) await self._broadcast(channel_uid, message)
@ -102,4 +102,3 @@ class SocketService(BaseService):
await s.close() await s.close()
logger.info(f"Removed socket for user {s.user['username']}") logger.info(f"Removed socket for user {s.user['username']}")
self.sockets.remove(s) self.sockets.remove(s)

View File

@ -32,14 +32,14 @@ class UserService(BaseService):
user["color"] = await self.services.util.random_light_hex_color() user["color"] = await self.services.util.random_light_hex_color()
return await super().save(user) return await super().save(user)
def authenticate_sync(self,username,password): def authenticate_sync(self, username, password):
user = self.get_by_username_sync(username) user = self.get_by_username_sync(username)
if not user: if not user:
return False return False
if not security.verify_sync(password, user["password"]): if not security.verify_sync(password, user["password"]):
return False return False
return True return True
async def authenticate(self, username, password): async def authenticate(self, username, password):
success = await self.validate_login(username, password) success = await self.validate_login(username, password)
@ -61,14 +61,12 @@ class UserService(BaseService):
return None return None
return path return path
async def get_template_path(self, user_uid): async def get_template_path(self, user_uid):
path = pathlib.Path(f"./drive/{user_uid}/snek/templates") path = pathlib.Path(f"./drive/{user_uid}/snek/templates")
if not path.exists(): if not path.exists():
return None return None
return path return path
def get_by_username_sync(self, username): def get_by_username_sync(self, username):
user = self.mapper.db["user"].find_one(username=username, deleted_at=None) user = self.mapper.db["user"].find_one(username=username, deleted_at=None)
return dict(user) return dict(user)

View File

@ -1,48 +1,51 @@
import asyncio
import base64
import json
import logging
import os import os
import aiohttp import shutil
import tempfile
from aiohttp import web from aiohttp import web
import shutil logging.basicConfig(
import json level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
import tempfile )
import asyncio logger = logging.getLogger("git_server")
import logging
import base64
import pathlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('git_server')
class GitApplication(web.Application): class GitApplication(web.Application):
def __init__(self, parent=None): def __init__(self, parent=None):
#import git # import git
#globals()['git'] = git # globals()['git'] = git
self.parent = parent self.parent = parent
super().__init__(client_max_size=1024*1024*1024*5) super().__init__(client_max_size=1024 * 1024 * 1024 * 5)
self.add_routes([ self.add_routes(
web.post('/create/{repo_name}', self.create_repository), [
web.delete('/delete/{repo_name}', self.delete_repository), web.post("/create/{repo_name}", self.create_repository),
web.get('/clone/{repo_name}', self.clone_repository), web.delete("/delete/{repo_name}", self.delete_repository),
web.post('/push/{repo_name}', self.push_repository), web.get("/clone/{repo_name}", self.clone_repository),
web.post('/pull/{repo_name}', self.pull_repository), web.post("/push/{repo_name}", self.push_repository),
web.get('/status/{repo_name}', self.status_repository), web.post("/pull/{repo_name}", self.pull_repository),
#web.get('/list', self.list_repositories), web.get("/status/{repo_name}", self.status_repository),
web.get('/branches/{repo_name}', self.list_branches), # web.get('/list', self.list_repositories),
web.post('/branches/{repo_name}', self.create_branch), web.get("/branches/{repo_name}", self.list_branches),
web.get('/log/{repo_name}', self.commit_log), web.post("/branches/{repo_name}", self.create_branch),
web.get('/file/{repo_name}/{file_path:.*}', self.file_content), web.get("/log/{repo_name}", self.commit_log),
web.get('/{path:.+}/info/refs', self.git_smart_http), web.get("/file/{repo_name}/{file_path:.*}", self.file_content),
web.post('/{path:.+}/git-upload-pack', self.git_smart_http), web.get("/{path:.+}/info/refs", self.git_smart_http),
web.post('/{path:.+}/git-receive-pack', self.git_smart_http), web.post("/{path:.+}/git-upload-pack", self.git_smart_http),
web.get('/{repo_name}.git/info/refs', self.git_smart_http), web.post("/{path:.+}/git-receive-pack", self.git_smart_http),
web.post('/{repo_name}.git/git-upload-pack', self.git_smart_http), web.get("/{repo_name}.git/info/refs", self.git_smart_http),
web.post('/{repo_name}.git/git-receive-pack', self.git_smart_http), web.post("/{repo_name}.git/git-upload-pack", self.git_smart_http),
]) web.post("/{repo_name}.git/git-receive-pack", self.git_smart_http),
]
)
async def check_basic_auth(self, request): async def check_basic_auth(self, request):
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Basic "): if not auth_header.startswith("Basic "):
return None,None return None, None
encoded_creds = auth_header.split("Basic ")[1] encoded_creds = auth_header.split("Basic ")[1]
decoded_creds = base64.b64decode(encoded_creds).decode() decoded_creds = base64.b64decode(encoded_creds).decode()
username, password = decoded_creds.split(":", 1) username, password = decoded_creds.split(":", 1)
@ -50,27 +53,31 @@ class GitApplication(web.Application):
username=username, password=password username=username, password=password
) )
if not request["user"]: if not request["user"]:
return None,None return None, None
request["repository_path"] = await self.parent.services.user.get_repository_path( request["repository_path"] = (
request["user"]["uid"] await self.parent.services.user.get_repository_path(request["user"]["uid"])
) )
return request["user"]['username'],request["repository_path"] return request["user"]["username"], request["repository_path"]
@staticmethod @staticmethod
def require_auth(handler): def require_auth(handler):
async def wrapped(self, request, *args, **kwargs): async def wrapped(self, request, *args, **kwargs):
username, repository_path = await self.check_basic_auth(request) username, repository_path = await self.check_basic_auth(request)
if not username or not repository_path: if not username or not repository_path:
return web.Response(status=401, headers={'WWW-Authenticate': 'Basic'}, text='Authentication required') return web.Response(
request['username'] = username status=401,
request['repository_path'] = repository_path headers={"WWW-Authenticate": "Basic"},
text="Authentication required",
)
request["username"] = username
request["repository_path"] = repository_path
return await handler(self, request, *args, **kwargs) return await handler(self, request, *args, **kwargs)
return wrapped return wrapped
def repo_path(self, repository_path, repo_name): def repo_path(self, repository_path, repo_name):
return repository_path.joinpath(repo_name + '.git') return repository_path.joinpath(repo_name + ".git")
def check_repo_exists(self, repository_path, repo_name): def check_repo_exists(self, repository_path, repo_name):
repo_dir = self.repo_path(repository_path, repo_name) repo_dir = self.repo_path(repository_path, repo_name)
@ -80,10 +87,10 @@ class GitApplication(web.Application):
@require_auth @require_auth
async def create_repository(self, request): async def create_repository(self, request):
username = request['username'] username = request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
if not repo_name or '/' in repo_name or '..' in repo_name: if not repo_name or "/" in repo_name or ".." in repo_name:
return web.Response(text="Invalid repository name", status=400) return web.Response(text="Invalid repository name", status=400)
repo_dir = self.repo_path(repository_path, repo_name) repo_dir = self.repo_path(repository_path, repo_name)
if os.path.exists(repo_dir): if os.path.exists(repo_dir):
@ -98,9 +105,9 @@ class GitApplication(web.Application):
@require_auth @require_auth
async def delete_repository(self, request): async def delete_repository(self, request):
username = request['username'] username = request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
@ -115,9 +122,9 @@ class GitApplication(web.Application):
@require_auth @require_auth
async def clone_repository(self, request): async def clone_repository(self, request):
username = request['username'] request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
@ -126,15 +133,15 @@ class GitApplication(web.Application):
response_data = { response_data = {
"repository": repo_name, "repository": repo_name,
"clone_command": f"git clone {clone_url}", "clone_command": f"git clone {clone_url}",
"clone_url": clone_url "clone_url": clone_url,
} }
return web.json_response(response_data) return web.json_response(response_data)
@require_auth @require_auth
async def push_repository(self, request): async def push_repository(self, request):
username = request['username'] username = request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
@ -142,34 +149,40 @@ class GitApplication(web.Application):
data = await request.json() data = await request.json()
except json.JSONDecodeError: except json.JSONDecodeError:
return web.Response(text="Invalid JSON data", status=400) return web.Response(text="Invalid JSON data", status=400)
commit_message = data.get('commit_message', 'Update from server') commit_message = data.get("commit_message", "Update from server")
branch = data.get('branch', 'main') branch = data.get("branch", "main")
changes = data.get('changes', []) changes = data.get("changes", [])
if not changes: if not changes:
return web.Response(text="No changes provided", status=400) return web.Response(text="No changes provided", status=400)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
for change in changes: for change in changes:
file_path = os.path.join(temp_dir, change.get('file', '')) file_path = os.path.join(temp_dir, change.get("file", ""))
content = change.get('content', '') content = change.get("content", "")
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as f: with open(file_path, "w") as f:
f.write(content) f.write(content)
temp_repo.git.add(A=True) temp_repo.git.add(A=True)
if not temp_repo.config_reader().has_section('user'): if not temp_repo.config_reader().has_section("user"):
temp_repo.config_writer().set_value("user", "name", "Git Server").release() temp_repo.config_writer().set_value(
temp_repo.config_writer().set_value("user", "email", "git@server.local").release() "user", "name", "Git Server"
).release()
temp_repo.config_writer().set_value(
"user", "email", "git@server.local"
).release()
temp_repo.index.commit(commit_message) temp_repo.index.commit(commit_message)
origin = temp_repo.remote('origin') origin = temp_repo.remote("origin")
origin.push(refspec=f"{branch}:{branch}") origin.push(refspec=f"{branch}:{branch}")
logger.info(f"Pushed to repository: {repo_name} for user {username}") logger.info(f"Pushed to repository: {repo_name} for user {username}")
return web.Response(text=f"Successfully pushed changes to {repo_name}") return web.Response(text=f"Successfully pushed changes to {repo_name}")
@require_auth @require_auth
async def pull_repository(self, request): async def pull_repository(self, request):
username = request['username'] username = request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
@ -177,13 +190,15 @@ class GitApplication(web.Application):
data = await request.json() data = await request.json()
except json.JSONDecodeError: except json.JSONDecodeError:
data = {} data = {}
remote_url = data.get('remote_url') remote_url = data.get("remote_url")
branch = data.get('branch', 'main') branch = data.get("branch", "main")
if not remote_url: if not remote_url:
return web.Response(text="Remote URL is required", status=400) return web.Response(text="Remote URL is required", status=400)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
try: try:
local_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) local_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
remote_name = "pull_source" remote_name = "pull_source"
try: try:
remote = local_repo.create_remote(remote_name, remote_url) remote = local_repo.create_remote(remote_name, remote_url)
@ -192,38 +207,46 @@ class GitApplication(web.Application):
remote.set_url(remote_url) remote.set_url(remote_url)
remote.fetch() remote.fetch()
local_repo.git.merge(f"{remote_name}/{branch}") local_repo.git.merge(f"{remote_name}/{branch}")
origin = local_repo.remote('origin') origin = local_repo.remote("origin")
origin.push() origin.push()
logger.info(f"Pulled to repository {repo_name} from {remote_url} for user {username}") logger.info(
return web.Response(text=f"Successfully pulled changes from {remote_url} to {repo_name}") f"Pulled to repository {repo_name} from {remote_url} for user {username}"
)
return web.Response(
text=f"Successfully pulled changes from {remote_url} to {repo_name}"
)
except Exception as e: except Exception as e:
logger.error(f"Error pulling to {repo_name}: {str(e)}") logger.error(f"Error pulling to {repo_name}: {str(e)}")
return web.Response(text=f"Error pulling changes: {str(e)}", status=500) return web.Response(text=f"Error pulling changes: {str(e)}", status=500)
@require_auth @require_auth
async def status_repository(self, request): async def status_repository(self, request):
username = request['username'] request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
try: try:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
branches = [b.name for b in temp_repo.branches] branches = [b.name for b in temp_repo.branches]
active_branch = temp_repo.active_branch.name active_branch = temp_repo.active_branch.name
commits = [] commits = []
for commit in list(temp_repo.iter_commits(max_count=5)): for commit in list(temp_repo.iter_commits(max_count=5)):
commits.append({ commits.append(
"id": commit.hexsha, {
"author": f"{commit.author.name} <{commit.author.email}>", "id": commit.hexsha,
"date": commit.committed_datetime.isoformat(), "author": f"{commit.author.name} <{commit.author.email}>",
"message": commit.message "date": commit.committed_datetime.isoformat(),
}) "message": commit.message,
}
)
files = [] files = []
for root, dirs, filenames in os.walk(temp_dir): for root, dirs, filenames in os.walk(temp_dir):
if '.git' in root: if ".git" in root:
continue continue
for filename in filenames: for filename in filenames:
full_path = os.path.join(root, filename) full_path = os.path.join(root, filename)
@ -234,50 +257,58 @@ class GitApplication(web.Application):
"branches": branches, "branches": branches,
"active_branch": active_branch, "active_branch": active_branch,
"recent_commits": commits, "recent_commits": commits,
"files": files "files": files,
} }
return web.json_response(status_info) return web.json_response(status_info)
except Exception as e: except Exception as e:
logger.error(f"Error getting status for {repo_name}: {str(e)}") logger.error(f"Error getting status for {repo_name}: {str(e)}")
return web.Response(text=f"Error getting repository status: {str(e)}", status=500) return web.Response(
text=f"Error getting repository status: {str(e)}", status=500
)
@require_auth @require_auth
async def list_repositories(self, request): async def list_repositories(self, request):
username = request['username'] request["username"]
try: try:
repos = [] repos = []
user_dir = self.REPO_DIR user_dir = self.REPO_DIR
if os.path.exists(user_dir): if os.path.exists(user_dir):
for item in os.listdir(user_dir): for item in os.listdir(user_dir):
item_path = os.path.join(user_dir, item) item_path = os.path.join(user_dir, item)
if os.path.isdir(item_path) and item.endswith('.git'): if os.path.isdir(item_path) and item.endswith(".git"):
repos.append(item[:-4]) repos.append(item[:-4])
if request.query.get('format') == 'json': if request.query.get("format") == "json":
return web.json_response({"repositories": repos}) return web.json_response({"repositories": repos})
else: else:
return web.Response(text="\n".join(repos) if repos else "No repositories found") return web.Response(
text="\n".join(repos) if repos else "No repositories found"
)
except Exception as e: except Exception as e:
logger.error(f"Error listing repositories: {str(e)}") logger.error(f"Error listing repositories: {str(e)}")
return web.Response(text=f"Error listing repositories: {str(e)}", status=500) return web.Response(
text=f"Error listing repositories: {str(e)}", status=500
)
@require_auth @require_auth
async def list_branches(self, request): async def list_branches(self, request):
username = request['username'] request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
branches = [b.name for b in temp_repo.branches] branches = [b.name for b in temp_repo.branches]
return web.json_response({"branches": branches}) return web.json_response({"branches": branches})
@require_auth @require_auth
async def create_branch(self, request): async def create_branch(self, request):
username = request['username'] username = request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
@ -285,145 +316,168 @@ class GitApplication(web.Application):
data = await request.json() data = await request.json()
except json.JSONDecodeError: except json.JSONDecodeError:
return web.Response(text="Invalid JSON data", status=400) return web.Response(text="Invalid JSON data", status=400)
branch_name = data.get('branch_name') branch_name = data.get("branch_name")
start_point = data.get('start_point', 'HEAD') start_point = data.get("start_point", "HEAD")
if not branch_name: if not branch_name:
return web.Response(text="Branch name is required", status=400) return web.Response(text="Branch name is required", status=400)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
try: try:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
temp_repo.git.branch(branch_name, start_point) temp_repo.git.branch(branch_name, start_point)
temp_repo.git.push('origin', branch_name) temp_repo.git.push("origin", branch_name)
logger.info(f"Created branch {branch_name} in repository {repo_name} for user {username}") logger.info(
f"Created branch {branch_name} in repository {repo_name} for user {username}"
)
return web.Response(text=f"Created branch {branch_name}") return web.Response(text=f"Created branch {branch_name}")
except Exception as e: except Exception as e:
logger.error(f"Error creating branch {branch_name} in {repo_name}: {str(e)}") logger.error(
f"Error creating branch {branch_name} in {repo_name}: {str(e)}"
)
return web.Response(text=f"Error creating branch: {str(e)}", status=500) return web.Response(text=f"Error creating branch: {str(e)}", status=500)
@require_auth @require_auth
async def commit_log(self, request): async def commit_log(self, request):
username = request['username'] request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
try: try:
limit = int(request.query.get('limit', 10)) limit = int(request.query.get("limit", 10))
branch = request.query.get('branch', 'main') branch = request.query.get("branch", "main")
except ValueError: except ValueError:
return web.Response(text="Invalid limit parameter", status=400) return web.Response(text="Invalid limit parameter", status=400)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
try: try:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
commits = [] commits = []
try: try:
for commit in list(temp_repo.iter_commits(branch, max_count=limit)): for commit in list(temp_repo.iter_commits(branch, max_count=limit)):
commits.append({ commits.append(
"id": commit.hexsha, {
"short_id": commit.hexsha[:7], "id": commit.hexsha,
"author": f"{commit.author.name} <{commit.author.email}>", "short_id": commit.hexsha[:7],
"date": commit.committed_datetime.isoformat(), "author": f"{commit.author.name} <{commit.author.email}>",
"message": commit.message.strip() "date": commit.committed_datetime.isoformat(),
}) "message": commit.message.strip(),
}
)
except git.GitCommandError as e: except git.GitCommandError as e:
if "unknown revision or path" in str(e): if "unknown revision or path" in str(e):
commits = [] commits = []
else: else:
raise raise
return web.json_response({ return web.json_response(
"repository": repo_name, {"repository": repo_name, "branch": branch, "commits": commits}
"branch": branch, )
"commits": commits
})
except Exception as e: except Exception as e:
logger.error(f"Error getting commit log for {repo_name}: {str(e)}") logger.error(f"Error getting commit log for {repo_name}: {str(e)}")
return web.Response(text=f"Error getting commit log: {str(e)}", status=500) return web.Response(
text=f"Error getting commit log: {str(e)}", status=500
)
@require_auth @require_auth
async def file_content(self, request): async def file_content(self, request):
username = request['username'] request["username"]
repo_name = request.match_info['repo_name'] repo_name = request.match_info["repo_name"]
file_path = request.match_info.get('file_path', '') file_path = request.match_info.get("file_path", "")
branch = request.query.get('branch', 'main') branch = request.query.get("branch", "main")
repository_path = request['repository_path'] repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name) error_response = self.check_repo_exists(repository_path, repo_name)
if error_response: if error_response:
return error_response return error_response
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
try: try:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
try: try:
temp_repo.git.checkout(branch) temp_repo.git.checkout(branch)
except git.GitCommandError: except git.GitCommandError:
return web.Response(text=f"Branch '{branch}' not found", status=404) return web.Response(text=f"Branch '{branch}' not found", status=404)
file_full_path = os.path.join(temp_dir, file_path) file_full_path = os.path.join(temp_dir, file_path)
if not os.path.exists(file_full_path): if not os.path.exists(file_full_path):
return web.Response(text=f"File '{file_path}' not found", status=404) return web.Response(
text=f"File '{file_path}' not found", status=404
)
if os.path.isdir(file_full_path): if os.path.isdir(file_full_path):
files = os.listdir(file_full_path) files = os.listdir(file_full_path)
return web.json_response({ return web.json_response(
"repository": repo_name, {
"path": file_path, "repository": repo_name,
"type": "directory", "path": file_path,
"contents": files "type": "directory",
}) "contents": files,
}
)
else: else:
try: try:
with open(file_full_path, 'r') as f: with open(file_full_path) as f:
content = f.read() content = f.read()
return web.Response(text=content) return web.Response(text=content)
except UnicodeDecodeError: except UnicodeDecodeError:
return web.Response(text=f"Cannot display binary file content for '{file_path}'", status=400) return web.Response(
text=f"Cannot display binary file content for '{file_path}'",
status=400,
)
except Exception as e: except Exception as e:
logger.error(f"Error getting file content from {repo_name}: {str(e)}") logger.error(f"Error getting file content from {repo_name}: {str(e)}")
return web.Response(text=f"Error getting file content: {str(e)}", status=500) return web.Response(
text=f"Error getting file content: {str(e)}", status=500
)
@require_auth @require_auth
async def git_smart_http(self, request): async def git_smart_http(self, request):
username = request['username'] request["username"]
repository_path = request['repository_path'] repository_path = request["repository_path"]
path = request.path path = request.path
async def get_repository_path(): async def get_repository_path():
req_path = path.lstrip('/') req_path = path.lstrip("/")
if req_path.endswith('/info/refs'): if req_path.endswith("/info/refs"):
repo_name = req_path[:-len('/info/refs')] repo_name = req_path[: -len("/info/refs")]
elif req_path.endswith('/git-upload-pack'): elif req_path.endswith("/git-upload-pack"):
repo_name = req_path[:-len('/git-upload-pack')] repo_name = req_path[: -len("/git-upload-pack")]
elif req_path.endswith('/git-receive-pack'): elif req_path.endswith("/git-receive-pack"):
repo_name = req_path[:-len('/git-receive-pack')] repo_name = req_path[: -len("/git-receive-pack")]
else: else:
repo_name = req_path repo_name = req_path
if repo_name.endswith('.git'): if repo_name.endswith(".git"):
repo_name = repo_name[:-4] repo_name = repo_name[:-4]
repo_name = repo_name[4:] repo_name = repo_name[4:]
repo_dir = repository_path.joinpath(repo_name + ".git") repo_dir = repository_path.joinpath(repo_name + ".git")
logger.info(f"Resolved repo path: {repo_dir}") logger.info(f"Resolved repo path: {repo_dir}")
return repo_dir return repo_dir
async def handle_info_refs(service): async def handle_info_refs(service):
repo_path = await get_repository_path() repo_path = await get_repository_path()
logger.info(f"handle_info_refs: {repo_path}") logger.info(f"handle_info_refs: {repo_path}")
if not os.path.exists(repo_path): if not os.path.exists(repo_path):
return web.Response(text="Repository not found", status=404) return web.Response(text="Repository not found", status=404)
cmd = [service, '--stateless-rpc', '--advertise-refs', str(repo_path)] cmd = [service, "--stateless-rpc", "--advertise-refs", str(repo_path)]
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode != 0: if process.returncode != 0:
logger.error(f"Git command failed: {stderr.decode()}") logger.error(f"Git command failed: {stderr.decode()}")
return web.Response(text=f"Git error: {stderr.decode()}", status=500) return web.Response(
text=f"Git error: {stderr.decode()}", status=500
)
response = web.StreamResponse( response = web.StreamResponse(
status=200, status=200,
reason='OK', reason="OK",
headers={ headers={
'Content-Type': f'application/x-{service}-advertisement', "Content-Type": f"application/x-{service}-advertisement",
'Cache-Control': 'no-cache' "Cache-Control": "no-cache",
} },
) )
await response.prepare(request) await response.prepare(request)
packet = f"# service={service}\n" packet = f"# service={service}\n"
@ -435,48 +489,58 @@ class GitApplication(web.Application):
except Exception as e: except Exception as e:
logger.error(f"Error handling info/refs: {str(e)}") logger.error(f"Error handling info/refs: {str(e)}")
return web.Response(text=f"Server error: {str(e)}", status=500) return web.Response(text=f"Server error: {str(e)}", status=500)
async def handle_service_rpc(service): async def handle_service_rpc(service):
repo_path = await get_repository_path() repo_path = await get_repository_path()
logger.info(f"handle_service_rpc: {repo_path}") logger.info(f"handle_service_rpc: {repo_path}")
if not os.path.exists(repo_path): if not os.path.exists(repo_path):
return web.Response(text="Repository not found", status=404) return web.Response(text="Repository not found", status=404)
if not request.headers.get('Content-Type') == f'application/x-{service}-request': if (
not request.headers.get("Content-Type")
== f"application/x-{service}-request"
):
return web.Response(text="Invalid Content-Type", status=403) return web.Response(text="Invalid Content-Type", status=403)
body = await request.read() body = await request.read()
cmd = [service, '--stateless-rpc', str(repo_path)] cmd = [service, "--stateless-rpc", str(repo_path)]
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdin=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE stderr=asyncio.subprocess.PIPE,
) )
stdout, stderr = await process.communicate(input=body) stdout, stderr = await process.communicate(input=body)
if process.returncode != 0: if process.returncode != 0:
logger.error(f"Git command failed: {stderr.decode()}") logger.error(f"Git command failed: {stderr.decode()}")
return web.Response(text=f"Git error: {stderr.decode()}", status=500) return web.Response(
text=f"Git error: {stderr.decode()}", status=500
)
return web.Response( return web.Response(
body=stdout, body=stdout, content_type=f"application/x-{service}-result"
content_type=f'application/x-{service}-result'
) )
except Exception as e: except Exception as e:
logger.error(f"Error handling service RPC: {str(e)}") logger.error(f"Error handling service RPC: {str(e)}")
return web.Response(text=f"Server error: {str(e)}", status=500) return web.Response(text=f"Server error: {str(e)}", status=500)
if request.method == 'GET' and path.endswith('/info/refs'):
service = request.query.get('service') if request.method == "GET" and path.endswith("/info/refs"):
if service in ('git-upload-pack', 'git-receive-pack'): service = request.query.get("service")
if service in ("git-upload-pack", "git-receive-pack"):
return await handle_info_refs(service) return await handle_info_refs(service)
else: else:
return web.Response(text="Smart HTTP requires service parameter", status=400) return web.Response(
elif request.method == 'POST' and '/git-upload-pack' in path: text="Smart HTTP requires service parameter", status=400
return await handle_service_rpc('git-upload-pack') )
elif request.method == 'POST' and '/git-receive-pack' in path: elif request.method == "POST" and "/git-upload-pack" in path:
return await handle_service_rpc('git-receive-pack') return await handle_service_rpc("git-upload-pack")
elif request.method == "POST" and "/git-receive-pack" in path:
return await handle_service_rpc("git-receive-pack")
return web.Response(text="Not found", status=404) return web.Response(text="Not found", status=404)
if __name__ == '__main__':
if __name__ == "__main__":
try: try:
import uvloop import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger.info("Using uvloop for improved performance") logger.info("Using uvloop for improved performance")
except ImportError: except ImportError:

View File

@ -2,28 +2,28 @@ import aiohttp
ENABLED = False ENABLED = False
import aiohttp
import asyncio import asyncio
from aiohttp import web import json
import sqlite3 import sqlite3
import dataset import aiohttp
from aiohttp import web
from sqlalchemy import event from sqlalchemy import event
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
import json
queue = asyncio.Queue() queue = asyncio.Queue()
class State: class State:
do_not_sync = False do_not_sync = False
async def sync_service(app): async def sync_service(app):
if not ENABLED: if not ENABLED:
return return
session = aiohttp.ClientSession() session = aiohttp.ClientSession()
async with session.ws_connect('http://localhost:3131/ws') as ws: async with session.ws_connect("http://localhost:3131/ws") as ws:
async def receive(): async def receive():
queries_synced = 0 queries_synced = 0
@ -31,7 +31,7 @@ async def sync_service(app):
if msg.type == aiohttp.WSMsgType.TEXT: if msg.type == aiohttp.WSMsgType.TEXT:
try: try:
data = json.loads(msg.data) data = json.loads(msg.data)
State.do_not_sync = True State.do_not_sync = True
app.db.execute(*data) app.db.execute(*data)
app.db.commit() app.db.commit()
State.do_not_sync = False State.do_not_sync = False
@ -41,21 +41,25 @@ async def sync_service(app):
await app.services.socket.broadcast_event() await app.services.socket.broadcast_event()
except Exception as e: except Exception as e:
print(e) print(e)
pass pass
#print(f"Received: {msg.data}") # print(f"Received: {msg.data}")
elif msg.type == aiohttp.WSMsgType.ERROR: elif msg.type == aiohttp.WSMsgType.ERROR:
break break
async def write(): async def write():
while True: while True:
msg = await queue.get() msg = await queue.get()
await ws.send_str(json.dumps(msg,default=str)) await ws.send_str(json.dumps(msg, default=str))
queue.task_done() queue.task_done()
await asyncio.gather(receive(), write()) await asyncio.gather(receive(), write())
await session.close() await session.close()
queries_queued = 0 queries_queued = 0
# Attach a listener to log all executed statements # Attach a listener to log all executed statements
@event.listens_for(Engine, "before_cursor_execute") @event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
@ -63,7 +67,7 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
return return
global queries_queued global queries_queued
if State.do_not_sync: if State.do_not_sync:
print(statement,parameters) print(statement, parameters)
return return
if statement.startswith("SELECT"): if statement.startswith("SELECT"):
return return
@ -71,17 +75,18 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
queries_queued += 1 queries_queued += 1
print("Queries queued: " + str(queries_queued)) print("Queries queued: " + str(queries_queued))
async def websocket_handler(request): async def websocket_handler(request):
queries_broadcasted = 0 queries_broadcasted = 0
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
request.app['websockets'].append(ws) request.app["websockets"].append(ws)
async for msg in ws: async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT: if msg.type == aiohttp.WSMsgType.TEXT:
for client in request.app['websockets']: for client in request.app["websockets"]:
if client != ws: if client != ws:
await client.send_str(msg.data) await client.send_str(msg.data)
cursor = request.app['db'].cursor() cursor = request.app["db"].cursor()
data = json.loads(msg.data) data = json.loads(msg.data)
queries_broadcasted += 1 queries_broadcasted += 1
@ -89,28 +94,32 @@ async def websocket_handler(request):
cursor.close() cursor.close()
print("Queries broadcasted: " + str(queries_broadcasted)) print("Queries broadcasted: " + str(queries_broadcasted))
elif msg.type == aiohttp.WSMsgType.ERROR: elif msg.type == aiohttp.WSMsgType.ERROR:
print(f'WebSocket connection closed with exception {ws.exception()}') print(f"WebSocket connection closed with exception {ws.exception()}")
request.app['websockets'].remove(ws) request.app["websockets"].remove(ws)
return ws return ws
app = web.Application()
app['websockets'] = []
app.router.add_get('/ws', websocket_handler) app = web.Application()
app["websockets"] = []
app.router.add_get("/ws", websocket_handler)
async def on_startup(app): async def on_startup(app):
app['db'] = sqlite3.connect('snek.db') app["db"] = sqlite3.connect("snek.db")
print("Server starting...") print("Server starting...")
async def on_cleanup(app): async def on_cleanup(app):
for ws in app['websockets']: for ws in app["websockets"]:
await ws.close() await ws.close()
app['db'].close() app["db"].close()
app.on_startup.append(on_startup) app.on_startup.append(on_startup)
app.on_cleanup.append(on_cleanup) app.on_cleanup.append(on_cleanup)
if __name__ == '__main__': if __name__ == "__main__":
web.run_app(app, host='127.0.0.1', port=3131) web.run_app(app, host="127.0.0.1", port=3131)

View File

@ -1,25 +1,30 @@
import asyncssh
import logging import logging
from pathlib import Path from pathlib import Path
global _app import asyncssh
global _app
def set_app(app): def set_app(app):
global _app global _app
_app = app _app = app
def get_app(): def get_app():
return _app return _app
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
roots = {} roots = {}
class SFTPServer(asyncssh.SFTPServer): class SFTPServer(asyncssh.SFTPServer):
def __init__(self, chan: asyncssh.SSHServerChannel): def __init__(self, chan: asyncssh.SSHServerChannel):
self.root = get_app().services.user.get_home_folder_by_username( self.root = get_app().services.user.get_home_folder_by_username(
chan.get_extra_info('username') chan.get_extra_info("username")
) )
self.root.mkdir(exist_ok=True) self.root.mkdir(exist_ok=True)
self.root = str(self.root) self.root = str(self.root)
@ -30,20 +35,22 @@ class SFTPServer(asyncssh.SFTPServer):
logger.debug(f"Mapping client path {path} to {mapped_path}") logger.debug(f"Mapping client path {path} to {mapped_path}")
return str(mapped_path).encode() return str(mapped_path).encode()
class SSHServer(asyncssh.SSHServer): class SSHServer(asyncssh.SSHServer):
def password_auth_supported(self): def password_auth_supported(self):
return True return True
def validate_password(self, username, password): def validate_password(self, username, password):
logger.debug(f"Validating credentials for user {username}") logger.debug(f"Validating credentials for user {username}")
result = get_app().services.user.authenticate_sync(username,password) result = get_app().services.user.authenticate_sync(username, password)
logger.info(f"Validating credentials for user {username}: {result}") logger.info(f"Validating credentials for user {username}: {result}")
return result return result
async def start_ssh_server(app,host,port):
set_app(app) async def start_ssh_server(app, host, port):
set_app(app)
logger.info("Starting SFTP server setup") logger.info("Starting SFTP server setup")
host_key_path = Path("drive") / ".ssh" / "sftp_server_key" host_key_path = Path("drive") / ".ssh" / "sftp_server_key"
host_key_path.parent.mkdir(exist_ok=True, parents=True) host_key_path.parent.mkdir(exist_ok=True, parents=True)
try: try:
@ -63,14 +70,12 @@ async def start_ssh_server(app,host,port):
x = await asyncssh.listen( x = await asyncssh.listen(
host=host, host=host,
port=port, port=port,
#process_factory=handle_client, # process_factory=handle_client,
server_host_keys=[key], server_host_keys=[key],
server_factory=SSHServer, server_factory=SSHServer,
sftp_factory=SFTPServer sftp_factory=SFTPServer,
) )
return x return x
except Exception as e: except Exception:
logger.warning(f"Failed to start SFTP server. Already running.") logger.warning(f"Failed to start SFTP server. Already running.")
pass pass

View File

@ -1,58 +1,71 @@
import yaml
import copy import copy
import yaml
class ComposeFileManager: class ComposeFileManager:
def __init__(self, compose_path='docker-compose.yml'): def __init__(self, compose_path="docker-compose.yml"):
self.compose_path = compose_path self.compose_path = compose_path
self._load() self._load()
def _load(self): def _load(self):
try: try:
with open(self.compose_path, 'r') as f: with open(self.compose_path) as f:
self.compose = yaml.safe_load(f) or {} self.compose = yaml.safe_load(f) or {}
except FileNotFoundError: except FileNotFoundError:
self.compose = {'version': '3', 'services': {}} self.compose = {"version": "3", "services": {}}
def _save(self): def _save(self):
with open(self.compose_path, 'w') as f: with open(self.compose_path, "w") as f:
yaml.dump(self.compose, f, default_flow_style=False) yaml.dump(self.compose, f, default_flow_style=False)
def list_instances(self): def list_instances(self):
return list(self.compose.get('services', {}).keys()) return list(self.compose.get("services", {}).keys())
def create_instance(self, name, image, command=None, cpus=None, memory=None, ports=None, volumes=None): def create_instance(
self,
name,
image,
command=None,
cpus=None,
memory=None,
ports=None,
volumes=None,
):
service = { service = {
'image': image, "image": image,
} }
if command: if command:
service['command'] = command service["command"] = command
if cpus or memory: if cpus or memory:
service['deploy'] = {'resources': {'limits': {}}} service["deploy"] = {"resources": {"limits": {}}}
if cpus: if cpus:
service['deploy']['resources']['limits']['cpus'] = str(cpus) service["deploy"]["resources"]["limits"]["cpus"] = str(cpus)
if memory: if memory:
service['deploy']['resources']['limits']['memory'] = str(memory) service["deploy"]["resources"]["limits"]["memory"] = str(memory)
if ports: if ports:
service['ports'] = [f"{host}:{container}" for container, host in ports.items()] service["ports"] = [
f"{host}:{container}" for container, host in ports.items()
]
if volumes: if volumes:
service['volumes'] = volumes service["volumes"] = volumes
self.compose.setdefault('services', {})[name] = service self.compose.setdefault("services", {})[name] = service
self._save() self._save()
def remove_instance(self, name): def remove_instance(self, name):
if name in self.compose.get('services', {}): if name in self.compose.get("services", {}):
del self.compose['services'][name] del self.compose["services"][name]
self._save() self._save()
def get_instance(self, name): def get_instance(self, name):
return self.compose.get('services', {}).get(name) return self.compose.get("services", {}).get(name)
def duplicate_instance(self, name, new_name): def duplicate_instance(self, name, new_name):
orig = self.get_instance(name) orig = self.get_instance(name)
if not orig: if not orig:
raise ValueError(f"No such instance: {name}") raise ValueError(f"No such instance: {name}")
self.compose['services'][new_name] = copy.deepcopy(orig) self.compose["services"][new_name] = copy.deepcopy(orig)
self._save() self._save()
def update_instance(self, name, **kwargs): def update_instance(self, name, **kwargs):
@ -62,11 +75,12 @@ class ComposeFileManager:
for k, v in kwargs.items(): for k, v in kwargs.items():
if v is not None: if v is not None:
service[k] = v service[k] = v
self.compose['services'][name] = service self.compose["services"][name] = service
self._save() self._save()
# Storage size is not tracked in compose files; would need Docker API for that. # Storage size is not tracked in compose files; would need Docker API for that.
# Example usage: # Example usage:
# mgr = ComposeFileManager() # mgr = ComposeFileManager()
# mgr.create_instance('web', 'nginx:latest', cpus=1, memory='512m', ports={80:8080}, volumes=['./data:/data']) # mgr.create_instance('web', 'nginx:latest', cpus=1, memory='512m', ports={80:8080}, volumes=['./data:/data'])

View File

@ -1,6 +1,7 @@
DEFAULT_LIMIT = 30 DEFAULT_LIMIT = 30
import asyncio
import typing import typing
import asyncio
from snek.system.model import BaseModel from snek.system.model import BaseModel
@ -12,21 +13,21 @@ class BaseMapper:
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
self.semaphore = asyncio.Semaphore(1) self.semaphore = asyncio.Semaphore(1)
self.default_limit = self.__class__.default_limit self.default_limit = self.__class__.default_limit
@property @property
def db(self): def db(self):
return self.app.db return self.app.db
@property @property
def loop(self): def loop(self):
return asyncio.get_event_loop() return asyncio.get_event_loop()
async def run_in_executor(self, func, *args, **kwargs): async def run_in_executor(self, func, *args, **kwargs):
async with self.semaphore: async with self.semaphore:
return func(*args, **kwargs) return func(*args, **kwargs)
#return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs)) # return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs))
async def new(self): async def new(self):
return self.model_class(mapper=self, app=self.app) return self.model_class(mapper=self, app=self.app)
@ -38,8 +39,8 @@ class BaseMapper:
async def get(self, uid: str = None, **kwargs) -> BaseModel: async def get(self, uid: str = None, **kwargs) -> BaseModel:
if uid: if uid:
kwargs["uid"] = uid kwargs["uid"] = uid
record = await self.run_in_executor(self.table.find_one,**kwargs) record = await self.run_in_executor(self.table.find_one, **kwargs)
if not record: if not record:
return None return None
record = dict(record) record = dict(record)
@ -50,7 +51,7 @@ class BaseMapper:
return await self.model_class.from_record(mapper=self, record=record) return await self.model_class.from_record(mapper=self, record=record)
async def exists(self, **kwargs): async def exists(self, **kwargs):
return await self.run_in_executor(self.table.exists,**kwargs) return await self.run_in_executor(self.table.exists, **kwargs)
async def count(self, **kwargs) -> int: async def count(self, **kwargs) -> int:
return await self.run_in_executor(self.table.count, **kwargs) return await self.run_in_executor(self.table.count, **kwargs)
@ -71,7 +72,7 @@ class BaseMapper:
yield model yield model
async def query(self, sql, *args): async def query(self, sql, *args):
for record in await self.run_in_executor(self.db.query,sql, *args): for record in await self.run_in_executor(self.db.query, sql, *args):
yield dict(record) yield dict(record)
async def update(self, model): async def update(self, model):

View File

@ -1,5 +1,5 @@
# Original source: https://brandonjay.dev/posts/2021/render-markdown-html-in-python-with-jinja2 # Original source: https://brandonjay.dev/posts/2021/render-markdown-html-in-python-with-jinja2
import re import re
from types import SimpleNamespace from types import SimpleNamespace
from app.cache import time_cache_async from app.cache import time_cache_async
@ -13,19 +13,20 @@ from pygments.lexers import get_lexer_by_name
def strip_markdown(md_text): def strip_markdown(md_text):
# Remove code blocks ( # Remove code blocks (
md_text = re.sub(r'[\s\S]?```', '', md_text) md_text = re.sub(r"[\s\S]?```", "", md_text)
md_text = re.sub(r'^\s{4,}.$', '', md_text, flags=re.MULTILINE) md_text = re.sub(r"^\s{4,}.$", "", md_text, flags=re.MULTILINE)
md_text = re.sub(r'^\s{0,3}#{1,6}\s+', '', md_text, flags=re.MULTILINE) md_text = re.sub(r"^\s{0,3}#{1,6}\s+", "", md_text, flags=re.MULTILINE)
md_text = re.sub(r'!\[.?\]\(.?\)', '', md_text) md_text = re.sub(r"!\[.?\]\(.?\)", "", md_text)
md_text = re.sub(r'\[([^\]]+)\]\(.?\)', r'\1', md_text) md_text = re.sub(r"\[([^\]]+)\]\(.?\)", r"\1", md_text)
md_text = re.sub(r'(\*|_){1,3}(.+?)\1{1,3}', r'\2', md_text) md_text = re.sub(r"(\*|_){1,3}(.+?)\1{1,3}", r"\2", md_text)
md_text = re.sub(r'^\s{0,3}>+\s?', '', md_text, flags=re.MULTILINE) md_text = re.sub(r"^\s{0,3}>+\s?", "", md_text, flags=re.MULTILINE)
md_text = re.sub(r'^(\s)(\-{3,}|_{3,}|\{3,})\s$', '', md_text, flags=re.MULTILINE) md_text = re.sub(r"^(\s)(\-{3,}|_{3,}|\{3,})\s$", "", md_text, flags=re.MULTILINE)
md_text = re.sub(r'[`~>#+\-=]', '', md_text) md_text = re.sub(r"[`~>#+\-=]", "", md_text)
md_text = re.sub(r'\s+', ' ', md_text) md_text = re.sub(r"\s+", " ", md_text)
return md_text.strip() return md_text.strip()
class MarkdownRenderer(HTMLRenderer): class MarkdownRenderer(HTMLRenderer):
_allow_harmful_protocols = False _allow_harmful_protocols = False
@ -40,7 +41,7 @@ class MarkdownRenderer(HTMLRenderer):
formatter = html.HtmlFormatter() formatter = html.HtmlFormatter()
self.env.globals["highlight_styles"] = formatter.get_style_defs() self.env.globals["highlight_styles"] = formatter.get_style_defs()
#def _escape(self, str): # def _escape(self, str):
# return str ##escape(str) # return str ##escape(str)
def get_lexer(self, lang, default="bash"): def get_lexer(self, lang, default="bash"):

View File

@ -6,9 +6,9 @@
# MIT License: This code is distributed under the MIT License. # MIT License: This code is distributed under the MIT License.
from aiohttp import web
import secrets import secrets
from aiohttp import web
csp_policy = ( csp_policy = (
"default-src 'self'; " "default-src 'self'; "
@ -22,14 +22,16 @@ csp_policy = (
def generate_nonce(): def generate_nonce():
return secrets.token_hex(16) return secrets.token_hex(16)
@web.middleware @web.middleware
async def csp_middleware(request, handler): async def csp_middleware(request, handler):
response = await handler(request) response = await handler(request)
return response
nonce = generate_nonce()
response.headers['Content-Security-Policy'] = csp_policy.format(nonce=nonce)
return response return response
nonce = generate_nonce()
response.headers["Content-Security-Policy"] = csp_policy.format(nonce=nonce)
return response
@web.middleware @web.middleware
async def no_cors_middleware(request, handler): async def no_cors_middleware(request, handler):

View File

@ -63,9 +63,11 @@ def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str:
obj = hashlib.sha256(salted) obj = hashlib.sha256(salted)
return obj.hexdigest() return obj.hexdigest()
async def hash(data: str, salt: str = DEFAULT_SALT) -> str: async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
return hash_sync(data, salt) return hash_sync(data, salt)
def verify_sync(string: str, hashed: str) -> bool: def verify_sync(string: str, hashed: str) -> bool:
"""Verify if the given string matches the hashed value. """Verify if the given string matches the hashed value.
@ -78,5 +80,6 @@ def verify_sync(string: str, hashed: str) -> bool:
""" """
return hash_sync(string) == hashed return hash_sync(string) == hashed
async def verify(string: str, hashed: str) -> bool: async def verify(string: str, hashed: str) -> bool:
return verify_sync(string, hashed) return verify_sync(string, hashed)

View File

@ -1,17 +1,16 @@
import mimetypes import mimetypes
import re import re
from functools import lru_cache
from types import SimpleNamespace from types import SimpleNamespace
from urllib.parse import urlparse, parse_qs from urllib.parse import parse_qs, urlparse
from app.cache import time_cache
import bleach
import emoji import emoji
import requests import requests
from app.cache import time_cache
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from jinja2 import TemplateSyntaxError, nodes from jinja2 import TemplateSyntaxError, nodes
from jinja2.ext import Extension from jinja2.ext import Extension
from jinja2.nodes import Const from jinja2.nodes import Const
import bleach
emoji.EMOJI_DATA['<img src="/emoji/snek1.gif" />'] = { emoji.EMOJI_DATA['<img src="/emoji/snek1.gif" />'] = {
"en": ":snek1:", "en": ":snek1:",
@ -81,13 +80,29 @@ emoji.EMOJI_DATA[
ALLOWED_TAGS = list(bleach.sanitizer.ALLOWED_TAGS) + [ ALLOWED_TAGS = list(bleach.sanitizer.ALLOWED_TAGS) + [
"img", "video", "audio", "source", "iframe", "picture", "span" "img",
"video",
"audio",
"source",
"iframe",
"picture",
"span",
] ]
ALLOWED_ATTRIBUTES = { ALLOWED_ATTRIBUTES = {
**bleach.sanitizer.ALLOWED_ATTRIBUTES, **bleach.sanitizer.ALLOWED_ATTRIBUTES,
"img": ["src", "alt", "title", "width", "height"], "img": ["src", "alt", "title", "width", "height"],
"a": ["href", "title", "target", "rel", "referrerpolicy", "class"], "a": ["href", "title", "target", "rel", "referrerpolicy", "class"],
"iframe": ["src", "width", "height", "frameborder", "allow", "allowfullscreen", "title", "referrerpolicy", "style"], "iframe": [
"src",
"width",
"height",
"frameborder",
"allow",
"allowfullscreen",
"title",
"referrerpolicy",
"style",
],
"video": ["src", "controls", "width", "height"], "video": ["src", "controls", "width", "height"],
"audio": ["src", "controls"], "audio": ["src", "controls"],
"source": ["src", "type"], "source": ["src", "type"],
@ -96,9 +111,6 @@ ALLOWED_ATTRIBUTES = {
} }
def sanitize_html(value): def sanitize_html(value):
return bleach.clean( return bleach.clean(
value, value,
@ -109,7 +121,6 @@ def sanitize_html(value):
) )
def set_link_target_blank(text): def set_link_target_blank(text):
soup = BeautifulSoup(text, "html.parser") soup = BeautifulSoup(text, "html.parser")
@ -118,26 +129,43 @@ def set_link_target_blank(text):
element.attrs["rel"] = "noopener noreferrer" element.attrs["rel"] = "noopener noreferrer"
element.attrs["referrerpolicy"] = "no-referrer" element.attrs["referrerpolicy"] = "no-referrer"
element.attrs["href"] = element.attrs["href"].strip(".").strip(",") element.attrs["href"] = element.attrs["href"].strip(".").strip(",")
return str(soup) return str(soup)
SAFE_ATTRIBUTES = { SAFE_ATTRIBUTES = {
'href', 'src', 'alt', 'title', 'width', 'height', 'style', 'id', 'class', "href",
'rel', 'type', 'name', 'value', 'placeholder', 'aria-hidden', 'aria-label', 'srcset' "src",
"alt",
"title",
"width",
"height",
"style",
"id",
"class",
"rel",
"type",
"name",
"value",
"placeholder",
"aria-hidden",
"aria-label",
"srcset",
} }
def whitelist_attributes(html): def whitelist_attributes(html):
soup = BeautifulSoup(html, 'html.parser') soup = BeautifulSoup(html, "html.parser")
for tag in soup.find_all(): for tag in soup.find_all():
if hasattr(tag, 'attrs'): if hasattr(tag, "attrs"):
if tag.name in ['script','form','input']: if tag.name in ["script", "form", "input"]:
tag.replace_with('') tag.replace_with("")
continue continue
attrs = dict(tag.attrs) attrs = dict(tag.attrs)
for attr in list(attrs): for attr in list(attrs):
# Check if attribute is in the safe list or is a data-* attribute # Check if attribute is in the safe list or is a data-* attribute
if not (attr in SAFE_ATTRIBUTES or attr.startswith('data-')): if not (attr in SAFE_ATTRIBUTES or attr.startswith("data-")):
del tag.attrs[attr] del tag.attrs[attr]
return str(soup) return str(soup)
@ -236,12 +264,12 @@ def enrich_image_rendering(text):
for element in soup.find_all("img"): for element in soup.find_all("img"):
if element.attrs["src"].startswith("/"): if element.attrs["src"].startswith("/"):
element.attrs["src"] += "?width=240&height=240" element.attrs["src"] += "?width=240&height=240"
picture_template = f''' picture_template = f"""
<picture> <picture>
<source srcset="{element.attrs["src"]}" type="{mimetypes.guess_type(element.attrs["src"])[0]}" /> <source srcset="{element.attrs["src"]}" type="{mimetypes.guess_type(element.attrs["src"])[0]}" />
<source srcset="{element.attrs["src"]}&format=webp" type="image/webp" /> <source srcset="{element.attrs["src"]}&format=webp" type="image/webp" />
<img src="{element.attrs["src"]}&format=png" title="{element.attrs["src"]}" alt="{element.attrs["src"]}" /> <img src="{element.attrs["src"]}&format=png" title="{element.attrs["src"]}" alt="{element.attrs["src"]}" />
</picture>''' </picture>"""
element.replace_with(BeautifulSoup(picture_template, "html.parser")) element.replace_with(BeautifulSoup(picture_template, "html.parser"))
return str(soup) return str(soup)
@ -285,7 +313,7 @@ def linkify_https(text):
return set_link_target_blank(str(soup)) return set_link_target_blank(str(soup))
@time_cache(timeout=60*60) @time_cache(timeout=60 * 60)
def get_url_content(url): def get_url_content(url):
try: try:
response = requests.get(url, timeout=5) response = requests.get(url, timeout=5)
@ -368,12 +396,11 @@ def embed_url(text):
page_video = get_element_options(None, "video", "video", "video") page_video = get_element_options(None, "video", "video", "video")
page_audio = get_element_options(None, "audio", "audio", "audio") page_audio = get_element_options(None, "audio", "audio", "audio")
preview_size = ( (
get_element_options(None, None, None, "card") get_element_options(None, None, None, "card")
or "summary_large_image" or "summary_large_image"
) )
attachment_base = BeautifulSoup(str(element), "html.parser") attachment_base = BeautifulSoup(str(element), "html.parser")
attachments[original_link_name] = attachment_base attachments[original_link_name] = attachment_base
@ -412,20 +439,28 @@ def embed_url(text):
) )
description_element.append( description_element.append(
BeautifulSoup(f'<strong class="page-name">{page_name}</strong>', "html.parser") BeautifulSoup(
f'<strong class="page-name">{page_name}</strong>',
"html.parser",
)
) )
description_element.append( description_element.append(
BeautifulSoup(f"<p class='page-description'>{page_description or "No description available."}</p>", "html.parser") BeautifulSoup(
f"<p class='page-description'>{page_description or "No description available."}</p>",
"html.parser",
)
) )
description_element.append( description_element.append(
BeautifulSoup(f"<p class='page-original-link'>{original_link_name}</p>", "html.parser") BeautifulSoup(
f"<p class='page-original-link'>{original_link_name}</p>",
"html.parser",
)
) )
render_element.append(description_element_base) render_element.append(description_element_base)
for attachment in attachments.values(): for attachment in attachments.values():
soup.append(attachment) soup.append(attachment)

View File

@ -1,22 +1,24 @@
class WebSocketClient: class WebSocketClient:
def __init__(self, hostname, port): def __init__(self, hostname, port):
self.buffer = b'' self.buffer = b""
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.hostname = hostname self.hostname = hostname
self.port = port self.port = port
self.connect() self.connect()
def __getattr__(self, method, *args, **kwargs): def __getattr__(self, method, *args, **kwargs):
if method in self.__dict__.keys(): if method in self.__dict__.keys():
return self.__dict__[method] return self.__dict__[method]
def call(*args, **kwargs): def call(*args, **kwargs):
self.write(json.dumps({'method': method, 'args': args, 'kwargs': kwargs})) self.write(json.dumps({"method": method, "args": args, "kwargs": kwargs}))
return json.loads(self.read()) return json.loads(self.read())
return call return call
def connect(self): def connect(self):
self.socket.connect((self.hostname, self.port)) self.socket.connect((self.hostname, self.port))
key = base64.b64encode(b'1234123412341234').decode('utf-8') key = base64.b64encode(b"1234123412341234").decode("utf-8")
handshake = ( handshake = (
f"GET /db HTTP/1.1\r\n" f"GET /db HTTP/1.1\r\n"
f"Host: localhost:3131\r\n" f"Host: localhost:3131\r\n"
@ -25,57 +27,58 @@ class WebSocketClient:
f"Sec-WebSocket-Key: {key}\r\n" f"Sec-WebSocket-Key: {key}\r\n"
f"Sec-WebSocket-Version: 13\r\n\r\n" f"Sec-WebSocket-Version: 13\r\n\r\n"
) )
self.socket.sendall(handshake.encode('utf-8')) self.socket.sendall(handshake.encode("utf-8"))
response = self.read_until(b'\r\n\r\n') response = self.read_until(b"\r\n\r\n")
if b'101 Switching Protocols' not in response: if b"101 Switching Protocols" not in response:
raise Exception("Failed to connect to WebSocket") raise Exception("Failed to connect to WebSocket")
def write(self, message): def write(self, message):
message_bytes = message.encode('utf-8') message_bytes = message.encode("utf-8")
length = len(message_bytes) length = len(message_bytes)
if length <= 125: if length <= 125:
self.socket.sendall(b'\x81' + bytes([length]) + message_bytes) self.socket.sendall(b"\x81" + bytes([length]) + message_bytes)
elif length >= 126 and length <= 65535: elif length >= 126 and length <= 65535:
self.socket.sendall(b'\x81' + bytes([126]) + length.to_bytes(2, 'big') + message_bytes) self.socket.sendall(
b"\x81" + bytes([126]) + length.to_bytes(2, "big") + message_bytes
)
else: else:
self.socket.sendall(b'\x81' + bytes([127]) + length.to_bytes(8, 'big') + message_bytes) self.socket.sendall(
b"\x81" + bytes([127]) + length.to_bytes(8, "big") + message_bytes
)
def read_until(self, delimiter): def read_until(self, delimiter):
while True: while True:
find_pos = self.buffer.find(delimiter) find_pos = self.buffer.find(delimiter)
if find_pos != -1: if find_pos != -1:
data = self.buffer[:find_pos+4] data = self.buffer[: find_pos + 4]
self.buffer = self.buffer[find_pos+4:] self.buffer = self.buffer[find_pos + 4 :]
return data return data
chunk = self.socket.recv(1024) chunk = self.socket.recv(1024)
if not chunk: if not chunk:
return None return None
self.buffer += chunk self.buffer += chunk
def read_exactly(self, length): def read_exactly(self, length):
while len(self.buffer) < length: while len(self.buffer) < length:
chunk = self.socket.recv(length - len(self.buffer)) chunk = self.socket.recv(length - len(self.buffer))
if not chunk: if not chunk:
return None return None
self.buffer += chunk self.buffer += chunk
response = self.buffer[: length] response = self.buffer[:length]
self.buffer = self.buffer[length:] self.buffer = self.buffer[length:]
return response return response
def read(self): def read(self):
frame = None frame = None
frame = self.read_exactly(2) frame = self.read_exactly(2)
length = frame[1] & 127 length = frame[1] & 127
if length == 126: if length == 126:
length = int.from_bytes(self.read_exactly(2), 'big') length = int.from_bytes(self.read_exactly(2), "big")
elif length == 127: elif length == 127:
length = int.from_bytes(self.read_exactly(8), 'big') length = int.from_bytes(self.read_exactly(8), "big")
message = self.read_exactly(length) message = self.read_exactly(length)
return message return message
def close(self): def close(self):
self.socket.close() self.socket.close()

View File

@ -31,21 +31,17 @@ from multiavatar import multiavatar
from snek.system.view import BaseView from snek.system.view import BaseView
from snek.view.avatar_animal import generate_avatar_with_options from snek.view.avatar_animal import generate_avatar_with_options
import functools
class AvatarView(BaseView): class AvatarView(BaseView):
login_required = False login_required = False
def __init__(self, *args,**kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args,**kwargs) super().__init__(*args, **kwargs)
self.avatars = {} self.avatars = {}
async def get(self): async def get(self):
type_ = self.request.query.get("type") type_ = self.request.query.get("type")
type_match = { type_match = {"animal": self.get_animal, "default": self.get_default}
"animal": self.get_animal,
"default": self.get_default
}
handler = type_match.get(type_, self.get_default) handler = type_match.get(type_, self.get_default)
return await handler() return await handler()
@ -64,10 +60,9 @@ class AvatarView(BaseView):
avatar = generate_avatar_with_options(self.request.query) avatar = generate_avatar_with_options(self.request.query)
await self.app.set(key, avatar) await self.app.set(key, avatar)
return web.Response(text=avatar, content_type="image/svg+xml") return web.Response(text=avatar, content_type="image/svg+xml")
except Exception as e: except Exception:
pass pass
def _get(self, uid): def _get(self, uid):
avatar = generate_avatar_with_options(self.request.query) avatar = generate_avatar_with_options(self.request.query)
return avatar return avatar

File diff suppressed because it is too large Load Diff

View File

@ -31,13 +31,12 @@ from multiavatar import multiavatar
from snek.system.view import BaseView from snek.system.view import BaseView
from snek.view.avatar_animal import generate_avatar_with_options from snek.view.avatar_animal import generate_avatar_with_options
import functools
class AvatarView(BaseView): class AvatarView(BaseView):
login_required = False login_required = False
def __init__(self, *args,**kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args,**kwargs) super().__init__(*args, **kwargs)
self.avatars = {} self.avatars = {}
async def get(self): async def get(self):
@ -45,10 +44,9 @@ class AvatarView(BaseView):
while True: while True:
try: try:
return web.Response(text=self._get(uid), content_type="image/svg+xml") return web.Response(text=self._get(uid), content_type="image/svg+xml")
except Exception as e: except Exception:
pass pass
def _get(self, uid): def _get(self, uid):
if uid in self.avatars: if uid in self.avatars:
return self.avatars[uid] return self.avatars[uid]

View File

@ -1,15 +1,15 @@
import asyncio import asyncio
import mimetypes import mimetypes
import pathlib
import urllib.parse
from os.path import isfile from os.path import isfile
from PIL import Image
import pillow_heif.HeifImagePlugin
from snek.system.view import BaseView
import aiofiles import aiofiles
from aiohttp import web from aiohttp import web
import pathlib from PIL import Image
import urllib.parse
from snek.system.view import BaseView
class ChannelAttachmentView(BaseView): class ChannelAttachmentView(BaseView):
async def get(self): async def get(self):
@ -73,18 +73,21 @@ class ChannelAttachmentView(BaseView):
# response.write is async but image.save is not # response.write is async but image.save is not
naughty_steal = response.write naughty_steal = response.write
tasks = [] tasks = []
def sync_writer(*args, **kwargs): def sync_writer(*args, **kwargs):
tasks.append(naughty_steal(*args, **kwargs)) tasks.append(naughty_steal(*args, **kwargs))
return True return True
setattr(response, "write", sync_writer) setattr(response, "write", sync_writer)
image.save(response, format=format_, quality=100, optimize=True, save_all=True) image.save(
response, format=format_, quality=100, optimize=True, save_all=True
)
setattr(response, "write", naughty_steal) setattr(response, "write", naughty_steal)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
return response return response
else: else:
response = web.FileResponse(channel_attachment["path"]) response = web.FileResponse(channel_attachment["path"])
@ -127,7 +130,9 @@ class ChannelAttachmentView(BaseView):
attachment_records = [] attachment_records = []
for attachment in attachments: for attachment in attachments:
attachment_record = attachment.record attachment_record = attachment.record
attachment_record['relative_url'] = urllib.parse.quote(attachment_record['relative_url']) attachment_record["relative_url"] = urllib.parse.quote(
attachment_record["relative_url"]
)
attachment_records.append(attachment_record) attachment_records.append(attachment_record)
return web.json_response( return web.json_response(
@ -138,23 +143,37 @@ class ChannelAttachmentView(BaseView):
} }
) )
class ChannelView(BaseView): class ChannelView(BaseView):
async def get(self): async def get(self):
channel_name = self.request.match_info.get("channel") channel_name = self.request.match_info.get("channel")
if(channel_name is None): if channel_name is None:
return web.HTTPNotFound() return web.HTTPNotFound()
channel = await self.services.channel.get(label="#" + channel_name) channel = await self.services.channel.get(label="#" + channel_name)
if(channel is None): if channel is None:
channel = await self.services.channel.get(label=channel_name) channel = await self.services.channel.get(label=channel_name)
channel = await self.services.channel.get(channel_name) channel = await self.services.channel.get(channel_name)
if(channel is None): if channel is None:
channel = await self.services.channel.get(label=channel_name) channel = await self.services.channel.get(label=channel_name)
if(channel is None): if channel is None:
user = await self.services.user.get(uid=self.session.get("uid")) user = await self.services.user.get(uid=self.session.get("uid"))
is_listed = self.request.query.get("listed", False) == "true" is_listed = self.request.query.get("listed", False) == "true"
is_private = self.request.query.get("private", False) == "true" is_private = self.request.query.get("private", False) == "true"
channel = await self.services.channel.create(label=channel_name,created_by_uid=user['uid'],description="No description provided.",tag="user",is_private=is_private,is_listed=is_listed) channel = await self.services.channel.create(
channel_member = await self.services.channel_member.create(channel_uid=channel['uid'],user_uid=user['uid'],is_moderator=True,is_read_only=False,is_muted=False,is_banned=False) label=channel_name,
created_by_uid=user["uid"],
return web.HTTPFound("/channel/{}.html".format(channel["uid"])) description="No description provided.",
tag="user",
is_private=is_private,
is_listed=is_listed,
)
await self.services.channel_member.create(
channel_uid=channel["uid"],
user_uid=user["uid"],
is_moderator=True,
is_read_only=False,
is_muted=False,
is_banned=False,
)
return web.HTTPFound("/channel/{}.html".format(channel["uid"]))

View File

@ -1,20 +1,15 @@
import mimetypes
import os
import urllib.parse
from datetime import datetime
from pathlib import Path
from urllib.parse import quote, unquote
from aiohttp import web from aiohttp import web
from snek.system.view import BaseView from snek.system.view import BaseView
import os
import mimetypes
from aiohttp import web
from urllib.parse import unquote, quote
from datetime import datetime
from aiohttp import web
from pathlib import Path
import mimetypes, urllib.parse
class DriveView(BaseView): class DriveView(BaseView):
async def get(self): async def get(self):
@ -27,7 +22,7 @@ class DriveView(BaseView):
return web.HTTPNotFound(reason="Path not found") return web.HTTPNotFound(reason="Path not found")
if target.is_dir(): if target.is_dir():
return await self.render_template("drive.html",{"path": rel_path}) return await self.render_template("drive.html", {"path": rel_path})
if target.is_file(): if target.is_file():
return web.FileResponse(target) return web.FileResponse(target)
@ -37,7 +32,7 @@ class DriveApiView(BaseView):
target = await self.services.user.get_home_folder(self.session.get("uid")) target = await self.services.user.get_home_folder(self.session.get("uid"))
rel = self.request.query.get("path", "") rel = self.request.query.get("path", "")
offset = int(self.request.query.get("offset", 0)) offset = int(self.request.query.get("offset", 0))
limit = int(self.request.query.get("limit", 20)) limit = int(self.request.query.get("limit", 20))
if rel: if rel:
target = target.joinpath(rel) target = target.joinpath(rel)
@ -47,58 +42,54 @@ class DriveApiView(BaseView):
if target.is_dir(): if target.is_dir():
entries = [] entries = []
for p in sorted(target.iterdir(), key=lambda p: (p.is_file(), p.name.lower())): for p in sorted(
target.iterdir(), key=lambda p: (p.is_file(), p.name.lower())
):
item_path = (Path(rel) / p.name).as_posix() item_path = (Path(rel) / p.name).as_posix()
mime = mimetypes.guess_type(p.name)[0] if p.is_file() else "inode/directory" mime = (
url = (self.request.url.with_path(f"/drive/{urllib.parse.quote(item_path)}") mimetypes.guess_type(p.name)[0]
if p.is_file() else None) if p.is_file()
entries.append({ else "inode/directory"
"name": p.name, )
"type": "directory" if p.is_dir() else "file", url = (
"mimetype": mime, self.request.url.with_path(
"size": p.stat().st_size if p.is_file() else None, f"/drive/{urllib.parse.quote(item_path)}"
"path": item_path, )
"url": url, if p.is_file()
}) else None
import json )
entries.append(
{
"name": p.name,
"type": "directory" if p.is_dir() else "file",
"mimetype": mime,
"size": p.stat().st_size if p.is_file() else None,
"path": item_path,
"url": url,
}
)
import json
total = len(entries) total = len(entries)
items = entries[offset:offset+limit] items = entries[offset : offset + limit]
return web.json_response({ return web.json_response(
"items": json.loads(json.dumps(items,default=str)), {
"pagination": {"offset": offset, "limit": limit, "total": total} "items": json.loads(json.dumps(items, default=str)),
}) "pagination": {"offset": offset, "limit": limit, "total": total},
}
)
url = self.request.url.with_path(f"/drive/{urllib.parse.quote(rel)}") url = self.request.url.with_path(f"/drive/{urllib.parse.quote(rel)}")
return web.json_response({ return web.json_response(
"name": target.name, {
"type": "file", "name": target.name,
"mimetype": mimetypes.guess_type(target.name)[0], "type": "file",
"size": target.stat().st_size, "mimetype": mimetypes.guess_type(target.name)[0],
"path": rel, "size": target.stat().st_size,
"url": str(url), "path": rel,
}) "url": str(url),
}
)
class DriveView222(BaseView): class DriveView222(BaseView):
@ -124,7 +115,11 @@ class DriveView222(BaseView):
entry_path = os.path.join(dir_path, entry) entry_path = os.path.join(dir_path, entry)
stat = os.stat(entry_path) stat = os.stat(entry_path)
is_dir = os.path.isdir(entry_path) is_dir = os.path.isdir(entry_path)
mimetype = None if is_dir else (mimetypes.guess_type(entry_path)[0] or "application/octet-stream") mimetype = (
None
if is_dir
else (mimetypes.guess_type(entry_path)[0] or "application/octet-stream")
)
size = stat.st_size if not is_dir else None size = stat.st_size if not is_dir else None
created_at = datetime.fromtimestamp(stat.st_ctime).isoformat() created_at = datetime.fromtimestamp(stat.st_ctime).isoformat()
updated_at = datetime.fromtimestamp(stat.st_mtime).isoformat() updated_at = datetime.fromtimestamp(stat.st_mtime).isoformat()
@ -155,15 +150,20 @@ class DriveView222(BaseView):
start = (page - 1) * page_size start = (page - 1) * page_size
end = start + page_size end = start + page_size
paged_entries = entries[start:end] paged_entries = entries[start:end]
details = [await self.entry_details(full_path, entry, rel_path) for entry in paged_entries] details = [
return web.json_response({ await self.entry_details(full_path, entry, rel_path)
"path": rel_path, for entry in paged_entries
"absolute_url": abs_url, ]
"entries": details, return web.json_response(
"total": len(entries), {
"page": page, "path": rel_path,
"page_size": page_size, "absolute_url": abs_url,
}) "entries": details,
"total": len(entries),
"page": page,
"page_size": page_size,
}
)
else: else:
with open(full_path, "rb") as f: with open(full_path, "rb") as f:
content = f.read() content = f.read()
@ -180,14 +180,18 @@ class DriveView222(BaseView):
data = await self.request.post() data = await self.request.post()
if data.get("type") == "dir": if data.get("type") == "dir":
os.makedirs(full_path) os.makedirs(full_path)
return web.json_response({"status": "created", "type": "dir", "absolute_url": abs_url}) return web.json_response(
{"status": "created", "type": "dir", "absolute_url": abs_url}
)
else: else:
file_field = data.get("file") file_field = data.get("file")
if not file_field: if not file_field:
raise web.HTTPBadRequest(reason="No file uploaded") raise web.HTTPBadRequest(reason="No file uploaded")
with open(full_path, "wb") as f: with open(full_path, "wb") as f:
f.write(file_field.file.read()) f.write(file_field.file.read())
return web.json_response({"status": "created", "type": "file", "absolute_url": abs_url}) return web.json_response(
{"status": "created", "type": "file", "absolute_url": abs_url}
)
async def put(self): async def put(self):
rel_path = self.request.match_info.get("rel_path", "") rel_path = self.request.match_info.get("rel_path", "")
@ -210,10 +214,14 @@ class DriveView222(BaseView):
raise web.HTTPNotFound(reason="Path not found") raise web.HTTPNotFound(reason="Path not found")
if os.path.isdir(full_path): if os.path.isdir(full_path):
os.rmdir(full_path) os.rmdir(full_path)
return web.json_response({"status": "deleted", "type": "dir", "absolute_url": abs_url}) return web.json_response(
{"status": "deleted", "type": "dir", "absolute_url": abs_url}
)
else: else:
os.remove(full_path) os.remove(full_path)
return web.json_response({"status": "deleted", "type": "file", "absolute_url": abs_url}) return web.json_response(
{"status": "deleted", "type": "file", "absolute_url": abs_url}
)
class DriveViewi2(BaseView): class DriveViewi2(BaseView):
@ -223,20 +231,17 @@ class DriveViewi2(BaseView):
async def get(self): async def get(self):
drive_uid = self.request.match_info.get("drive") drive_uid = self.request.match_info.get("drive")
before = self.request.query.get("before") before = self.request.query.get("before")
filters = {} filters = {}
if before: if before:
filters["created_at__lt"] = before filters["created_at__lt"] = before
if drive_uid: if drive_uid:
filters['drive_uid'] = drive_uid filters["drive_uid"] = drive_uid
drive = await self.services.drive.get(uid=drive_uid) drive = await self.services.drive.get(uid=drive_uid)
drive_items = [] drive_items = []
async for item in self.services.drive_item.find(**filters): async for item in self.services.drive_item.find(**filters):
record = item.record record = item.record
record["url"] = "/drive.bin/" + record["uid"] + "." + item.extension record["url"] = "/drive.bin/" + record["uid"] + "." + item.extension

View File

@ -82,4 +82,4 @@ class PushView(BaseFormView):
print(post_notification.text) print(post_notification.text)
print(post_notification.headers) print(post_notification.headers)
return await self.json_response({ "registered": True }) return await self.json_response({"registered": True})

View File

@ -1,14 +1,13 @@
import os import asyncio
import mimetypes import mimetypes
import os
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
import humanize import humanize
from aiohttp import web from aiohttp import web
from snek.system.view import BaseView from snek.system.view import BaseView
import asyncio
class BareRepoNavigator: class BareRepoNavigator:
@ -25,18 +24,18 @@ class BareRepoNavigator:
except Exception as e: except Exception as e:
print(f"Error opening repository: {str(e)}") print(f"Error opening repository: {str(e)}")
sys.exit(1) sys.exit(1)
self.repo_path = repo_path self.repo_path = repo_path
self.branches = list(self.repo.branches) self.branches = list(self.repo.branches)
self.current_branch = None self.current_branch = None
self.current_commit = None self.current_commit = None
self.current_path = "" self.current_path = ""
self.history = [] self.history = []
def get_branches(self): def get_branches(self):
"""Return a list of branch names in the repository.""" """Return a list of branch names in the repository."""
return [branch.name for branch in self.branches] return [branch.name for branch in self.branches]
def set_branch(self, branch_name): def set_branch(self, branch_name):
"""Set the current branch.""" """Set the current branch."""
try: try:
@ -47,23 +46,27 @@ class BareRepoNavigator:
return True return True
except IndexError: except IndexError:
return False return False
def get_commits(self, count=10): def get_commits(self, count=10):
"""Get the latest commits on the current branch.""" """Get the latest commits on the current branch."""
if not self.current_branch: if not self.current_branch:
return [] return []
commits = [] commits = []
for commit in self.repo.iter_commits(self.current_branch, max_count=count): for commit in self.repo.iter_commits(self.current_branch, max_count=count):
commits.append({ commits.append(
'hash': commit.hexsha, {
'short_hash': commit.hexsha[:7], "hash": commit.hexsha,
'message': commit.message.strip(), "short_hash": commit.hexsha[:7],
'author': commit.author.name, "message": commit.message.strip(),
'date': datetime.fromtimestamp(commit.committed_date).strftime('%Y-%m-%d %H:%M:%S') "author": commit.author.name,
}) "date": datetime.fromtimestamp(commit.committed_date).strftime(
"%Y-%m-%d %H:%M:%S"
),
}
)
return commits return commits
def set_commit(self, commit_hash): def set_commit(self, commit_hash):
"""Set the current commit by hash.""" """Set the current commit by hash."""
try: try:
@ -73,48 +76,48 @@ class BareRepoNavigator:
return True return True
except ValueError: except ValueError:
return False return False
def list_directory(self, path=""): def list_directory(self, path=""):
"""List the contents of a directory in the current commit.""" """List the contents of a directory in the current commit."""
if not self.current_commit: if not self.current_commit:
return {'dirs': [], 'files': []} return {"dirs": [], "files": []}
dirs = [] dirs = []
files = [] files = []
try: try:
# Get the tree at the current path # Get the tree at the current path
if path: if path:
tree = self.current_commit.tree[path] tree = self.current_commit.tree[path]
if not hasattr(tree, 'trees'): # It's a blob, not a tree if not hasattr(tree, "trees"): # It's a blob, not a tree
return {'dirs': [], 'files': [path]} return {"dirs": [], "files": [path]}
else: else:
tree = self.current_commit.tree tree = self.current_commit.tree
# List directories and files # List directories and files
for item in tree: for item in tree:
if item.type == 'tree': if item.type == "tree":
item_path = os.path.join(path, item.name) if path else item.name item_path = os.path.join(path, item.name) if path else item.name
dirs.append(item_path) dirs.append(item_path)
elif item.type == 'blob': elif item.type == "blob":
item_path = os.path.join(path, item.name) if path else item.name item_path = os.path.join(path, item.name) if path else item.name
files.append(item_path) files.append(item_path)
dirs.sort() dirs.sort()
files.sort() files.sort()
return {'dirs': dirs, 'files': files} return {"dirs": dirs, "files": files}
except KeyError: except KeyError:
return {'dirs': [], 'files': []} return {"dirs": [], "files": []}
def get_file_content(self, file_path): def get_file_content(self, file_path):
"""Get the content of a file in the current commit.""" """Get the content of a file in the current commit."""
if not self.current_commit: if not self.current_commit:
return None return None
try: try:
blob = self.current_commit.tree[file_path] blob = self.current_commit.tree[file_path]
return blob.data_stream.read().decode('utf-8', errors='replace') return blob.data_stream.read().decode("utf-8", errors="replace")
except (KeyError, UnicodeDecodeError): except (KeyError, UnicodeDecodeError):
try: try:
# Try to get as binary if text decoding fails # Try to get as binary if text decoding fails
@ -122,12 +125,12 @@ class BareRepoNavigator:
return blob.data_stream.read() return blob.data_stream.read()
except: except:
return None return None
def navigate_to(self, path): def navigate_to(self, path):
"""Navigate to a specific path, updating the current path.""" """Navigate to a specific path, updating the current path."""
if not self.current_commit: if not self.current_commit:
return False return False
try: try:
if path: if path:
self.current_commit.tree[path] # Check if path exists self.current_commit.tree[path] # Check if path exists
@ -136,7 +139,7 @@ class BareRepoNavigator:
return True return True
except KeyError: except KeyError:
return False return False
def navigate_back(self): def navigate_back(self):
"""Navigate back to the previous path.""" """Navigate back to the previous path."""
if self.history: if self.history:
@ -145,12 +148,13 @@ class BareRepoNavigator:
return False return False
class RepositoryView(BaseView): class RepositoryView(BaseView):
login_required = True login_required = True
def checkout_bare_repo(self, bare_repo_path: Path, target_path: Path, ref: str = 'HEAD'): def checkout_bare_repo(
self, bare_repo_path: Path, target_path: Path, ref: str = "HEAD"
):
repo = Repo(bare_repo_path) repo = Repo(bare_repo_path)
assert repo.bare, "Repository is not bare." assert repo.bare, "Repository is not bare."
@ -159,23 +163,22 @@ class RepositoryView(BaseView):
for blob in tree.traverse(): for blob in tree.traverse():
target_file = target_path / blob.path target_file = target_path / blob.path
target_file.parent.mkdir(parents=True, exist_ok=True) target_file.parent.mkdir(parents=True, exist_ok=True)
print(blob.path) print(blob.path)
with open(target_file, 'wb') as f: with open(target_file, "wb") as f:
f.write(blob.data_stream.read()) f.write(blob.data_stream.read())
async def get(self): async def get(self):
base_repo_path = Path("drive/repositories") base_repo_path = Path("drive/repositories")
authenticated_user_id = self.session.get("uid") authenticated_user_id = self.session.get("uid")
username = self.request.match_info.get('username') username = self.request.match_info.get("username")
repo_name = self.request.match_info.get('repository') repo_name = self.request.match_info.get("repository")
rel_path = self.request.match_info.get('path', '') rel_path = self.request.match_info.get("path", "")
user = None user = None
if not username.count("-") == 4: if not username.count("-") == 4:
user = await self.app.services.user.get(username=username) user = await self.app.services.user.get(username=username)
@ -185,22 +188,24 @@ class RepositoryView(BaseView):
else: else:
user = await self.app.services.user.get(uid=username) user = await self.app.services.user.get(uid=username)
repo = await self.app.services.repository.get(name=repo_name, user_uid=user["uid"]) repo = await self.app.services.repository.get(
name=repo_name, user_uid=user["uid"]
)
if not repo: if not repo:
return web.Response(text="404 Not Found", status=404) return web.Response(text="404 Not Found", status=404)
if repo['is_private'] and authenticated_user_id != repo['uid']: if repo["is_private"] and authenticated_user_id != repo["uid"]:
return web.Response(text="404 Not Found", status=404) return web.Response(text="404 Not Found", status=404)
repo_root_base = (base_repo_path / user['uid'] / (repo_name + ".git")).resolve() repo_root_base = (base_repo_path / user["uid"] / (repo_name + ".git")).resolve()
repo_root = (base_repo_path / user['uid'] / repo_name).resolve() repo_root = (base_repo_path / user["uid"] / repo_name).resolve()
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor(None, await loop.run_in_executor(
self.checkout_bare_repo, repo_root_base, repo_root None, self.checkout_bare_repo, repo_root_base, repo_root
) )
except: except:
pass pass
if not repo_root.exists() or not repo_root.is_dir(): if not repo_root.exists() or not repo_root.is_dir():
return web.Response(text="404 Not Found", status=404) return web.Response(text="404 Not Found", status=404)
@ -211,24 +216,35 @@ class RepositoryView(BaseView):
return web.Response(text="404 Not Found", status=404) return web.Response(text="404 Not Found", status=404)
if abs_path.is_dir(): if abs_path.is_dir():
return web.Response(text=self.render_directory(abs_path, username, repo_name, safe_rel_path), content_type='text/html') return web.Response(
text=self.render_directory(
abs_path, username, repo_name, safe_rel_path
),
content_type="text/html",
)
else: else:
return web.Response(text=self.render_file(abs_path), content_type='text/html') return web.Response(
text=self.render_file(abs_path), content_type="text/html"
)
def render_directory(self, abs_path, username, repo_name, safe_rel_path): def render_directory(self, abs_path, username, repo_name, safe_rel_path):
entries = sorted(abs_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())) entries = sorted(
abs_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())
)
items = [] items = []
if safe_rel_path: if safe_rel_path:
parent_path = Path(safe_rel_path).parent parent_path = Path(safe_rel_path).parent
parent_link = f"/repository/{username}/{repo_name}/{parent_path}".rstrip('/') parent_link = f"/repository/{username}/{repo_name}/{parent_path}".rstrip(
"/"
)
items.append(f'<li><a href="{parent_link}">⬅️ ..</a></li>') items.append(f'<li><a href="{parent_link}">⬅️ ..</a></li>')
for entry in entries: for entry in entries:
link_path = urllib.parse.quote(str(Path(safe_rel_path) / entry.name)) link_path = urllib.parse.quote(str(Path(safe_rel_path) / entry.name))
link = f"/repository/{username}/{repo_name}/{link_path}".rstrip('/') link = f"/repository/{username}/{repo_name}/{link_path}".rstrip("/")
display = entry.name + ('/' if entry.is_dir() else '') display = entry.name + ("/" if entry.is_dir() else "")
size = '' if entry.is_dir() else humanize.naturalsize(entry.stat().st_size) size = "" if entry.is_dir() else humanize.naturalsize(entry.stat().st_size)
icon = self.get_icon(entry) icon = self.get_icon(entry)
items.append(f'<li>{icon} <a href="{link}">{display}</a> {size}</li>') items.append(f'<li>{icon} <a href="{link}">{display}</a> {size}</li>')
@ -247,19 +263,24 @@ class RepositoryView(BaseView):
def render_file(self, abs_path): def render_file(self, abs_path):
try: try:
with open(abs_path, 'r', encoding='utf-8', errors='ignore') as f: with open(abs_path, encoding="utf-8", errors="ignore") as f:
content = f.read() content = f.read()
return f"<pre>{content}</pre>" return f"<pre>{content}</pre>"
except Exception as e: except Exception as e:
return f"<h1>Error</h1><pre>{e}</pre>" return f"<h1>Error</h1><pre>{e}</pre>"
def get_icon(self, file): def get_icon(self, file):
if file.is_dir(): return "📁" if file.is_dir():
mime = mimetypes.guess_type(file.name)[0] or '' return "📁"
if mime.startswith("image"): return "🖼️" mime = mimetypes.guess_type(file.name)[0] or ""
if mime.startswith("text"): return "📄" if mime.startswith("image"):
if mime.startswith("audio"): return "🎵" return "🖼️"
if mime.startswith("video"): return "🎬" if mime.startswith("text"):
if file.name.endswith(".py"): return "🐍" return "📄"
if mime.startswith("audio"):
return "🎵"
if mime.startswith("video"):
return "🎬"
if file.name.endswith(".py"):
return "🐍"
return "📦" return "📦"

View File

@ -7,19 +7,20 @@
# MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions. # MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions.
import json
import traceback
import asyncio import asyncio
import json
import logging
import traceback
from aiohttp import web from aiohttp import web
from snek.system.model import now from snek.system.model import now
from snek.system.profiler import Profiler from snek.system.profiler import Profiler
from snek.system.view import BaseView from snek.system.view import BaseView
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RPCView(BaseView): class RPCView(BaseView):
class RPCApi: class RPCApi:
def __init__(self, view, ws): def __init__(self, view, ws):
@ -32,7 +33,7 @@ class RPCView(BaseView):
async def _session_ensure(self): async def _session_ensure(self):
uid = await self.view.session_get("uid") uid = await self.view.session_get("uid")
if not uid in self.user_session: if uid not in self.user_session:
self.user_session[uid] = { self.user_session[uid] = {
"said_hello": False, "said_hello": False,
} }
@ -50,45 +51,54 @@ class RPCView(BaseView):
self._require_login() self._require_login()
return await self.services.db.insert(self.user_uid, table_name, record) return await self.services.db.insert(self.user_uid, table_name, record)
async def db_update(self, table_name, record): async def db_update(self, table_name, record):
self._require_login() self._require_login()
return await self.services.db.update(self.user_uid, table_name, record) return await self.services.db.update(self.user_uid, table_name, record)
async def set_typing(self,channel_uid,color=None):
async def set_typing(self, channel_uid, color=None):
self._require_login() self._require_login()
user = await self.services.user.get(self.user_uid) user = await self.services.user.get(self.user_uid)
if not color: if not color:
color = user["color"] color = user["color"]
return await self.services.socket.broadcast(channel_uid, { return await self.services.socket.broadcast(
"channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2", channel_uid,
"event": "set_typing", {
"data": { "channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2",
"event":"set_typing", "event": "set_typing",
"user_uid": user['uid'], "data": {
"username": user["username"], "event": "set_typing",
"nick": user["nick"], "user_uid": user["uid"],
"channel_uid": channel_uid, "username": user["username"],
"color": color "nick": user["nick"],
} "channel_uid": channel_uid,
}) "color": color,
},
},
)
async def db_delete(self, table_name, record): async def db_delete(self, table_name, record):
self._require_login() self._require_login()
return await self.services.db.delete(self.user_uid, table_name, record) return await self.services.db.delete(self.user_uid, table_name, record)
async def db_get(self, table_name, record): async def db_get(self, table_name, record):
self._require_login() self._require_login()
return await self.services.db.get(self.user_uid, table_name, record) return await self.services.db.get(self.user_uid, table_name, record)
async def db_find(self, table_name, record): async def db_find(self, table_name, record):
self._require_login() self._require_login()
return await self.services.db.find(self.user_uid, table_name, record) return await self.services.db.find(self.user_uid, table_name, record)
async def db_upsert(self, table_name, record,keys):
async def db_upsert(self, table_name, record, keys):
self._require_login() self._require_login()
return await self.services.db.upsert(self.user_uid, table_name, record,keys) return await self.services.db.upsert(
self.user_uid, table_name, record, keys
)
async def db_query(self, table_name, args): async def db_query(self, table_name, args):
self._require_login() self._require_login()
return await self.services.db.query(self.user_uid, table_name, sql, args) return await self.services.db.query(self.user_uid, table_name, sql, args)
@property @property
def user_uid(self): def user_uid(self):
return self.view.session.get("uid") return self.view.session.get("uid")
@ -195,52 +205,60 @@ class RPCView(BaseView):
async def finalize_message(self, message_uid): async def finalize_message(self, message_uid):
self._require_login() self._require_login()
message = await self.services.channel_message.get(message_uid) message = await self.services.channel_message.get(message_uid)
if not message: if not message:
return False return False
if message["user_uid"] != self.user_uid: if message["user_uid"] != self.user_uid:
raise Exception("Not allowed") raise Exception("Not allowed")
if not message["is_final"]:
if not message['is_final']: await self.services.chat.finalize(message["uid"])
await self.services.chat.finalize(message['uid'])
return True return True
async def update_message_text(self,message_uid, text): async def update_message_text(self, message_uid, text):
self._require_login() self._require_login()
message = await self.services.channel_message.get(message_uid) message = await self.services.channel_message.get(message_uid)
if message["user_uid"] != self.user_uid: if message["user_uid"] != self.user_uid:
raise Exception("Not allowed") raise Exception("Not allowed")
if message.get_seconds_since_last_update() > 3: if message.get_seconds_since_last_update() > 3:
await self.finalize_message(message["uid"]) await self.finalize_message(message["uid"])
return {"error": "Message too old","seconds_since_last_update": message.get_seconds_since_last_update(),"success": False} return {
"error": "Message too old",
"seconds_since_last_update": message.get_seconds_since_last_update(),
"success": False,
}
message['message'] = text message["message"] = text
if not text: if not text:
message['deleted_at'] = now() message["deleted_at"] = now()
else: else:
message['deleted_at'] = None message["deleted_at"] = None
await self.services.channel_message.save(message) await self.services.channel_message.save(message)
data = message.record data = message.record
data['text'] = message["message"] data["text"] = message["message"]
data['message_uid'] = message_uid data["message_uid"] = message_uid
await self.services.socket.broadcast(
message["channel_uid"],
{
"channel_uid": message["channel_uid"],
"event": "update_message_text",
"data": message.record,
},
)
await self.services.socket.broadcast(message["channel_uid"], {
"channel_uid": message["channel_uid"],
"event": "update_message_text",
"data": message.record
})
return {"success": True} return {"success": True}
async def send_message(self, channel_uid, message,is_final=True): async def send_message(self, channel_uid, message, is_final=True):
self._require_login() self._require_login()
message = await self.services.chat.send(self.user_uid, channel_uid, message,is_final) message = await self.services.chat.send(
self.user_uid, channel_uid, message, is_final
)
return message["uid"] return message["uid"]
async def echo(self, *args): async def echo(self, *args):
@ -330,7 +348,8 @@ class RPCView(BaseView):
self._require_login() self._require_login()
results = [ results = [
record async for record in self.services.channel.get_recent_users(channel_uid) record
async for record in self.services.channel.get_recent_users(channel_uid)
] ]
results = sorted(results, key=lambda x: x["nick"]) results = sorted(results, key=lambda x: x["nick"])
return results return results
@ -371,7 +390,6 @@ class RPCView(BaseView):
await self.services.socket.send_to_user(self.user_uid, call) await self.services.socket.send_to_user(self.user_uid, call)
self._scheduled.remove(call) self._scheduled.remove(call)
async def ping(self, callId, *args): async def ping(self, callId, *args):
if self.user_uid: if self.user_uid:
user = await self.services.user.get(uid=self.user_uid) user = await self.services.user.get(uid=self.user_uid)
@ -379,17 +397,23 @@ class RPCView(BaseView):
await self.services.user.save(user) await self.services.user.save(user)
return {"pong": args} return {"pong": args}
async def stars_render(self, channel_uid, message): async def stars_render(self, channel_uid, message):
for user in await self.get_online_users(channel_uid): for user in await self.get_online_users(channel_uid):
try: try:
await self.services.socket.send_to_user(user['uid'], dict(event="stars_render", data={"channel_uid": channel_uid, "message":message})) await self.services.socket.send_to_user(
user["uid"],
{
"event": "stars_render",
"data": {"channel_uid": channel_uid, "message": message},
},
)
except Exception as ex: except Exception as ex:
print(ex) print(ex)
async def get(self): async def get(self):
scheduled = [] scheduled = []
async def schedule(uid, seconds, call): async def schedule(uid, seconds, call):
scheduled.append(call) scheduled.append(call)
await asyncio.sleep(seconds) await asyncio.sleep(seconds)
@ -408,16 +432,18 @@ class RPCView(BaseView):
await self.services.socket.subscribe( await self.services.socket.subscribe(
ws, subscription["channel_uid"], self.request.session.get("uid") ws, subscription["channel_uid"], self.request.session.get("uid")
) )
if not scheduled and self.request.app.uptime_seconds < 5: if not scheduled and self.request.app.uptime_seconds < 5:
await schedule(self.request.session.get("uid"),0,{"event":"refresh", "data": { await schedule(
"message": "Finishing deployment"} self.request.session.get("uid"),
} 0,
) {"event": "refresh", "data": {"message": "Finishing deployment"}},
await schedule(self.request.session.get("uid"),15,{"event": "deployed", "data": {
"uptime": self.request.app.uptime}
}
) )
await schedule(
self.request.session.get("uid"),
15,
{"event": "deployed", "data": {"uptime": self.request.app.uptime}},
)
rpc = RPCView.RPCApi(self, ws) rpc = RPCView.RPCApi(self, ws)
async for msg in ws: async for msg in ws:
if msg.type == web.WSMsgType.TEXT: if msg.type == web.WSMsgType.TEXT:

View File

@ -1,45 +1,48 @@
import asyncio
from aiohttp import web from aiohttp import web
from snek.system.view import BaseFormView from snek.system.view import BaseFormView
import pathlib
class ContainersIndexView(BaseFormView): class ContainersIndexView(BaseFormView):
login_required = True login_required = True
async def get(self): async def get(self):
user_uid = self.session.get("uid") user_uid = self.session.get("uid")
containers = [] containers = []
async for container in self.services.container.find(user_uid=user_uid): async for container in self.services.container.find(user_uid=user_uid):
containers.append(container.record) containers.append(container.record)
user = await self.services.user.get(uid=self.session.get("uid")) user = await self.services.user.get(uid=self.session.get("uid"))
return await self.render_template("settings/containers/index.html", {"containers": containers, "user": user}) return await self.render_template(
"settings/containers/index.html", {"containers": containers, "user": user}
)
class ContainersCreateView(BaseFormView): class ContainersCreateView(BaseFormView):
login_required = True login_required = True
async def get(self): async def get(self):
return await self.render_template("settings/containers/create.html") return await self.render_template("settings/containers/create.html")
async def post(self): async def post(self):
data = await self.request.post() data = await self.request.post()
container = await self.services.container.create( await self.services.container.create(
user_uid=self.session.get("uid"), user_uid=self.session.get("uid"),
name=data['name'], name=data["name"],
status=data['status'], status=data["status"],
resources=data.get('resources', ''), resources=data.get("resources", ""),
path=data.get('path', ''), path=data.get("path", ""),
readonly=bool(data.get('readonly', False)) readonly=bool(data.get("readonly", False)),
) )
return web.HTTPFound("/settings/containers/index.html") return web.HTTPFound("/settings/containers/index.html")
class ContainersUpdateView(BaseFormView): class ContainersUpdateView(BaseFormView):
login_required = True login_required = True
@ -51,40 +54,43 @@ class ContainersUpdateView(BaseFormView):
) )
if not container: if not container:
return web.HTTPNotFound() return web.HTTPNotFound()
return await self.render_template("settings/containers/update.html", {"container": container.record}) return await self.render_template(
"settings/containers/update.html", {"container": container.record}
)
async def post(self): async def post(self):
data = await self.request.post() data = await self.request.post()
container = await self.services.container.get( container = await self.services.container.get(
user_uid=self.session.get("uid"), uid=self.request.match_info["uid"] user_uid=self.session.get("uid"), uid=self.request.match_info["uid"]
) )
container['status'] = data['status'] container["status"] = data["status"]
container['resources'] = data.get('resources', '') container["resources"] = data.get("resources", "")
container['path'] = data.get('path', '') container["path"] = data.get("path", "")
container['readonly'] = bool(data.get('readonly', False)) container["readonly"] = bool(data.get("readonly", False))
await self.services.container.save(container) await self.services.container.save(container)
return web.HTTPFound("/settings/containers/index.html") return web.HTTPFound("/settings/containers/index.html")
class ContainersDeleteView(BaseFormView): class ContainersDeleteView(BaseFormView):
login_required = True login_required = True
async def get(self): async def get(self):
container = await self.services.container.get( container = await self.services.container.get(
user_uid=self.session.get("uid"), uid=self.request.match_info["uid"] user_uid=self.session.get("uid"), uid=self.request.match_info["uid"]
) )
if not container: if not container:
return web.HTTPNotFound() return web.HTTPNotFound()
return await self.render_template("settings/containers/delete.html", {"container": container.record}) return await self.render_template(
"settings/containers/delete.html", {"container": container.record}
)
async def post(self): async def post(self):
user_uid = self.session.get("uid") user_uid = self.session.get("uid")
uid = self.request.match_info["uid"] uid = self.request.match_info["uid"]
container = await self.services.container.get( container = await self.services.container.get(user_uid=user_uid, uid=uid)
user_uid=user_uid, uid=uid
)
if not container: if not container:
return web.HTTPNotFound() return web.HTTPNotFound()
await self.services.container.delete(user_uid=user_uid, uid=uid) await self.services.container.delete(user_uid=user_uid, uid=uid)

View File

@ -1,26 +1,26 @@
import asyncio
from aiohttp import web from aiohttp import web
from snek.system.view import BaseFormView from snek.system.view import BaseFormView
import pathlib
class RepositoriesIndexView(BaseFormView): class RepositoriesIndexView(BaseFormView):
login_required = True login_required = True
async def get(self): async def get(self):
user_uid = self.session.get("uid") user_uid = self.session.get("uid")
repositories = [] repositories = []
async for repository in self.services.repository.find(user_uid=user_uid): async for repository in self.services.repository.find(user_uid=user_uid):
repositories.append(repository.record) repositories.append(repository.record)
user = await self.services.user.get(uid=self.session.get("uid")) user = await self.services.user.get(uid=self.session.get("uid"))
return await self.render_template("settings/repositories/index.html", {"repositories": repositories, "user": user}) return await self.render_template(
"settings/repositories/index.html",
{"repositories": repositories, "user": user},
)
class RepositoriesCreateView(BaseFormView): class RepositoriesCreateView(BaseFormView):
@ -28,14 +28,19 @@ class RepositoriesCreateView(BaseFormView):
login_required = True login_required = True
async def get(self): async def get(self):
return await self.render_template("settings/repositories/create.html") return await self.render_template("settings/repositories/create.html")
async def post(self): async def post(self):
data = await self.request.post() data = await self.request.post()
repository = await self.services.repository.create(user_uid=self.session.get("uid"), name=data['name'], is_private=int(data.get('is_private',0))) await self.services.repository.create(
user_uid=self.session.get("uid"),
name=data["name"],
is_private=int(data.get("is_private", 0)),
)
return web.HTTPFound("/settings/repositories/index.html") return web.HTTPFound("/settings/repositories/index.html")
class RepositoriesUpdateView(BaseFormView): class RepositoriesUpdateView(BaseFormView):
login_required = True login_required = True
@ -47,40 +52,41 @@ class RepositoriesUpdateView(BaseFormView):
) )
if not repository: if not repository:
return web.HTTPNotFound() return web.HTTPNotFound()
return await self.render_template("settings/repositories/update.html", {"repository": repository.record}) return await self.render_template(
"settings/repositories/update.html", {"repository": repository.record}
)
async def post(self): async def post(self):
data = await self.request.post() data = await self.request.post()
repository = await self.services.repository.get( repository = await self.services.repository.get(
user_uid=self.session.get("uid"), name=self.request.match_info["name"] user_uid=self.session.get("uid"), name=self.request.match_info["name"]
) )
repository['is_private'] = int(data.get('is_private',0)) repository["is_private"] = int(data.get("is_private", 0))
await self.services.repository.save(repository) await self.services.repository.save(repository)
return web.HTTPFound("/settings/repositories/index.html") return web.HTTPFound("/settings/repositories/index.html")
class RepositoriesDeleteView(BaseFormView): class RepositoriesDeleteView(BaseFormView):
login_required = True login_required = True
async def get(self): async def get(self):
repository = await self.services.repository.get( repository = await self.services.repository.get(
user_uid=self.session.get("uid"), name=self.request.match_info["name"] user_uid=self.session.get("uid"), name=self.request.match_info["name"]
) )
if not repository: if not repository:
return web.HTTPNotFound() return web.HTTPNotFound()
return await self.render_template("settings/repositories/delete.html", {"repository": repository.record}) return await self.render_template(
"settings/repositories/delete.html", {"repository": repository.record}
)
async def post(self): async def post(self):
user_uid = self.session.get("uid") user_uid = self.session.get("uid")
name = self.request.match_info["name"] name = self.request.match_info["name"]
repository = await self.services.repository.get( repository = await self.services.repository.get(user_uid=user_uid, name=name)
user_uid=user_uid, name=name
)
if not repository: if not repository:
return web.HTTPNotFound() return web.HTTPNotFound()
await self.services.repository.delete(user_uid=user_uid, name=name) await self.services.repository.delete(user_uid=user_uid, name=name)
return web.HTTPFound("/settings/repositories/index.html") return web.HTTPFound("/settings/repositories/index.html")