This commit is contained in:
retoor 2025-06-07 05:47:54 +02:00
parent f182c2209e
commit 389d417c1b
45 changed files with 2158 additions and 1235 deletions

View File

@ -1,27 +1,31 @@
import click
import uvloop
from aiohttp import web
import asyncio
from snek.app import Application
from IPython import start_ipython
import sqlite3
import pathlib
import shutil
import shutil
import sqlite3
import click
from aiohttp import web
from IPython import start_ipython
from snek.app import Application
@click.group()
def cli():
pass
@cli.command()
@click.option('--db_path',default="snek.db", help='Database to initialize if not exists.')
@click.option('--source',default=None, help='Database to initialize if not exists.')
def init(db_path,source):
@click.option(
"--db_path", default="snek.db", help="Database to initialize if not exists."
)
@click.option("--source", default=None, help="Database to initialize if not exists.")
def init(db_path, source):
if source and pathlib.Path(source).exists():
print(f"Copying {source} to {db_path}")
shutil.copy2(source,db_path)
shutil.copy2(source, db_path)
print("Database initialized.")
return
if pathlib.Path(db_path).exists():
return
print(f"Initializing database at {db_path}")
@ -33,25 +37,44 @@ def init(db_path,source):
db.close()
print("Database initialized.")
@cli.command()
@click.option('--port', default=8081, show_default=True, help='Port to run the application on')
@click.option('--host', default='0.0.0.0', show_default=True, help='Host to run the application on')
@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application')
def serve(port, host, db_path):
#init(db_path)
#asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
web.run_app(
Application(db_path=f"sqlite:///{db_path}"), port=port, host=host
)
@cli.command()
@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application')
@click.option(
"--port", default=8081, show_default=True, help="Port to run the application on"
)
@click.option(
"--host",
default="0.0.0.0",
show_default=True,
help="Host to run the application on",
)
@click.option(
"--db_path",
default="snek.db",
show_default=True,
help="Database path for the application",
)
def serve(port, host, db_path):
# init(db_path)
# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
web.run_app(Application(db_path=f"sqlite:///{db_path}"), port=port, host=host)
@cli.command()
@click.option(
"--db_path",
default="snek.db",
show_default=True,
help="Database path for the application",
)
def shell(db_path):
app = Application(db_path=f"sqlite:///{db_path}")
start_ipython(argv=[], user_ns={'app': app})
start_ipython(argv=[], user_ns={"app": app})
def main():
cli()
if __name__ == "__main__":
main()

View File

@ -4,13 +4,15 @@ import pathlib
import ssl
import uuid
from datetime import datetime
from snek import snode
from snek.view.threads import ThreadsView
logging.basicConfig(level=logging.DEBUG)
from ipaddress import ip_address
from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address
import IP2Location
from aiohttp import web
from aiohttp_session import (
get_session as session_get,
@ -20,54 +22,56 @@ from aiohttp_session import (
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from app.app import Application as BaseApplication
from jinja2 import FileSystemLoader
import IP2Location
from snek.sssh import start_ssh_server
from snek.docs.app import Application as DocsApplication
from snek.mapper import get_mappers
from snek.service import get_services
from snek.sgit import GitApplication
from snek.sssh import start_ssh_server
from snek.system import http
from snek.system.cache import Cache
from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler
from snek.system.template import EmojiExtension, LinkifyExtension, PythonExtension
from snek.system.template import (
EmojiExtension,
LinkifyExtension,
PythonExtension,
sanitize_html,
)
from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView
from snek.view.channel import ChannelAttachmentView, ChannelView
from snek.view.docs import DocsHTMLView, DocsMDView
from snek.view.drive import DriveView
from snek.view.drive import DriveApiView
from snek.view.drive import DriveApiView, DriveView
from snek.view.index import IndexView
from snek.view.login import LoginView
from snek.view.push import PushView
from snek.view.logout import LogoutView
from snek.view.push import PushView
from snek.view.register import RegisterView
from snek.view.rpc import RPCView
from snek.view.repository import RepositoryView
from snek.view.rpc import RPCView
from snek.view.search_user import SearchUserView
from snek.view.settings.repositories import RepositoriesIndexView
from snek.view.settings.repositories import RepositoriesCreateView
from snek.view.settings.repositories import RepositoriesUpdateView
from snek.view.settings.repositories import RepositoriesDeleteView
from snek.view.settings.containers import (
ContainersCreateView,
ContainersDeleteView,
ContainersIndexView,
ContainersUpdateView,
)
from snek.view.settings.index import SettingsIndexView
from snek.view.settings.profile import SettingsProfileView
from snek.view.settings.repositories import (
RepositoriesCreateView,
RepositoriesDeleteView,
RepositoriesIndexView,
RepositoriesUpdateView,
)
from snek.view.stats import StatsView
from snek.view.status import StatusView
from snek.view.terminal import TerminalSocketView, TerminalView
from snek.view.upload import UploadView
from snek.view.user import UserView
from snek.view.web import WebView
from snek.view.channel import ChannelAttachmentView
from snek.view.channel import ChannelView
from snek.view.settings.containers import (
ContainersIndexView,
ContainersCreateView,
ContainersUpdateView,
ContainersDeleteView,
)
from snek.webdav import WebdavApplication
from snek.system.template import sanitize_html
from snek.sgit import GitApplication
SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34"
from snek.system.template import whitelist_attributes
@ -94,7 +98,7 @@ async def ip2location_middleware(request, handler):
if not user:
return response
location = request.app.ip2location.get(ip)
original_city = user["city"]
user["city"]
if user["city"] != location.city:
user["country_long"] = location.country
user["country_short"] = locaion.country_short
@ -121,7 +125,7 @@ class Application(BaseApplication):
cors_middleware,
web.normalize_path_middleware(merge_slashes=True),
ip2location_middleware,
csp_middleware
csp_middleware,
]
self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
self.static_path = pathlib.Path(__file__).parent.joinpath("static")
@ -139,7 +143,7 @@ class Application(BaseApplication):
self.jinja2_env.add_extension(LinkifyExtension)
self.jinja2_env.add_extension(PythonExtension)
self.jinja2_env.add_extension(EmojiExtension)
self.jinja2_env.filters['sanitize'] = sanitize_html
self.jinja2_env.filters["sanitize"] = sanitize_html
self.time_start = datetime.now()
self.ssh_host = "0.0.0.0"
self.ssh_port = 2242
@ -272,8 +276,12 @@ class Application(BaseApplication):
self.router.add_get("/http-photo", self.handle_http_photo)
self.router.add_get("/rpc.ws", RPCView)
self.router.add_get("/c/{channel:.*}", ChannelView)
self.router.add_view("/channel/{channel_uid}/attachment.bin",ChannelAttachmentView)
self.router.add_view("/channel/attachment/{relative_url:.*}",ChannelAttachmentView)
self.router.add_view(
"/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
)
self.router.add_view(
"/channel/attachment/{relative_url:.*}", ChannelAttachmentView
)
self.router.add_view("/channel/{channel}.html", WebView)
self.router.add_view("/threads.html", ThreadsView)
self.router.add_view("/terminal.ws", TerminalSocketView)
@ -284,15 +292,29 @@ class Application(BaseApplication):
self.router.add_view("/stats.json", StatsView)
self.router.add_view("/user/{user}.html", UserView)
self.router.add_view("/repository/{username}/{repository}", RepositoryView)
self.router.add_view("/repository/{username}/{repository}/{path:.*}", RepositoryView)
self.router.add_view(
"/repository/{username}/{repository}/{path:.*}", RepositoryView
)
self.router.add_view("/settings/repositories/index.html", RepositoriesIndexView)
self.router.add_view("/settings/repositories/create.html", RepositoriesCreateView)
self.router.add_view("/settings/repositories/repository/{name}/update.html", RepositoriesUpdateView)
self.router.add_view("/settings/repositories/repository/{name}/delete.html", RepositoriesDeleteView)
self.router.add_view(
"/settings/repositories/create.html", RepositoriesCreateView
)
self.router.add_view(
"/settings/repositories/repository/{name}/update.html",
RepositoriesUpdateView,
)
self.router.add_view(
"/settings/repositories/repository/{name}/delete.html",
RepositoriesDeleteView,
)
self.router.add_view("/settings/containers/index.html", ContainersIndexView)
self.router.add_view("/settings/containers/create.html", ContainersCreateView)
self.router.add_view("/settings/containers/container/{uid}/update.html", ContainersUpdateView)
self.router.add_view("/settings/containers/container/{uid}/delete.html", ContainersDeleteView)
self.router.add_view(
"/settings/containers/container/{uid}/update.html", ContainersUpdateView
)
self.router.add_view(
"/settings/containers/container/{uid}/delete.html", ContainersDeleteView
)
self.webdav = WebdavApplication(self)
self.git = GitApplication(self)
self.add_subapp("/webdav", self.webdav)
@ -302,9 +324,9 @@ class Application(BaseApplication):
async def handle_test(self, request):
return await whitelist_attributes(self.render_template(
"test.html", request, context={"name": "retoor"}
))
return await whitelist_attributes(
self.render_template("test.html", request, context={"name": "retoor"})
)
async def handle_http_get(self, request: web.Request):
url = request.query.get("url")
@ -375,8 +397,8 @@ class Application(BaseApplication):
self.jinja2_env.loader = self.original_loader
#rendered.text = whitelist_attributes(rendered.text)
#rendered.headers['Content-Lenght'] = len(rendered.text)
# rendered.text = whitelist_attributes(rendered.text)
# rendered.headers['Content-Lenght'] = len(rendered.text)
return rendered
async def static_handler(self, request):
@ -423,7 +445,7 @@ app = Application(db_path="sqlite:///snek.db")
async def main():
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain('cert.pem', 'key.pem')
ssl_context.load_cert_chain("cert.pem", "key.pem")
await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context)

View File

@ -1,6 +1,7 @@
import asyncio
import sys
class LoadBalancer:
def __init__(self, backend_ports):
self.backend_ports = backend_ports
@ -8,27 +9,29 @@ class LoadBalancer:
self.client_counts = [0] * len(backend_ports)
self.lock = asyncio.Lock()
async def start_backend_servers(self,port,workers):
async def start_backend_servers(self, port, workers):
for x in range(workers):
port += 1
process = await asyncio.create_subprocess_exec(
sys.executable,
sys.argv[0],
'backend',
"backend",
str(port),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
stderr=asyncio.subprocess.PIPE,
)
port += 1
self.backend_processes.append(process)
print(f"Started backend server on port {(port-1)/port} with PID {process.pid}")
print(
f"Started backend server on port {(port-1)/port} with PID {process.pid}"
)
async def handle_client(self, reader, writer):
async with self.lock:
min_clients = min(self.client_counts)
server_index = self.client_counts.index(min_clients)
self.client_counts[server_index] += 1
backend = ('127.0.0.1', self.backend_ports[server_index])
backend = ("127.0.0.1", self.backend_ports[server_index])
try:
backend_reader, backend_writer = await asyncio.open_connection(*backend)
@ -62,10 +65,10 @@ class LoadBalancer:
for i, count in enumerate(self.client_counts):
print(f"Server {self.backend_ports[i]}: {count} clients")
async def start(self, host='0.0.0.0', port=8081,workers=5):
await self.start_backend_servers(port,workers)
async def start(self, host="0.0.0.0", port=8081, workers=5):
await self.start_backend_servers(port, workers)
server = await asyncio.start_server(self.handle_client, host, port)
monitor_task = asyncio.create_task(self.monitor())
asyncio.create_task(self.monitor())
# Handle shutdown gracefully
try:
@ -80,6 +83,7 @@ class LoadBalancer:
await asyncio.gather(*(p.wait() for p in self.backend_processes))
print("Backend processes terminated.")
async def backend_echo_server(port):
async def handle_echo(reader, writer):
try:
@ -94,10 +98,11 @@ async def backend_echo_server(port):
finally:
writer.close()
server = await asyncio.start_server(handle_echo, '127.0.0.1', port)
server = await asyncio.start_server(handle_echo, "127.0.0.1", port)
print(f"Backend echo server running on port {port}")
await server.serve_forever()
async def main():
backend_ports = [8001, 8003, 8005, 8006]
# Launch backend echo servers
@ -105,19 +110,20 @@ async def main():
lb = LoadBalancer(backend_ports)
await lb.start()
if __name__ == "__main__":
if len(sys.argv) > 1:
if sys.argv[1] == 'backend':
if sys.argv[1] == "backend":
port = int(sys.argv[2])
from snek.app import Application
snek = Application(port=port)
web.run_app(snek, port=port, host='127.0.0.1')
elif sys.argv[1] == 'sync':
from snek.sync import app
web.run_app(snek, port=port, host='127.0.0.1')
web.run_app(snek, port=port, host="127.0.0.1")
elif sys.argv[1] == "sync":
web.run_app(snek, port=port, host="127.0.0.1")
else:
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Shutting down...")

View File

@ -1,22 +1,21 @@
import functools
from snek.mapper.channel import ChannelMapper
from snek.mapper.channel_attachment import ChannelAttachmentMapper
from snek.mapper.channel_member import ChannelMemberMapper
from snek.mapper.channel_message import ChannelMessageMapper
from snek.mapper.container import ContainerMapper
from snek.mapper.drive import DriveMapper
from snek.mapper.drive_item import DriveItemMapper
from snek.mapper.notification import NotificationMapper
from snek.mapper.push import PushMapper
from snek.mapper.repository import RepositoryMapper
from snek.mapper.user import UserMapper
from snek.mapper.user_property import UserPropertyMapper
from snek.mapper.repository import RepositoryMapper
from snek.mapper.push import PushMapper
from snek.mapper.channel_attachment import ChannelAttachmentMapper
from snek.mapper.container import ContainerMapper
from snek.system.object import Object
@functools.cache
def get_mappers(app=None):
return Object(
**{

View File

@ -1,6 +1,7 @@
from snek.model.container import Container
from snek.system.mapper import BaseMapper
class ContainerMapper(BaseMapper):
model_class = Container
table_name = "container"
table_name = "container"

View File

@ -1,19 +1,19 @@
import functools
from snek.model.channel import ChannelModel
from snek.model.channel_attachment import ChannelAttachmentModel
from snek.model.channel_member import ChannelMemberModel
# from snek.model.channel_message import ChannelMessageModel
from snek.model.channel_message import ChannelMessageModel
from snek.model.container import Container
from snek.model.drive import DriveModel
from snek.model.drive_item import DriveItemModel
from snek.model.notification import NotificationModel
from snek.model.push_registration import PushRegistrationModel
from snek.model.repository import RepositoryModel
from snek.model.user import UserModel
from snek.model.user_property import UserPropertyModel
from snek.model.repository import RepositoryModel
from snek.model.channel_attachment import ChannelAttachmentModel
from snek.model.push_registration import PushRegistrationModel
from snek.model.container import Container
from snek.system.object import Object

View File

@ -1,4 +1,3 @@
from snek.system.model import BaseModel
from snek.system.model import BaseModel, ModelField
@ -11,6 +10,6 @@ class ChannelAttachmentModel(BaseModel):
user_uid = ModelField(name="user_uid", required=True, kind=str)
mime_type = ModelField(name="type", required=True, kind=str)
relative_url = ModelField(name="relative_url", required=True, kind=str)
resource_type = ModelField(name="resource_type", required=True, kind=str,value="file")
resource_type = ModelField(
name="resource_type", required=True, kind=str, value="file"
)

View File

@ -1,6 +1,8 @@
from datetime import datetime, timezone
from snek.model.user import UserModel
from snek.system.model import BaseModel, ModelField
from datetime import datetime,timezone
class ChannelMessageModel(BaseModel):
channel_uid = ModelField(name="channel_uid", required=True, kind=str)
@ -10,7 +12,11 @@ class ChannelMessageModel(BaseModel):
is_final = ModelField(name="is_final", required=True, kind=bool, value=True)
def get_seconds_since_last_update(self):
return int((datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])).total_seconds())
return int(
(
datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])
).total_seconds()
)
async def get_user(self) -> UserModel:
return await self.app.services.user.get(uid=self["user_uid"])

View File

@ -1,5 +1,6 @@
from snek.system.model import BaseModel, ModelField
class Container(BaseModel):
id = ModelField(name="id", required=True, kind=str)
name = ModelField(name="name", required=True, kind=str)

View File

@ -9,7 +9,9 @@ class DriveItemModel(BaseModel):
path = ModelField(name="path", required=True, kind=str)
file_type = ModelField(name="file_type", required=True, kind=str)
file_size = ModelField(name="file_size", required=True, kind=int)
is_available = ModelField(name="is_available", required=True, kind=bool, initial_value=True)
is_available = ModelField(
name="is_available", required=True, kind=bool, initial_value=True
)
@property
def extension(self):

View File

@ -1,14 +1,10 @@
from snek.model.user import UserModel
from snek.system.model import BaseModel, ModelField
class RepositoryModel(BaseModel):
user_uid = ModelField(name="user_uid", required=True, kind=str)
name = ModelField(name="name", required=True, kind=str)
is_private = ModelField(name="is_private", required=False, kind=bool)

View File

@ -30,7 +30,7 @@ class UserModel(BaseModel):
last_ping = ModelField(name="last_ping", required=False, kind=str)
is_admin = ModelField(name="is_admin", required=False, kind=bool)
country_short = ModelField(name="country_short", required=False, kind=str)
country_long = ModelField(name="country_long", required=False, kind=str)
city = ModelField(name="city", required=False, kind=str)

View File

@ -1,51 +1,54 @@
import snek.serpentarium
import time
import time
from concurrent.futures import ProcessPoolExecutor
import snek.serpentarium
durations = []
def task1():
global durations
global durations
client = snek.serpentarium.DatasetWrapper()
start=time.time()
start = time.time()
for x in range(1500):
client['a'].delete()
client['a'].insert({"foo": x})
client['a'].find(foo=x)
client['a'].find_one(foo=x)
client['a'].count()
#print(client['a'].find(foo=x) )
#print(client['a'].find_one(foo=x) )
#print(client['a'].count())
client["a"].delete()
client["a"].insert({"foo": x})
client["a"].find(foo=x)
client["a"].find_one(foo=x)
client["a"].count()
# print(client['a'].find(foo=x) )
# print(client['a'].find_one(foo=x) )
# print(client['a'].count())
client.close()
duration1 = f"{time.time()-start}"
durations.append(duration1)
print(durations)
with ProcessPoolExecutor(max_workers=4) as executor:
tasks = [executor.submit(task1),
tasks = [
executor.submit(task1),
executor.submit(task1),
executor.submit(task1),
executor.submit(task1),
executor.submit(task1)
]
for task in tasks:
task.result()
import dataset
import dataset
client = dataset.connect("sqlite:///snek.db")
start=time.time()
start = time.time()
for x in range(1500):
client['a'].delete()
client['a'].insert({"foo": x})
print([dict(row) for row in client['a'].find(foo=x)])
print(dict(client['a'].find_one(foo=x) ))
print(client['a'].count())
client["a"].delete()
client["a"].insert({"foo": x})
print([dict(row) for row in client["a"].find(foo=x)])
print(dict(client["a"].find_one(foo=x)))
print(client["a"].count())
duration2 = f"{time.time()-start}"
print(duration1,duration2)
print(duration1, duration2)

View File

@ -1,22 +1,22 @@
import functools
from snek.service.channel import ChannelService
from snek.service.channel_attachment import ChannelAttachmentService
from snek.service.channel_member import ChannelMemberService
from snek.service.channel_message import ChannelMessageService
from snek.service.chat import ChatService
from snek.service.container import ContainerService
from snek.service.db import DBService
from snek.service.drive import DriveService
from snek.service.drive_item import DriveItemService
from snek.service.notification import NotificationService
from snek.service.push import PushService
from snek.service.repository import RepositoryService
from snek.service.socket import SocketService
from snek.service.user import UserService
from snek.service.push import PushService
from snek.service.user_property import UserPropertyService
from snek.service.util import UtilService
from snek.service.repository import RepositoryService
from snek.service.channel_attachment import ChannelAttachmentService
from snek.service.container import ContainerService
from snek.system.object import Object
from snek.service.db import DBService
@functools.cache

View File

@ -1,9 +1,9 @@
import pathlib
from datetime import datetime
from snek.system.model import now
from snek.system.service import BaseService
import pathlib
class ChannelService(BaseService):
mapper_name = "channel"
@ -17,12 +17,10 @@ class ChannelService(BaseService):
pass
return folder
async def get_attachment_folder(self, channel_uid,ensure=False):
async def get_attachment_folder(self, channel_uid, ensure=False):
path = pathlib.Path(f"./drive/{channel_uid}/attachments")
if ensure:
path.mkdir(
parents=True, exist_ok=True
)
path.mkdir(parents=True, exist_ok=True)
return path
async def get(self, uid=None, **kwargs):
@ -78,7 +76,10 @@ class ChannelService(BaseService):
return channel
async def get_recent_users(self, channel_uid):
async for user in self.query("SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30", {"channel_uid": channel_uid}):
async for user in self.query(
"SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30",
{"channel_uid": channel_uid},
):
yield user
async def get_users(self, channel_uid):

View File

@ -1,25 +1,25 @@
from snek.system.service import BaseService
import urllib.parse
import pathlib
import mimetypes
import uuid
from snek.system.service import BaseService
class ChannelAttachmentService(BaseService):
mapper_name="channel_attachment"
mapper_name = "channel_attachment"
async def create_file(self, channel_uid, user_uid, name):
attachment = await self.new()
attachment["channel_uid"] = channel_uid
attachment['user_uid'] = user_uid
attachment["user_uid"] = user_uid
attachment["name"] = name
attachment["mime_type"] = mimetypes.guess_type(name)[0]
attachment['resource_type'] = "file"
attachment["resource_type"] = "file"
real_file_name = f"{attachment['uid']}-{name}"
attachment["relative_url"] = (f"{attachment['uid']}-{name}")
attachment_folder = await self.services.channel.get_attachment_folder(channel_uid)
attachment["relative_url"] = f"{attachment['uid']}-{name}"
attachment_folder = await self.services.channel.get_attachment_folder(
channel_uid
)
attachment_path = attachment_folder.joinpath(real_file_name)
attachment["path"] = str(attachment_path)
if await self.save(attachment):
return attachment
raise Exception(f"Failed to create channel attachment: {attachment.errors}.")

View File

@ -11,7 +11,7 @@ class ChannelMessageService(BaseService):
model["channel_uid"] = channel_uid
model["user_uid"] = user_uid
model["message"] = message
model['is_final'] = is_final
model["is_final"] = is_final
context = {}
@ -56,7 +56,7 @@ class ChannelMessageService(BaseService):
async def save(self, model):
context = {}
context.update(model.record)
user = await self.app.services.user.get(model['user_uid'])
user = await self.app.services.user.get(model["user_uid"])
context.update(
{
"user_uid": user["uid"],

View File

@ -13,13 +13,13 @@ class ChatService(BaseService):
channel["last_message_on"] = now()
await self.services.channel.save(channel)
await self.services.socket.broadcast(
channel['uid'],
channel["uid"],
{
"message": channel_message["message"],
"html": channel_message["html"],
"user_uid": user['uid'],
"user_uid": user["uid"],
"color": user["color"],
"channel_uid": channel['uid'],
"channel_uid": channel["uid"],
"created_at": channel_message["created_at"],
"updated_at": channel_message["updated_at"],
"username": user["username"],
@ -32,14 +32,12 @@ class ChatService(BaseService):
self.services.notification.create_channel_message(message_uid)
)
async def send(self, user_uid, channel_uid, message, is_final=True):
channel = await self.services.channel.get(uid=channel_uid)
if not channel:
raise Exception("Channel not found.")
channel_message = await self.services.channel_message.create(
channel_uid, user_uid, message,is_final
channel_uid, user_uid, message, is_final
)
channel_message_uid = channel_message["uid"]

View File

@ -1,36 +1,51 @@
from snek.system.docker import ComposeFileManager
from snek.system.service import BaseService
from snek.system.docker import ComposeFileManager
class ContainerService(BaseService):
mapper_name = "container"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.compose_path = "snek-container-compose.yml"
self.compose = ComposeFileManager(self.compose_path)
async def get_instances(self):
return list(self.compose.list_instances())
async def get_container_name(self, channel_uid):
return f"channel-{channel_uid}"
async def create(self, channel_uid, image="ubuntu:latest", command=None, cpus=1, memory='1024m', ports=None, volumes=None):
async def create(
self,
channel_uid,
image="ubuntu:latest",
command=None,
cpus=1,
memory="1024m",
ports=None,
volumes=None,
):
name = await self.get_container_name(channel_uid)
self.compose.create_instance(
name,
image,
command,
cpus,
memory,
ports,
["./"+str((await self.services.channel.get_home_folder(channel_uid))) +":"+ "/home/ubuntu"]
name,
image,
command,
cpus,
memory,
ports,
[
"./"
+ str(await self.services.channel.get_home_folder(channel_uid))
+ ":"
+ "/home/ubuntu"
],
)
async def create2(self, id, name, status, resources=None, user_uid=None, path=None, readonly=False):
async def create2(
self, id, name, status, resources=None, user_uid=None, path=None, readonly=False
):
model = await self.new()
model["id"] = id
model["name"] = name

View File

@ -1,23 +1,21 @@
from snek.system.service import BaseService
import dataset
import uuid
import dataset
from snek.system.service import BaseService
from datetime import datetime
class DBService(BaseService):
async def get_db(self, user_uid):
home_folder = await self.app.services.user.get_home_folder(user_uid)
home_folder.mkdir(parents=True, exist_ok=True)
db_path = home_folder.joinpath("snek/user.db")
db_path.parent.mkdir(parents=True, exist_ok=True)
return dataset.connect("sqlite:///" + str(db_path))
return dataset.connect("sqlite:///" + str(db_path))
async def insert(self, user_uid, table_name, values):
db = await self.get_db(user_uid)
return db[table_name].insert(values)
async def update(self, user_uid, table_name, values, filters):
db = await self.get_db(user_uid)
@ -33,7 +31,7 @@ class DBService(BaseService):
async def find(self, user_uid, table_name, kwargs):
db = await self.get_db(user_uid)
kwargs['_limit'] = kwargs.get('_limit', 30)
kwargs["_limit"] = kwargs.get("_limit", 30)
return [dict(row) for row in db[table_name].find(**kwargs)]
async def get(self, user_uid, table_name, filters):
@ -45,14 +43,13 @@ class DBService(BaseService):
except ValueError:
return None
async def delete(self, user_uid, table_name, filters):
db = await self.get_db(user_uid)
if not filters:
filters = {}
return db[table_name].delete(**filters)
async def query(self, sql,values):
async def query(self, sql, values):
db = await self.app.db
return [dict(row) for row in db.query(sql, values or {})]
@ -62,8 +59,6 @@ class DBService(BaseService):
filters = {}
return bool(db[table_name].find_one(**filters))
async def count(self, user_uid, table_name, filters):
db = await self.get_db(user_uid)
if not filters:

View File

@ -1,6 +1,7 @@
from snek.system.markdown import strip_markdown
from snek.system.model import now
from snek.system.service import BaseService
from snek.system.markdown import strip_markdown
class NotificationService(BaseService):
mapper_name = "notification"
@ -34,7 +35,7 @@ class NotificationService(BaseService):
uid=channel_message_uid
)
if not channel_message["is_final"]:
return
return
user = await self.services.user.get(uid=channel_message["user_uid"])
self.app.db.begin()
async for channel_member in self.services.channel_member.find(

View File

@ -1,25 +1,23 @@
import base64
import json
import aiohttp
from snek.system.service import BaseService
import os.path
import random
import time
import base64
import uuid
from functools import cache
from pathlib import Path
import jwt
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from urllib.parse import urlparse
import os.path
import aiohttp
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from snek.system.service import BaseService
# The only reason to persist the keys is to be able to use them in the web push
PRIVATE_KEY_FILE = Path("./notification-private.pem")
@ -268,4 +266,4 @@ class PushService(BaseService):
raise Exception(
f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}"
)
)

View File

@ -1,13 +1,17 @@
from snek.system.service import BaseService
import asyncio
import asyncio
import shutil
from snek.system.service import BaseService
class RepositoryService(BaseService):
mapper_name = "repository"
async def delete(self, user_uid, name):
loop = asyncio.get_event_loop()
repository_path = (await self.services.user.get_repository_path(user_uid)).joinpath(name)
repository_path = (
await self.services.user.get_repository_path(user_uid)
).joinpath(name)
try:
await loop.run_in_executor(None, shutil.rmtree, repository_path)
except Exception as ex:
@ -15,7 +19,6 @@ class RepositoryService(BaseService):
await super().delete(user_uid=user_uid, name=name)
async def exists(self, user_uid, name, **kwargs):
kwargs["user_uid"] = user_uid
kwargs["name"] = name
@ -29,18 +32,16 @@ class RepositoryService(BaseService):
repository_path = str(repository_path)
if not repository_path.endswith(".git"):
repository_path += ".git"
command = ['git', 'init', '--bare', repository_path]
command = ["git", "init", "--bare", repository_path]
process = await asyncio.subprocess.create_subprocess_exec(
*command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
return process.returncode == 0
async def create(self, user_uid, name,is_private=False):
async def create(self, user_uid, name, is_private=False):
if await self.exists(user_uid=user_uid, name=name):
return False
return False
if not await self.init(user_uid=user_uid, name=name):
return False

View File

@ -1,12 +1,14 @@
import asyncio
import logging
from datetime import datetime
from snek.model.user import UserModel
from snek.system.service import BaseService
from datetime import datetime
import json
import asyncio
import logging
logger = logging.getLogger(__name__)
from snek.system.model import now
class SocketService(BaseService):
class Socket:
@ -15,7 +17,6 @@ class SocketService(BaseService):
self.is_connected = True
self.user = user
async def send_json(self, data):
if not self.is_connected:
return False
@ -40,7 +41,6 @@ class SocketService(BaseService):
self.users = {}
self.subscriptions = {}
self.last_update = str(datetime.now())
async def user_availability_service(self):
logger.info("User availability update service started.")
@ -50,14 +50,15 @@ class SocketService(BaseService):
for s in self.sockets:
if not s.is_connected:
continue
if not s.user in users_updated:
if s.user not in users_updated:
s.user["last_ping"] = now()
await self.app.services.user.save(s.user)
users_updated.append(s.user)
logger.info(f"Updated user availability for {len(users_updated)} online users.")
logger.info(
f"Updated user availability for {len(users_updated)} online users."
)
await asyncio.sleep(60)
async def add(self, ws, user_uid):
s = self.Socket(ws, await self.app.services.user.get(uid=user_uid))
self.sockets.add(s)
@ -81,7 +82,6 @@ class SocketService(BaseService):
count += 1
return count
async def broadcast(self, channel_uid, message):
await self._broadcast(channel_uid, message)
@ -102,4 +102,3 @@ class SocketService(BaseService):
await s.close()
logger.info(f"Removed socket for user {s.user['username']}")
self.sockets.remove(s)

View File

@ -32,14 +32,14 @@ class UserService(BaseService):
user["color"] = await self.services.util.random_light_hex_color()
return await super().save(user)
def authenticate_sync(self,username,password):
def authenticate_sync(self, username, password):
user = self.get_by_username_sync(username)
if not user:
return False
if not security.verify_sync(password, user["password"]):
return False
return True
return True
async def authenticate(self, username, password):
success = await self.validate_login(username, password)
@ -61,14 +61,12 @@ class UserService(BaseService):
return None
return path
async def get_template_path(self, user_uid):
path = pathlib.Path(f"./drive/{user_uid}/snek/templates")
if not path.exists():
return None
return path
def get_by_username_sync(self, username):
user = self.mapper.db["user"].find_one(username=username, deleted_at=None)
return dict(user)

View File

@ -1,48 +1,51 @@
import asyncio
import base64
import json
import logging
import os
import aiohttp
import shutil
import tempfile
from aiohttp import web
import shutil
import json
import tempfile
import asyncio
import logging
import base64
import pathlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('git_server')
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("git_server")
class GitApplication(web.Application):
def __init__(self, parent=None):
#import git
#globals()['git'] = git
# import git
# globals()['git'] = git
self.parent = parent
super().__init__(client_max_size=1024*1024*1024*5)
self.add_routes([
web.post('/create/{repo_name}', self.create_repository),
web.delete('/delete/{repo_name}', self.delete_repository),
web.get('/clone/{repo_name}', self.clone_repository),
web.post('/push/{repo_name}', self.push_repository),
web.post('/pull/{repo_name}', self.pull_repository),
web.get('/status/{repo_name}', self.status_repository),
#web.get('/list', self.list_repositories),
web.get('/branches/{repo_name}', self.list_branches),
web.post('/branches/{repo_name}', self.create_branch),
web.get('/log/{repo_name}', self.commit_log),
web.get('/file/{repo_name}/{file_path:.*}', self.file_content),
web.get('/{path:.+}/info/refs', self.git_smart_http),
web.post('/{path:.+}/git-upload-pack', self.git_smart_http),
web.post('/{path:.+}/git-receive-pack', self.git_smart_http),
web.get('/{repo_name}.git/info/refs', self.git_smart_http),
web.post('/{repo_name}.git/git-upload-pack', self.git_smart_http),
web.post('/{repo_name}.git/git-receive-pack', self.git_smart_http),
])
super().__init__(client_max_size=1024 * 1024 * 1024 * 5)
self.add_routes(
[
web.post("/create/{repo_name}", self.create_repository),
web.delete("/delete/{repo_name}", self.delete_repository),
web.get("/clone/{repo_name}", self.clone_repository),
web.post("/push/{repo_name}", self.push_repository),
web.post("/pull/{repo_name}", self.pull_repository),
web.get("/status/{repo_name}", self.status_repository),
# web.get('/list', self.list_repositories),
web.get("/branches/{repo_name}", self.list_branches),
web.post("/branches/{repo_name}", self.create_branch),
web.get("/log/{repo_name}", self.commit_log),
web.get("/file/{repo_name}/{file_path:.*}", self.file_content),
web.get("/{path:.+}/info/refs", self.git_smart_http),
web.post("/{path:.+}/git-upload-pack", self.git_smart_http),
web.post("/{path:.+}/git-receive-pack", self.git_smart_http),
web.get("/{repo_name}.git/info/refs", self.git_smart_http),
web.post("/{repo_name}.git/git-upload-pack", self.git_smart_http),
web.post("/{repo_name}.git/git-receive-pack", self.git_smart_http),
]
)
async def check_basic_auth(self, request):
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Basic "):
return None,None
return None, None
encoded_creds = auth_header.split("Basic ")[1]
decoded_creds = base64.b64decode(encoded_creds).decode()
username, password = decoded_creds.split(":", 1)
@ -50,27 +53,31 @@ class GitApplication(web.Application):
username=username, password=password
)
if not request["user"]:
return None,None
request["repository_path"] = await self.parent.services.user.get_repository_path(
request["user"]["uid"]
return None, None
request["repository_path"] = (
await self.parent.services.user.get_repository_path(request["user"]["uid"])
)
return request["user"]['username'],request["repository_path"]
return request["user"]["username"], request["repository_path"]
@staticmethod
def require_auth(handler):
async def wrapped(self, request, *args, **kwargs):
username, repository_path = await self.check_basic_auth(request)
if not username or not repository_path:
return web.Response(status=401, headers={'WWW-Authenticate': 'Basic'}, text='Authentication required')
request['username'] = username
request['repository_path'] = repository_path
return web.Response(
status=401,
headers={"WWW-Authenticate": "Basic"},
text="Authentication required",
)
request["username"] = username
request["repository_path"] = repository_path
return await handler(self, request, *args, **kwargs)
return wrapped
def repo_path(self, repository_path, repo_name):
return repository_path.joinpath(repo_name + '.git')
return repository_path.joinpath(repo_name + ".git")
def check_repo_exists(self, repository_path, repo_name):
repo_dir = self.repo_path(repository_path, repo_name)
@ -80,10 +87,10 @@ class GitApplication(web.Application):
@require_auth
async def create_repository(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
if not repo_name or '/' in repo_name or '..' in repo_name:
username = request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
if not repo_name or "/" in repo_name or ".." in repo_name:
return web.Response(text="Invalid repository name", status=400)
repo_dir = self.repo_path(repository_path, repo_name)
if os.path.exists(repo_dir):
@ -98,9 +105,9 @@ class GitApplication(web.Application):
@require_auth
async def delete_repository(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
username = request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@ -115,9 +122,9 @@ class GitApplication(web.Application):
@require_auth
async def clone_repository(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@ -126,15 +133,15 @@ class GitApplication(web.Application):
response_data = {
"repository": repo_name,
"clone_command": f"git clone {clone_url}",
"clone_url": clone_url
"clone_url": clone_url,
}
return web.json_response(response_data)
@require_auth
async def push_repository(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
username = request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@ -142,34 +149,40 @@ class GitApplication(web.Application):
data = await request.json()
except json.JSONDecodeError:
return web.Response(text="Invalid JSON data", status=400)
commit_message = data.get('commit_message', 'Update from server')
branch = data.get('branch', 'main')
changes = data.get('changes', [])
commit_message = data.get("commit_message", "Update from server")
branch = data.get("branch", "main")
changes = data.get("changes", [])
if not changes:
return web.Response(text="No changes provided", status=400)
with tempfile.TemporaryDirectory() as temp_dir:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
for change in changes:
file_path = os.path.join(temp_dir, change.get('file', ''))
content = change.get('content', '')
file_path = os.path.join(temp_dir, change.get("file", ""))
content = change.get("content", "")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as f:
with open(file_path, "w") as f:
f.write(content)
temp_repo.git.add(A=True)
if not temp_repo.config_reader().has_section('user'):
temp_repo.config_writer().set_value("user", "name", "Git Server").release()
temp_repo.config_writer().set_value("user", "email", "git@server.local").release()
if not temp_repo.config_reader().has_section("user"):
temp_repo.config_writer().set_value(
"user", "name", "Git Server"
).release()
temp_repo.config_writer().set_value(
"user", "email", "git@server.local"
).release()
temp_repo.index.commit(commit_message)
origin = temp_repo.remote('origin')
origin = temp_repo.remote("origin")
origin.push(refspec=f"{branch}:{branch}")
logger.info(f"Pushed to repository: {repo_name} for user {username}")
return web.Response(text=f"Successfully pushed changes to {repo_name}")
@require_auth
async def pull_repository(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
username = request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@ -177,13 +190,15 @@ class GitApplication(web.Application):
data = await request.json()
except json.JSONDecodeError:
data = {}
remote_url = data.get('remote_url')
branch = data.get('branch', 'main')
remote_url = data.get("remote_url")
branch = data.get("branch", "main")
if not remote_url:
return web.Response(text="Remote URL is required", status=400)
with tempfile.TemporaryDirectory() as temp_dir:
try:
local_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
local_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
remote_name = "pull_source"
try:
remote = local_repo.create_remote(remote_name, remote_url)
@ -192,38 +207,46 @@ class GitApplication(web.Application):
remote.set_url(remote_url)
remote.fetch()
local_repo.git.merge(f"{remote_name}/{branch}")
origin = local_repo.remote('origin')
origin = local_repo.remote("origin")
origin.push()
logger.info(f"Pulled to repository {repo_name} from {remote_url} for user {username}")
return web.Response(text=f"Successfully pulled changes from {remote_url} to {repo_name}")
logger.info(
f"Pulled to repository {repo_name} from {remote_url} for user {username}"
)
return web.Response(
text=f"Successfully pulled changes from {remote_url} to {repo_name}"
)
except Exception as e:
logger.error(f"Error pulling to {repo_name}: {str(e)}")
return web.Response(text=f"Error pulling changes: {str(e)}", status=500)
@require_auth
async def status_repository(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
with tempfile.TemporaryDirectory() as temp_dir:
try:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
branches = [b.name for b in temp_repo.branches]
active_branch = temp_repo.active_branch.name
commits = []
for commit in list(temp_repo.iter_commits(max_count=5)):
commits.append({
"id": commit.hexsha,
"author": f"{commit.author.name} <{commit.author.email}>",
"date": commit.committed_datetime.isoformat(),
"message": commit.message
})
commits.append(
{
"id": commit.hexsha,
"author": f"{commit.author.name} <{commit.author.email}>",
"date": commit.committed_datetime.isoformat(),
"message": commit.message,
}
)
files = []
for root, dirs, filenames in os.walk(temp_dir):
if '.git' in root:
if ".git" in root:
continue
for filename in filenames:
full_path = os.path.join(root, filename)
@ -234,50 +257,58 @@ class GitApplication(web.Application):
"branches": branches,
"active_branch": active_branch,
"recent_commits": commits,
"files": files
"files": files,
}
return web.json_response(status_info)
except Exception as e:
logger.error(f"Error getting status for {repo_name}: {str(e)}")
return web.Response(text=f"Error getting repository status: {str(e)}", status=500)
return web.Response(
text=f"Error getting repository status: {str(e)}", status=500
)
@require_auth
async def list_repositories(self, request):
username = request['username']
request["username"]
try:
repos = []
user_dir = self.REPO_DIR
if os.path.exists(user_dir):
for item in os.listdir(user_dir):
item_path = os.path.join(user_dir, item)
if os.path.isdir(item_path) and item.endswith('.git'):
if os.path.isdir(item_path) and item.endswith(".git"):
repos.append(item[:-4])
if request.query.get('format') == 'json':
if request.query.get("format") == "json":
return web.json_response({"repositories": repos})
else:
return web.Response(text="\n".join(repos) if repos else "No repositories found")
return web.Response(
text="\n".join(repos) if repos else "No repositories found"
)
except Exception as e:
logger.error(f"Error listing repositories: {str(e)}")
return web.Response(text=f"Error listing repositories: {str(e)}", status=500)
return web.Response(
text=f"Error listing repositories: {str(e)}", status=500
)
@require_auth
async def list_branches(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
with tempfile.TemporaryDirectory() as temp_dir:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
branches = [b.name for b in temp_repo.branches]
return web.json_response({"branches": branches})
@require_auth
async def create_branch(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
username = request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@ -285,145 +316,168 @@ class GitApplication(web.Application):
data = await request.json()
except json.JSONDecodeError:
return web.Response(text="Invalid JSON data", status=400)
branch_name = data.get('branch_name')
start_point = data.get('start_point', 'HEAD')
branch_name = data.get("branch_name")
start_point = data.get("start_point", "HEAD")
if not branch_name:
return web.Response(text="Branch name is required", status=400)
with tempfile.TemporaryDirectory() as temp_dir:
try:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
temp_repo.git.branch(branch_name, start_point)
temp_repo.git.push('origin', branch_name)
logger.info(f"Created branch {branch_name} in repository {repo_name} for user {username}")
temp_repo.git.push("origin", branch_name)
logger.info(
f"Created branch {branch_name} in repository {repo_name} for user {username}"
)
return web.Response(text=f"Created branch {branch_name}")
except Exception as e:
logger.error(f"Error creating branch {branch_name} in {repo_name}: {str(e)}")
logger.error(
f"Error creating branch {branch_name} in {repo_name}: {str(e)}"
)
return web.Response(text=f"Error creating branch: {str(e)}", status=500)
@require_auth
async def commit_log(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
repository_path = request['repository_path']
request["username"]
repo_name = request.match_info["repo_name"]
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
try:
limit = int(request.query.get('limit', 10))
branch = request.query.get('branch', 'main')
limit = int(request.query.get("limit", 10))
branch = request.query.get("branch", "main")
except ValueError:
return web.Response(text="Invalid limit parameter", status=400)
with tempfile.TemporaryDirectory() as temp_dir:
try:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
commits = []
try:
for commit in list(temp_repo.iter_commits(branch, max_count=limit)):
commits.append({
"id": commit.hexsha,
"short_id": commit.hexsha[:7],
"author": f"{commit.author.name} <{commit.author.email}>",
"date": commit.committed_datetime.isoformat(),
"message": commit.message.strip()
})
commits.append(
{
"id": commit.hexsha,
"short_id": commit.hexsha[:7],
"author": f"{commit.author.name} <{commit.author.email}>",
"date": commit.committed_datetime.isoformat(),
"message": commit.message.strip(),
}
)
except git.GitCommandError as e:
if "unknown revision or path" in str(e):
commits = []
else:
raise
return web.json_response({
"repository": repo_name,
"branch": branch,
"commits": commits
})
return web.json_response(
{"repository": repo_name, "branch": branch, "commits": commits}
)
except Exception as e:
logger.error(f"Error getting commit log for {repo_name}: {str(e)}")
return web.Response(text=f"Error getting commit log: {str(e)}", status=500)
return web.Response(
text=f"Error getting commit log: {str(e)}", status=500
)
@require_auth
async def file_content(self, request):
username = request['username']
repo_name = request.match_info['repo_name']
file_path = request.match_info.get('file_path', '')
branch = request.query.get('branch', 'main')
repository_path = request['repository_path']
request["username"]
repo_name = request.match_info["repo_name"]
file_path = request.match_info.get("file_path", "")
branch = request.query.get("branch", "main")
repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
with tempfile.TemporaryDirectory() as temp_dir:
try:
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
temp_repo = git.Repo.clone_from(
self.repo_path(repository_path, repo_name), temp_dir
)
try:
temp_repo.git.checkout(branch)
except git.GitCommandError:
return web.Response(text=f"Branch '{branch}' not found", status=404)
file_full_path = os.path.join(temp_dir, file_path)
if not os.path.exists(file_full_path):
return web.Response(text=f"File '{file_path}' not found", status=404)
return web.Response(
text=f"File '{file_path}' not found", status=404
)
if os.path.isdir(file_full_path):
files = os.listdir(file_full_path)
return web.json_response({
"repository": repo_name,
"path": file_path,
"type": "directory",
"contents": files
})
return web.json_response(
{
"repository": repo_name,
"path": file_path,
"type": "directory",
"contents": files,
}
)
else:
try:
with open(file_full_path, 'r') as f:
with open(file_full_path) as f:
content = f.read()
return web.Response(text=content)
except UnicodeDecodeError:
return web.Response(text=f"Cannot display binary file content for '{file_path}'", status=400)
return web.Response(
text=f"Cannot display binary file content for '{file_path}'",
status=400,
)
except Exception as e:
logger.error(f"Error getting file content from {repo_name}: {str(e)}")
return web.Response(text=f"Error getting file content: {str(e)}", status=500)
return web.Response(
text=f"Error getting file content: {str(e)}", status=500
)
@require_auth
async def git_smart_http(self, request):
username = request['username']
repository_path = request['repository_path']
request["username"]
repository_path = request["repository_path"]
path = request.path
async def get_repository_path():
req_path = path.lstrip('/')
if req_path.endswith('/info/refs'):
repo_name = req_path[:-len('/info/refs')]
elif req_path.endswith('/git-upload-pack'):
repo_name = req_path[:-len('/git-upload-pack')]
elif req_path.endswith('/git-receive-pack'):
repo_name = req_path[:-len('/git-receive-pack')]
req_path = path.lstrip("/")
if req_path.endswith("/info/refs"):
repo_name = req_path[: -len("/info/refs")]
elif req_path.endswith("/git-upload-pack"):
repo_name = req_path[: -len("/git-upload-pack")]
elif req_path.endswith("/git-receive-pack"):
repo_name = req_path[: -len("/git-receive-pack")]
else:
repo_name = req_path
if repo_name.endswith('.git'):
if repo_name.endswith(".git"):
repo_name = repo_name[:-4]
repo_name = repo_name[4:]
repo_dir = repository_path.joinpath(repo_name + ".git")
logger.info(f"Resolved repo path: {repo_dir}")
return repo_dir
return repo_dir
async def handle_info_refs(service):
repo_path = await get_repository_path()
logger.info(f"handle_info_refs: {repo_path}")
if not os.path.exists(repo_path):
return web.Response(text="Repository not found", status=404)
cmd = [service, '--stateless-rpc', '--advertise-refs', str(repo_path)]
cmd = [service, "--stateless-rpc", "--advertise-refs", str(repo_path)]
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
logger.error(f"Git command failed: {stderr.decode()}")
return web.Response(text=f"Git error: {stderr.decode()}", status=500)
return web.Response(
text=f"Git error: {stderr.decode()}", status=500
)
response = web.StreamResponse(
status=200,
reason='OK',
reason="OK",
headers={
'Content-Type': f'application/x-{service}-advertisement',
'Cache-Control': 'no-cache'
}
"Content-Type": f"application/x-{service}-advertisement",
"Cache-Control": "no-cache",
},
)
await response.prepare(request)
packet = f"# service={service}\n"
@ -435,48 +489,58 @@ class GitApplication(web.Application):
except Exception as e:
logger.error(f"Error handling info/refs: {str(e)}")
return web.Response(text=f"Server error: {str(e)}", status=500)
async def handle_service_rpc(service):
repo_path = await get_repository_path()
logger.info(f"handle_service_rpc: {repo_path}")
if not os.path.exists(repo_path):
return web.Response(text="Repository not found", status=404)
if not request.headers.get('Content-Type') == f'application/x-{service}-request':
if (
not request.headers.get("Content-Type")
== f"application/x-{service}-request"
):
return web.Response(text="Invalid Content-Type", status=403)
body = await request.read()
cmd = [service, '--stateless-rpc', str(repo_path)]
cmd = [service, "--stateless-rpc", str(repo_path)]
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate(input=body)
if process.returncode != 0:
logger.error(f"Git command failed: {stderr.decode()}")
return web.Response(text=f"Git error: {stderr.decode()}", status=500)
return web.Response(
text=f"Git error: {stderr.decode()}", status=500
)
return web.Response(
body=stdout,
content_type=f'application/x-{service}-result'
body=stdout, content_type=f"application/x-{service}-result"
)
except Exception as e:
logger.error(f"Error handling service RPC: {str(e)}")
return web.Response(text=f"Server error: {str(e)}", status=500)
if request.method == 'GET' and path.endswith('/info/refs'):
service = request.query.get('service')
if service in ('git-upload-pack', 'git-receive-pack'):
if request.method == "GET" and path.endswith("/info/refs"):
service = request.query.get("service")
if service in ("git-upload-pack", "git-receive-pack"):
return await handle_info_refs(service)
else:
return web.Response(text="Smart HTTP requires service parameter", status=400)
elif request.method == 'POST' and '/git-upload-pack' in path:
return await handle_service_rpc('git-upload-pack')
elif request.method == 'POST' and '/git-receive-pack' in path:
return await handle_service_rpc('git-receive-pack')
return web.Response(
text="Smart HTTP requires service parameter", status=400
)
elif request.method == "POST" and "/git-upload-pack" in path:
return await handle_service_rpc("git-upload-pack")
elif request.method == "POST" and "/git-receive-pack" in path:
return await handle_service_rpc("git-receive-pack")
return web.Response(text="Not found", status=404)
if __name__ == '__main__':
if __name__ == "__main__":
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger.info("Using uvloop for improved performance")
except ImportError:

View File

@ -2,28 +2,28 @@ import aiohttp
ENABLED = False
import aiohttp
import asyncio
from aiohttp import web
import json
import sqlite3
import dataset
import aiohttp
from aiohttp import web
from sqlalchemy import event
from sqlalchemy.engine import Engine
import json
queue = asyncio.Queue()
class State:
do_not_sync = False
do_not_sync = False
async def sync_service(app):
if not ENABLED:
return
return
session = aiohttp.ClientSession()
async with session.ws_connect('http://localhost:3131/ws') as ws:
async with session.ws_connect("http://localhost:3131/ws") as ws:
async def receive():
queries_synced = 0
@ -31,7 +31,7 @@ async def sync_service(app):
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
State.do_not_sync = True
State.do_not_sync = True
app.db.execute(*data)
app.db.commit()
State.do_not_sync = False
@ -41,21 +41,25 @@ async def sync_service(app):
await app.services.socket.broadcast_event()
except Exception as e:
print(e)
pass
#print(f"Received: {msg.data}")
pass
# print(f"Received: {msg.data}")
elif msg.type == aiohttp.WSMsgType.ERROR:
break
async def write():
while True:
msg = await queue.get()
await ws.send_str(json.dumps(msg,default=str))
await ws.send_str(json.dumps(msg, default=str))
queue.task_done()
await asyncio.gather(receive(), write())
await session.close()
queries_queued = 0
# Attach a listener to log all executed statements
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
@ -63,7 +67,7 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
return
global queries_queued
if State.do_not_sync:
print(statement,parameters)
print(statement, parameters)
return
if statement.startswith("SELECT"):
return
@ -71,17 +75,18 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
queries_queued += 1
print("Queries queued: " + str(queries_queued))
async def websocket_handler(request):
queries_broadcasted = 0
queries_broadcasted = 0
ws = web.WebSocketResponse()
await ws.prepare(request)
request.app['websockets'].append(ws)
request.app["websockets"].append(ws)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
for client in request.app['websockets']:
for client in request.app["websockets"]:
if client != ws:
await client.send_str(msg.data)
cursor = request.app['db'].cursor()
cursor = request.app["db"].cursor()
data = json.loads(msg.data)
queries_broadcasted += 1
@ -89,28 +94,32 @@ async def websocket_handler(request):
cursor.close()
print("Queries broadcasted: " + str(queries_broadcasted))
elif msg.type == aiohttp.WSMsgType.ERROR:
print(f'WebSocket connection closed with exception {ws.exception()}')
print(f"WebSocket connection closed with exception {ws.exception()}")
request.app['websockets'].remove(ws)
request.app["websockets"].remove(ws)
return ws
app = web.Application()
app['websockets'] = []
app.router.add_get('/ws', websocket_handler)
app = web.Application()
app["websockets"] = []
app.router.add_get("/ws", websocket_handler)
async def on_startup(app):
app['db'] = sqlite3.connect('snek.db')
app["db"] = sqlite3.connect("snek.db")
print("Server starting...")
async def on_cleanup(app):
for ws in app['websockets']:
for ws in app["websockets"]:
await ws.close()
app['db'].close()
app["db"].close()
app.on_startup.append(on_startup)
app.on_cleanup.append(on_cleanup)
if __name__ == '__main__':
web.run_app(app, host='127.0.0.1', port=3131)
if __name__ == "__main__":
web.run_app(app, host="127.0.0.1", port=3131)

View File

@ -1,25 +1,30 @@
import asyncssh
import logging
from pathlib import Path
global _app
import asyncssh
global _app
def set_app(app):
global _app
_app = app
_app = app
def get_app():
return _app
logger = logging.getLogger(__name__)
roots = {}
class SFTPServer(asyncssh.SFTPServer):
def __init__(self, chan: asyncssh.SSHServerChannel):
self.root = get_app().services.user.get_home_folder_by_username(
chan.get_extra_info('username')
chan.get_extra_info("username")
)
self.root.mkdir(exist_ok=True)
self.root = str(self.root)
@ -30,20 +35,22 @@ class SFTPServer(asyncssh.SFTPServer):
logger.debug(f"Mapping client path {path} to {mapped_path}")
return str(mapped_path).encode()
class SSHServer(asyncssh.SSHServer):
def password_auth_supported(self):
return True
def validate_password(self, username, password):
logger.debug(f"Validating credentials for user {username}")
result = get_app().services.user.authenticate_sync(username,password)
result = get_app().services.user.authenticate_sync(username, password)
logger.info(f"Validating credentials for user {username}: {result}")
return result
async def start_ssh_server(app,host,port):
set_app(app)
async def start_ssh_server(app, host, port):
set_app(app)
logger.info("Starting SFTP server setup")
host_key_path = Path("drive") / ".ssh" / "sftp_server_key"
host_key_path.parent.mkdir(exist_ok=True, parents=True)
try:
@ -63,14 +70,12 @@ async def start_ssh_server(app,host,port):
x = await asyncssh.listen(
host=host,
port=port,
#process_factory=handle_client,
# process_factory=handle_client,
server_host_keys=[key],
server_factory=SSHServer,
sftp_factory=SFTPServer
sftp_factory=SFTPServer,
)
return x
except Exception as e:
except Exception:
logger.warning(f"Failed to start SFTP server. Already running.")
pass
pass

View File

@ -1,58 +1,71 @@
import yaml
import copy
import yaml
class ComposeFileManager:
def __init__(self, compose_path='docker-compose.yml'):
def __init__(self, compose_path="docker-compose.yml"):
self.compose_path = compose_path
self._load()
def _load(self):
try:
with open(self.compose_path, 'r') as f:
with open(self.compose_path) as f:
self.compose = yaml.safe_load(f) or {}
except FileNotFoundError:
self.compose = {'version': '3', 'services': {}}
self.compose = {"version": "3", "services": {}}
def _save(self):
with open(self.compose_path, 'w') as f:
with open(self.compose_path, "w") as f:
yaml.dump(self.compose, f, default_flow_style=False)
def list_instances(self):
return list(self.compose.get('services', {}).keys())
return list(self.compose.get("services", {}).keys())
def create_instance(self, name, image, command=None, cpus=None, memory=None, ports=None, volumes=None):
def create_instance(
self,
name,
image,
command=None,
cpus=None,
memory=None,
ports=None,
volumes=None,
):
service = {
'image': image,
"image": image,
}
if command:
service['command'] = command
service["command"] = command
if cpus or memory:
service['deploy'] = {'resources': {'limits': {}}}
service["deploy"] = {"resources": {"limits": {}}}
if cpus:
service['deploy']['resources']['limits']['cpus'] = str(cpus)
service["deploy"]["resources"]["limits"]["cpus"] = str(cpus)
if memory:
service['deploy']['resources']['limits']['memory'] = str(memory)
service["deploy"]["resources"]["limits"]["memory"] = str(memory)
if ports:
service['ports'] = [f"{host}:{container}" for container, host in ports.items()]
service["ports"] = [
f"{host}:{container}" for container, host in ports.items()
]
if volumes:
service['volumes'] = volumes
service["volumes"] = volumes
self.compose.setdefault('services', {})[name] = service
self.compose.setdefault("services", {})[name] = service
self._save()
def remove_instance(self, name):
if name in self.compose.get('services', {}):
del self.compose['services'][name]
if name in self.compose.get("services", {}):
del self.compose["services"][name]
self._save()
def get_instance(self, name):
return self.compose.get('services', {}).get(name)
return self.compose.get("services", {}).get(name)
def duplicate_instance(self, name, new_name):
orig = self.get_instance(name)
if not orig:
raise ValueError(f"No such instance: {name}")
self.compose['services'][new_name] = copy.deepcopy(orig)
self.compose["services"][new_name] = copy.deepcopy(orig)
self._save()
def update_instance(self, name, **kwargs):
@ -62,11 +75,12 @@ class ComposeFileManager:
for k, v in kwargs.items():
if v is not None:
service[k] = v
self.compose['services'][name] = service
self.compose["services"][name] = service
self._save()
# Storage size is not tracked in compose files; would need Docker API for that.
# Example usage:
# mgr = ComposeFileManager()
# mgr.create_instance('web', 'nginx:latest', cpus=1, memory='512m', ports={80:8080}, volumes=['./data:/data'])

View File

@ -1,6 +1,7 @@
DEFAULT_LIMIT = 30
import asyncio
import typing
import asyncio
from snek.system.model import BaseModel
@ -12,21 +13,21 @@ class BaseMapper:
def __init__(self, app):
self.app = app
self.semaphore = asyncio.Semaphore(1)
self.semaphore = asyncio.Semaphore(1)
self.default_limit = self.__class__.default_limit
@property
def db(self):
return self.app.db
@property
@property
def loop(self):
return asyncio.get_event_loop()
async def run_in_executor(self, func, *args, **kwargs):
async with self.semaphore:
return func(*args, **kwargs)
#return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs))
# return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs))
async def new(self):
return self.model_class(mapper=self, app=self.app)
@ -38,8 +39,8 @@ class BaseMapper:
async def get(self, uid: str = None, **kwargs) -> BaseModel:
if uid:
kwargs["uid"] = uid
record = await self.run_in_executor(self.table.find_one,**kwargs)
record = await self.run_in_executor(self.table.find_one, **kwargs)
if not record:
return None
record = dict(record)
@ -50,7 +51,7 @@ class BaseMapper:
return await self.model_class.from_record(mapper=self, record=record)
async def exists(self, **kwargs):
return await self.run_in_executor(self.table.exists,**kwargs)
return await self.run_in_executor(self.table.exists, **kwargs)
async def count(self, **kwargs) -> int:
return await self.run_in_executor(self.table.count, **kwargs)
@ -71,7 +72,7 @@ class BaseMapper:
yield model
async def query(self, sql, *args):
for record in await self.run_in_executor(self.db.query,sql, *args):
for record in await self.run_in_executor(self.db.query, sql, *args):
yield dict(record)
async def update(self, model):

View File

@ -1,5 +1,5 @@
# Original source: https://brandonjay.dev/posts/2021/render-markdown-html-in-python-with-jinja2
import re
import re
from types import SimpleNamespace
from app.cache import time_cache_async
@ -13,19 +13,20 @@ from pygments.lexers import get_lexer_by_name
def strip_markdown(md_text):
# Remove code blocks (
md_text = re.sub(r'[\s\S]?```', '', md_text)
md_text = re.sub(r'^\s{4,}.$', '', md_text, flags=re.MULTILINE)
md_text = re.sub(r'^\s{0,3}#{1,6}\s+', '', md_text, flags=re.MULTILINE)
md_text = re.sub(r'!\[.?\]\(.?\)', '', md_text)
md_text = re.sub(r'\[([^\]]+)\]\(.?\)', r'\1', md_text)
md_text = re.sub(r'(\*|_){1,3}(.+?)\1{1,3}', r'\2', md_text)
md_text = re.sub(r'^\s{0,3}>+\s?', '', md_text, flags=re.MULTILINE)
md_text = re.sub(r'^(\s)(\-{3,}|_{3,}|\{3,})\s$', '', md_text, flags=re.MULTILINE)
md_text = re.sub(r'[`~>#+\-=]', '', md_text)
md_text = re.sub(r'\s+', ' ', md_text)
# Remove code blocks (
md_text = re.sub(r"[\s\S]?```", "", md_text)
md_text = re.sub(r"^\s{4,}.$", "", md_text, flags=re.MULTILINE)
md_text = re.sub(r"^\s{0,3}#{1,6}\s+", "", md_text, flags=re.MULTILINE)
md_text = re.sub(r"!\[.?\]\(.?\)", "", md_text)
md_text = re.sub(r"\[([^\]]+)\]\(.?\)", r"\1", md_text)
md_text = re.sub(r"(\*|_){1,3}(.+?)\1{1,3}", r"\2", md_text)
md_text = re.sub(r"^\s{0,3}>+\s?", "", md_text, flags=re.MULTILINE)
md_text = re.sub(r"^(\s)(\-{3,}|_{3,}|\{3,})\s$", "", md_text, flags=re.MULTILINE)
md_text = re.sub(r"[`~>#+\-=]", "", md_text)
md_text = re.sub(r"\s+", " ", md_text)
return md_text.strip()
class MarkdownRenderer(HTMLRenderer):
_allow_harmful_protocols = False
@ -40,7 +41,7 @@ class MarkdownRenderer(HTMLRenderer):
formatter = html.HtmlFormatter()
self.env.globals["highlight_styles"] = formatter.get_style_defs()
#def _escape(self, str):
# def _escape(self, str):
# return str ##escape(str)
def get_lexer(self, lang, default="bash"):

View File

@ -6,9 +6,9 @@
# MIT License: This code is distributed under the MIT License.
from aiohttp import web
import secrets
from aiohttp import web
csp_policy = (
"default-src 'self'; "
@ -22,14 +22,16 @@ csp_policy = (
def generate_nonce():
return secrets.token_hex(16)
@web.middleware
async def csp_middleware(request, handler):
response = await handler(request)
return response
nonce = generate_nonce()
response.headers['Content-Security-Policy'] = csp_policy.format(nonce=nonce)
return response
nonce = generate_nonce()
response.headers["Content-Security-Policy"] = csp_policy.format(nonce=nonce)
return response
@web.middleware
async def no_cors_middleware(request, handler):

View File

@ -63,9 +63,11 @@ def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str:
obj = hashlib.sha256(salted)
return obj.hexdigest()
async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
return hash_sync(data, salt)
def verify_sync(string: str, hashed: str) -> bool:
"""Verify if the given string matches the hashed value.
@ -78,5 +80,6 @@ def verify_sync(string: str, hashed: str) -> bool:
"""
return hash_sync(string) == hashed
async def verify(string: str, hashed: str) -> bool:
return verify_sync(string, hashed)

View File

@ -1,17 +1,16 @@
import mimetypes
import re
from functools import lru_cache
from types import SimpleNamespace
from urllib.parse import urlparse, parse_qs
from app.cache import time_cache
from urllib.parse import parse_qs, urlparse
import bleach
import emoji
import requests
from app.cache import time_cache
from bs4 import BeautifulSoup
from jinja2 import TemplateSyntaxError, nodes
from jinja2.ext import Extension
from jinja2.nodes import Const
import bleach
emoji.EMOJI_DATA['<img src="/emoji/snek1.gif" />'] = {
"en": ":snek1:",
@ -81,13 +80,29 @@ emoji.EMOJI_DATA[
ALLOWED_TAGS = list(bleach.sanitizer.ALLOWED_TAGS) + [
"img", "video", "audio", "source", "iframe", "picture", "span"
"img",
"video",
"audio",
"source",
"iframe",
"picture",
"span",
]
ALLOWED_ATTRIBUTES = {
**bleach.sanitizer.ALLOWED_ATTRIBUTES,
"img": ["src", "alt", "title", "width", "height"],
"a": ["href", "title", "target", "rel", "referrerpolicy", "class"],
"iframe": ["src", "width", "height", "frameborder", "allow", "allowfullscreen", "title", "referrerpolicy", "style"],
"iframe": [
"src",
"width",
"height",
"frameborder",
"allow",
"allowfullscreen",
"title",
"referrerpolicy",
"style",
],
"video": ["src", "controls", "width", "height"],
"audio": ["src", "controls"],
"source": ["src", "type"],
@ -96,9 +111,6 @@ ALLOWED_ATTRIBUTES = {
}
def sanitize_html(value):
return bleach.clean(
value,
@ -109,7 +121,6 @@ def sanitize_html(value):
)
def set_link_target_blank(text):
soup = BeautifulSoup(text, "html.parser")
@ -118,26 +129,43 @@ def set_link_target_blank(text):
element.attrs["rel"] = "noopener noreferrer"
element.attrs["referrerpolicy"] = "no-referrer"
element.attrs["href"] = element.attrs["href"].strip(".").strip(",")
return str(soup)
SAFE_ATTRIBUTES = {
'href', 'src', 'alt', 'title', 'width', 'height', 'style', 'id', 'class',
'rel', 'type', 'name', 'value', 'placeholder', 'aria-hidden', 'aria-label', 'srcset'
"href",
"src",
"alt",
"title",
"width",
"height",
"style",
"id",
"class",
"rel",
"type",
"name",
"value",
"placeholder",
"aria-hidden",
"aria-label",
"srcset",
}
def whitelist_attributes(html):
soup = BeautifulSoup(html, 'html.parser')
soup = BeautifulSoup(html, "html.parser")
for tag in soup.find_all():
if hasattr(tag, 'attrs'):
if tag.name in ['script','form','input']:
tag.replace_with('')
if hasattr(tag, "attrs"):
if tag.name in ["script", "form", "input"]:
tag.replace_with("")
continue
attrs = dict(tag.attrs)
for attr in list(attrs):
# Check if attribute is in the safe list or is a data-* attribute
if not (attr in SAFE_ATTRIBUTES or attr.startswith('data-')):
if not (attr in SAFE_ATTRIBUTES or attr.startswith("data-")):
del tag.attrs[attr]
return str(soup)
@ -236,12 +264,12 @@ def enrich_image_rendering(text):
for element in soup.find_all("img"):
if element.attrs["src"].startswith("/"):
element.attrs["src"] += "?width=240&height=240"
picture_template = f'''
picture_template = f"""
<picture>
<source srcset="{element.attrs["src"]}" type="{mimetypes.guess_type(element.attrs["src"])[0]}" />
<source srcset="{element.attrs["src"]}&format=webp" type="image/webp" />
<img src="{element.attrs["src"]}&format=png" title="{element.attrs["src"]}" alt="{element.attrs["src"]}" />
</picture>'''
</picture>"""
element.replace_with(BeautifulSoup(picture_template, "html.parser"))
return str(soup)
@ -285,7 +313,7 @@ def linkify_https(text):
return set_link_target_blank(str(soup))
@time_cache(timeout=60*60)
@time_cache(timeout=60 * 60)
def get_url_content(url):
try:
response = requests.get(url, timeout=5)
@ -368,12 +396,11 @@ def embed_url(text):
page_video = get_element_options(None, "video", "video", "video")
page_audio = get_element_options(None, "audio", "audio", "audio")
preview_size = (
(
get_element_options(None, None, None, "card")
or "summary_large_image"
)
attachment_base = BeautifulSoup(str(element), "html.parser")
attachments[original_link_name] = attachment_base
@ -412,20 +439,28 @@ def embed_url(text):
)
description_element.append(
BeautifulSoup(f'<strong class="page-name">{page_name}</strong>', "html.parser")
BeautifulSoup(
f'<strong class="page-name">{page_name}</strong>',
"html.parser",
)
)
description_element.append(
BeautifulSoup(f"<p class='page-description'>{page_description or "No description available."}</p>", "html.parser")
BeautifulSoup(
f"<p class='page-description'>{page_description or "No description available."}</p>",
"html.parser",
)
)
description_element.append(
BeautifulSoup(f"<p class='page-original-link'>{original_link_name}</p>", "html.parser")
BeautifulSoup(
f"<p class='page-original-link'>{original_link_name}</p>",
"html.parser",
)
)
render_element.append(description_element_base)
for attachment in attachments.values():
soup.append(attachment)

View File

@ -1,22 +1,24 @@
class WebSocketClient:
def __init__(self, hostname, port):
self.buffer = b''
self.buffer = b""
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.hostname = hostname
self.hostname = hostname
self.port = port
self.connect()
self.connect()
def __getattr__(self, method, *args, **kwargs):
if method in self.__dict__.keys():
return self.__dict__[method]
def call(*args, **kwargs):
self.write(json.dumps({'method': method, 'args': args, 'kwargs': kwargs}))
self.write(json.dumps({"method": method, "args": args, "kwargs": kwargs}))
return json.loads(self.read())
return call
def connect(self):
self.socket.connect((self.hostname, self.port))
key = base64.b64encode(b'1234123412341234').decode('utf-8')
key = base64.b64encode(b"1234123412341234").decode("utf-8")
handshake = (
f"GET /db HTTP/1.1\r\n"
f"Host: localhost:3131\r\n"
@ -25,57 +27,58 @@ class WebSocketClient:
f"Sec-WebSocket-Key: {key}\r\n"
f"Sec-WebSocket-Version: 13\r\n\r\n"
)
self.socket.sendall(handshake.encode('utf-8'))
response = self.read_until(b'\r\n\r\n')
if b'101 Switching Protocols' not in response:
self.socket.sendall(handshake.encode("utf-8"))
response = self.read_until(b"\r\n\r\n")
if b"101 Switching Protocols" not in response:
raise Exception("Failed to connect to WebSocket")
def write(self, message):
message_bytes = message.encode('utf-8')
message_bytes = message.encode("utf-8")
length = len(message_bytes)
if length <= 125:
self.socket.sendall(b'\x81' + bytes([length]) + message_bytes)
self.socket.sendall(b"\x81" + bytes([length]) + message_bytes)
elif length >= 126 and length <= 65535:
self.socket.sendall(b'\x81' + bytes([126]) + length.to_bytes(2, 'big') + message_bytes)
self.socket.sendall(
b"\x81" + bytes([126]) + length.to_bytes(2, "big") + message_bytes
)
else:
self.socket.sendall(b'\x81' + bytes([127]) + length.to_bytes(8, 'big') + message_bytes)
self.socket.sendall(
b"\x81" + bytes([127]) + length.to_bytes(8, "big") + message_bytes
)
def read_until(self, delimiter):
def read_until(self, delimiter):
while True:
find_pos = self.buffer.find(delimiter)
if find_pos != -1:
data = self.buffer[:find_pos+4]
self.buffer = self.buffer[find_pos+4:]
return data
data = self.buffer[: find_pos + 4]
self.buffer = self.buffer[find_pos + 4 :]
return data
chunk = self.socket.recv(1024)
if not chunk:
return None
self.buffer += chunk
def read_exactly(self, length):
while len(self.buffer) < length:
chunk = self.socket.recv(length - len(self.buffer))
if not chunk:
return None
self.buffer += chunk
response = self.buffer[: length]
self.buffer += chunk
response = self.buffer[:length]
self.buffer = self.buffer[length:]
return response
def read(self):
frame = None
frame = None
frame = self.read_exactly(2)
length = frame[1] & 127
if length == 126:
length = int.from_bytes(self.read_exactly(2), 'big')
length = int.from_bytes(self.read_exactly(2), "big")
elif length == 127:
length = int.from_bytes(self.read_exactly(8), 'big')
length = int.from_bytes(self.read_exactly(8), "big")
message = self.read_exactly(length)
return message
def close(self):
self.socket.close()

View File

@ -31,21 +31,17 @@ from multiavatar import multiavatar
from snek.system.view import BaseView
from snek.view.avatar_animal import generate_avatar_with_options
import functools
class AvatarView(BaseView):
login_required = False
def __init__(self, *args,**kwargs):
super().__init__(*args,**kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.avatars = {}
async def get(self):
type_ = self.request.query.get("type")
type_match = {
"animal": self.get_animal,
"default": self.get_default
}
type_match = {"animal": self.get_animal, "default": self.get_default}
handler = type_match.get(type_, self.get_default)
return await handler()
@ -64,10 +60,9 @@ class AvatarView(BaseView):
avatar = generate_avatar_with_options(self.request.query)
await self.app.set(key, avatar)
return web.Response(text=avatar, content_type="image/svg+xml")
except Exception as e:
except Exception:
pass
def _get(self, uid):
avatar = generate_avatar_with_options(self.request.query)
return avatar

File diff suppressed because it is too large Load Diff

View File

@ -31,13 +31,12 @@ from multiavatar import multiavatar
from snek.system.view import BaseView
from snek.view.avatar_animal import generate_avatar_with_options
import functools
class AvatarView(BaseView):
login_required = False
def __init__(self, *args,**kwargs):
super().__init__(*args,**kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.avatars = {}
async def get(self):
@ -45,10 +44,9 @@ class AvatarView(BaseView):
while True:
try:
return web.Response(text=self._get(uid), content_type="image/svg+xml")
except Exception as e:
except Exception:
pass
def _get(self, uid):
if uid in self.avatars:
return self.avatars[uid]

View File

@ -1,15 +1,15 @@
import asyncio
import mimetypes
import pathlib
import urllib.parse
from os.path import isfile
from PIL import Image
import pillow_heif.HeifImagePlugin
from snek.system.view import BaseView
import aiofiles
from aiohttp import web
import pathlib
import urllib.parse
from PIL import Image
from snek.system.view import BaseView
class ChannelAttachmentView(BaseView):
async def get(self):
@ -73,18 +73,21 @@ class ChannelAttachmentView(BaseView):
# response.write is async but image.save is not
naughty_steal = response.write
tasks = []
def sync_writer(*args, **kwargs):
tasks.append(naughty_steal(*args, **kwargs))
return True
setattr(response, "write", sync_writer)
image.save(response, format=format_, quality=100, optimize=True, save_all=True)
image.save(
response, format=format_, quality=100, optimize=True, save_all=True
)
setattr(response, "write", naughty_steal)
await asyncio.gather(*tasks)
return response
else:
response = web.FileResponse(channel_attachment["path"])
@ -127,7 +130,9 @@ class ChannelAttachmentView(BaseView):
attachment_records = []
for attachment in attachments:
attachment_record = attachment.record
attachment_record['relative_url'] = urllib.parse.quote(attachment_record['relative_url'])
attachment_record["relative_url"] = urllib.parse.quote(
attachment_record["relative_url"]
)
attachment_records.append(attachment_record)
return web.json_response(
@ -138,23 +143,37 @@ class ChannelAttachmentView(BaseView):
}
)
class ChannelView(BaseView):
async def get(self):
channel_name = self.request.match_info.get("channel")
if(channel_name is None):
if channel_name is None:
return web.HTTPNotFound()
channel = await self.services.channel.get(label="#" + channel_name)
if(channel is None):
if channel is None:
channel = await self.services.channel.get(label=channel_name)
channel = await self.services.channel.get(channel_name)
if(channel is None):
if channel is None:
channel = await self.services.channel.get(label=channel_name)
if(channel is None):
user = await self.services.user.get(uid=self.session.get("uid"))
if channel is None:
user = await self.services.user.get(uid=self.session.get("uid"))
is_listed = self.request.query.get("listed", False) == "true"
is_private = self.request.query.get("private", False) == "true"
channel = await self.services.channel.create(label=channel_name,created_by_uid=user['uid'],description="No description provided.",tag="user",is_private=is_private,is_listed=is_listed)
channel_member = await self.services.channel_member.create(channel_uid=channel['uid'],user_uid=user['uid'],is_moderator=True,is_read_only=False,is_muted=False,is_banned=False)
return web.HTTPFound("/channel/{}.html".format(channel["uid"]))
channel = await self.services.channel.create(
label=channel_name,
created_by_uid=user["uid"],
description="No description provided.",
tag="user",
is_private=is_private,
is_listed=is_listed,
)
await self.services.channel_member.create(
channel_uid=channel["uid"],
user_uid=user["uid"],
is_moderator=True,
is_read_only=False,
is_muted=False,
is_banned=False,
)
return web.HTTPFound("/channel/{}.html".format(channel["uid"]))

View File

@ -1,20 +1,15 @@
import mimetypes
import os
import urllib.parse
from datetime import datetime
from pathlib import Path
from urllib.parse import quote, unquote
from aiohttp import web
from snek.system.view import BaseView
import os
import mimetypes
from aiohttp import web
from urllib.parse import unquote, quote
from datetime import datetime
from aiohttp import web
from pathlib import Path
import mimetypes, urllib.parse
class DriveView(BaseView):
async def get(self):
@ -27,7 +22,7 @@ class DriveView(BaseView):
return web.HTTPNotFound(reason="Path not found")
if target.is_dir():
return await self.render_template("drive.html",{"path": rel_path})
return await self.render_template("drive.html", {"path": rel_path})
if target.is_file():
return web.FileResponse(target)
@ -37,7 +32,7 @@ class DriveApiView(BaseView):
target = await self.services.user.get_home_folder(self.session.get("uid"))
rel = self.request.query.get("path", "")
offset = int(self.request.query.get("offset", 0))
limit = int(self.request.query.get("limit", 20))
limit = int(self.request.query.get("limit", 20))
if rel:
target = target.joinpath(rel)
@ -47,58 +42,54 @@ class DriveApiView(BaseView):
if target.is_dir():
entries = []
for p in sorted(target.iterdir(), key=lambda p: (p.is_file(), p.name.lower())):
for p in sorted(
target.iterdir(), key=lambda p: (p.is_file(), p.name.lower())
):
item_path = (Path(rel) / p.name).as_posix()
mime = mimetypes.guess_type(p.name)[0] if p.is_file() else "inode/directory"
url = (self.request.url.with_path(f"/drive/{urllib.parse.quote(item_path)}")
if p.is_file() else None)
entries.append({
"name": p.name,
"type": "directory" if p.is_dir() else "file",
"mimetype": mime,
"size": p.stat().st_size if p.is_file() else None,
"path": item_path,
"url": url,
})
import json
mime = (
mimetypes.guess_type(p.name)[0]
if p.is_file()
else "inode/directory"
)
url = (
self.request.url.with_path(
f"/drive/{urllib.parse.quote(item_path)}"
)
if p.is_file()
else None
)
entries.append(
{
"name": p.name,
"type": "directory" if p.is_dir() else "file",
"mimetype": mime,
"size": p.stat().st_size if p.is_file() else None,
"path": item_path,
"url": url,
}
)
import json
total = len(entries)
items = entries[offset:offset+limit]
return web.json_response({
"items": json.loads(json.dumps(items,default=str)),
"pagination": {"offset": offset, "limit": limit, "total": total}
})
items = entries[offset : offset + limit]
return web.json_response(
{
"items": json.loads(json.dumps(items, default=str)),
"pagination": {"offset": offset, "limit": limit, "total": total},
}
)
url = self.request.url.with_path(f"/drive/{urllib.parse.quote(rel)}")
return web.json_response({
"name": target.name,
"type": "file",
"mimetype": mimetypes.guess_type(target.name)[0],
"size": target.stat().st_size,
"path": rel,
"url": str(url),
})
return web.json_response(
{
"name": target.name,
"type": "file",
"mimetype": mimetypes.guess_type(target.name)[0],
"size": target.stat().st_size,
"path": rel,
"url": str(url),
}
)
class DriveView222(BaseView):
@ -124,7 +115,11 @@ class DriveView222(BaseView):
entry_path = os.path.join(dir_path, entry)
stat = os.stat(entry_path)
is_dir = os.path.isdir(entry_path)
mimetype = None if is_dir else (mimetypes.guess_type(entry_path)[0] or "application/octet-stream")
mimetype = (
None
if is_dir
else (mimetypes.guess_type(entry_path)[0] or "application/octet-stream")
)
size = stat.st_size if not is_dir else None
created_at = datetime.fromtimestamp(stat.st_ctime).isoformat()
updated_at = datetime.fromtimestamp(stat.st_mtime).isoformat()
@ -155,15 +150,20 @@ class DriveView222(BaseView):
start = (page - 1) * page_size
end = start + page_size
paged_entries = entries[start:end]
details = [await self.entry_details(full_path, entry, rel_path) for entry in paged_entries]
return web.json_response({
"path": rel_path,
"absolute_url": abs_url,
"entries": details,
"total": len(entries),
"page": page,
"page_size": page_size,
})
details = [
await self.entry_details(full_path, entry, rel_path)
for entry in paged_entries
]
return web.json_response(
{
"path": rel_path,
"absolute_url": abs_url,
"entries": details,
"total": len(entries),
"page": page,
"page_size": page_size,
}
)
else:
with open(full_path, "rb") as f:
content = f.read()
@ -180,14 +180,18 @@ class DriveView222(BaseView):
data = await self.request.post()
if data.get("type") == "dir":
os.makedirs(full_path)
return web.json_response({"status": "created", "type": "dir", "absolute_url": abs_url})
return web.json_response(
{"status": "created", "type": "dir", "absolute_url": abs_url}
)
else:
file_field = data.get("file")
if not file_field:
raise web.HTTPBadRequest(reason="No file uploaded")
with open(full_path, "wb") as f:
f.write(file_field.file.read())
return web.json_response({"status": "created", "type": "file", "absolute_url": abs_url})
return web.json_response(
{"status": "created", "type": "file", "absolute_url": abs_url}
)
async def put(self):
rel_path = self.request.match_info.get("rel_path", "")
@ -210,10 +214,14 @@ class DriveView222(BaseView):
raise web.HTTPNotFound(reason="Path not found")
if os.path.isdir(full_path):
os.rmdir(full_path)
return web.json_response({"status": "deleted", "type": "dir", "absolute_url": abs_url})
return web.json_response(
{"status": "deleted", "type": "dir", "absolute_url": abs_url}
)
else:
os.remove(full_path)
return web.json_response({"status": "deleted", "type": "file", "absolute_url": abs_url})
return web.json_response(
{"status": "deleted", "type": "file", "absolute_url": abs_url}
)
class DriveViewi2(BaseView):
@ -223,20 +231,17 @@ class DriveViewi2(BaseView):
async def get(self):
drive_uid = self.request.match_info.get("drive")
before = self.request.query.get("before")
filters = {}
filters = {}
if before:
filters["created_at__lt"] = before
if drive_uid:
filters['drive_uid'] = drive_uid
filters["drive_uid"] = drive_uid
drive = await self.services.drive.get(uid=drive_uid)
drive_items = []
async for item in self.services.drive_item.find(**filters):
record = item.record
record["url"] = "/drive.bin/" + record["uid"] + "." + item.extension

View File

@ -82,4 +82,4 @@ class PushView(BaseFormView):
print(post_notification.text)
print(post_notification.headers)
return await self.json_response({ "registered": True })
return await self.json_response({"registered": True})

View File

@ -1,14 +1,13 @@
import os
import asyncio
import mimetypes
import os
import urllib.parse
from pathlib import Path
import humanize
from aiohttp import web
from snek.system.view import BaseView
import asyncio
class BareRepoNavigator:
@ -25,18 +24,18 @@ class BareRepoNavigator:
except Exception as e:
print(f"Error opening repository: {str(e)}")
sys.exit(1)
self.repo_path = repo_path
self.branches = list(self.repo.branches)
self.current_branch = None
self.current_commit = None
self.current_path = ""
self.history = []
def get_branches(self):
"""Return a list of branch names in the repository."""
return [branch.name for branch in self.branches]
def set_branch(self, branch_name):
"""Set the current branch."""
try:
@ -47,23 +46,27 @@ class BareRepoNavigator:
return True
except IndexError:
return False
def get_commits(self, count=10):
"""Get the latest commits on the current branch."""
if not self.current_branch:
return []
commits = []
for commit in self.repo.iter_commits(self.current_branch, max_count=count):
commits.append({
'hash': commit.hexsha,
'short_hash': commit.hexsha[:7],
'message': commit.message.strip(),
'author': commit.author.name,
'date': datetime.fromtimestamp(commit.committed_date).strftime('%Y-%m-%d %H:%M:%S')
})
commits.append(
{
"hash": commit.hexsha,
"short_hash": commit.hexsha[:7],
"message": commit.message.strip(),
"author": commit.author.name,
"date": datetime.fromtimestamp(commit.committed_date).strftime(
"%Y-%m-%d %H:%M:%S"
),
}
)
return commits
def set_commit(self, commit_hash):
"""Set the current commit by hash."""
try:
@ -73,48 +76,48 @@ class BareRepoNavigator:
return True
except ValueError:
return False
def list_directory(self, path=""):
"""List the contents of a directory in the current commit."""
if not self.current_commit:
return {'dirs': [], 'files': []}
return {"dirs": [], "files": []}
dirs = []
files = []
try:
# Get the tree at the current path
if path:
tree = self.current_commit.tree[path]
if not hasattr(tree, 'trees'): # It's a blob, not a tree
return {'dirs': [], 'files': [path]}
if not hasattr(tree, "trees"): # It's a blob, not a tree
return {"dirs": [], "files": [path]}
else:
tree = self.current_commit.tree
# List directories and files
for item in tree:
if item.type == 'tree':
if item.type == "tree":
item_path = os.path.join(path, item.name) if path else item.name
dirs.append(item_path)
elif item.type == 'blob':
elif item.type == "blob":
item_path = os.path.join(path, item.name) if path else item.name
files.append(item_path)
dirs.sort()
files.sort()
return {'dirs': dirs, 'files': files}
return {"dirs": dirs, "files": files}
except KeyError:
return {'dirs': [], 'files': []}
return {"dirs": [], "files": []}
def get_file_content(self, file_path):
"""Get the content of a file in the current commit."""
if not self.current_commit:
return None
try:
blob = self.current_commit.tree[file_path]
return blob.data_stream.read().decode('utf-8', errors='replace')
return blob.data_stream.read().decode("utf-8", errors="replace")
except (KeyError, UnicodeDecodeError):
try:
# Try to get as binary if text decoding fails
@ -122,12 +125,12 @@ class BareRepoNavigator:
return blob.data_stream.read()
except:
return None
def navigate_to(self, path):
"""Navigate to a specific path, updating the current path."""
if not self.current_commit:
return False
try:
if path:
self.current_commit.tree[path] # Check if path exists
@ -136,7 +139,7 @@ class BareRepoNavigator:
return True
except KeyError:
return False
def navigate_back(self):
"""Navigate back to the previous path."""
if self.history:
@ -145,12 +148,13 @@ class BareRepoNavigator:
return False
class RepositoryView(BaseView):
login_required = True
def checkout_bare_repo(self, bare_repo_path: Path, target_path: Path, ref: str = 'HEAD'):
def checkout_bare_repo(
self, bare_repo_path: Path, target_path: Path, ref: str = "HEAD"
):
repo = Repo(bare_repo_path)
assert repo.bare, "Repository is not bare."
@ -159,23 +163,22 @@ class RepositoryView(BaseView):
for blob in tree.traverse():
target_file = target_path / blob.path
target_file.parent.mkdir(parents=True, exist_ok=True)
print(blob.path)
with open(target_file, 'wb') as f:
with open(target_file, "wb") as f:
f.write(blob.data_stream.read())
async def get(self):
base_repo_path = Path("drive/repositories")
base_repo_path = Path("drive/repositories")
authenticated_user_id = self.session.get("uid")
username = self.request.match_info.get('username')
repo_name = self.request.match_info.get('repository')
rel_path = self.request.match_info.get('path', '')
username = self.request.match_info.get("username")
repo_name = self.request.match_info.get("repository")
rel_path = self.request.match_info.get("path", "")
user = None
if not username.count("-") == 4:
user = await self.app.services.user.get(username=username)
@ -185,22 +188,24 @@ class RepositoryView(BaseView):
else:
user = await self.app.services.user.get(uid=username)
repo = await self.app.services.repository.get(name=repo_name, user_uid=user["uid"])
repo = await self.app.services.repository.get(
name=repo_name, user_uid=user["uid"]
)
if not repo:
return web.Response(text="404 Not Found", status=404)
if repo['is_private'] and authenticated_user_id != repo['uid']:
return web.Response(text="404 Not Found", status=404)
if repo["is_private"] and authenticated_user_id != repo["uid"]:
return web.Response(text="404 Not Found", status=404)
repo_root_base = (base_repo_path / user['uid'] / (repo_name + ".git")).resolve()
repo_root = (base_repo_path / user['uid'] / repo_name).resolve()
repo_root_base = (base_repo_path / user["uid"] / (repo_name + ".git")).resolve()
repo_root = (base_repo_path / user["uid"] / repo_name).resolve()
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None,
self.checkout_bare_repo, repo_root_base, repo_root
await loop.run_in_executor(
None, self.checkout_bare_repo, repo_root_base, repo_root
)
except:
pass
if not repo_root.exists() or not repo_root.is_dir():
return web.Response(text="404 Not Found", status=404)
@ -211,24 +216,35 @@ class RepositoryView(BaseView):
return web.Response(text="404 Not Found", status=404)
if abs_path.is_dir():
return web.Response(text=self.render_directory(abs_path, username, repo_name, safe_rel_path), content_type='text/html')
return web.Response(
text=self.render_directory(
abs_path, username, repo_name, safe_rel_path
),
content_type="text/html",
)
else:
return web.Response(text=self.render_file(abs_path), content_type='text/html')
return web.Response(
text=self.render_file(abs_path), content_type="text/html"
)
def render_directory(self, abs_path, username, repo_name, safe_rel_path):
entries = sorted(abs_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()))
entries = sorted(
abs_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())
)
items = []
if safe_rel_path:
parent_path = Path(safe_rel_path).parent
parent_link = f"/repository/{username}/{repo_name}/{parent_path}".rstrip('/')
parent_link = f"/repository/{username}/{repo_name}/{parent_path}".rstrip(
"/"
)
items.append(f'<li><a href="{parent_link}">⬅️ ..</a></li>')
for entry in entries:
link_path = urllib.parse.quote(str(Path(safe_rel_path) / entry.name))
link = f"/repository/{username}/{repo_name}/{link_path}".rstrip('/')
display = entry.name + ('/' if entry.is_dir() else '')
size = '' if entry.is_dir() else humanize.naturalsize(entry.stat().st_size)
link = f"/repository/{username}/{repo_name}/{link_path}".rstrip("/")
display = entry.name + ("/" if entry.is_dir() else "")
size = "" if entry.is_dir() else humanize.naturalsize(entry.stat().st_size)
icon = self.get_icon(entry)
items.append(f'<li>{icon} <a href="{link}">{display}</a> {size}</li>')
@ -247,19 +263,24 @@ class RepositoryView(BaseView):
def render_file(self, abs_path):
try:
with open(abs_path, 'r', encoding='utf-8', errors='ignore') as f:
with open(abs_path, encoding="utf-8", errors="ignore") as f:
content = f.read()
return f"<pre>{content}</pre>"
except Exception as e:
return f"<h1>Error</h1><pre>{e}</pre>"
def get_icon(self, file):
if file.is_dir(): return "📁"
mime = mimetypes.guess_type(file.name)[0] or ''
if mime.startswith("image"): return "🖼️"
if mime.startswith("text"): return "📄"
if mime.startswith("audio"): return "🎵"
if mime.startswith("video"): return "🎬"
if file.name.endswith(".py"): return "🐍"
if file.is_dir():
return "📁"
mime = mimetypes.guess_type(file.name)[0] or ""
if mime.startswith("image"):
return "🖼️"
if mime.startswith("text"):
return "📄"
if mime.startswith("audio"):
return "🎵"
if mime.startswith("video"):
return "🎬"
if file.name.endswith(".py"):
return "🐍"
return "📦"

View File

@ -7,19 +7,20 @@
# MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions.
import json
import traceback
import asyncio
import json
import logging
import traceback
from aiohttp import web
from snek.system.model import now
from snek.system.profiler import Profiler
from snek.system.view import BaseView
import logging
logger = logging.getLogger(__name__)
class RPCView(BaseView):
class RPCApi:
def __init__(self, view, ws):
@ -32,7 +33,7 @@ class RPCView(BaseView):
async def _session_ensure(self):
uid = await self.view.session_get("uid")
if not uid in self.user_session:
if uid not in self.user_session:
self.user_session[uid] = {
"said_hello": False,
}
@ -50,45 +51,54 @@ class RPCView(BaseView):
self._require_login()
return await self.services.db.insert(self.user_uid, table_name, record)
async def db_update(self, table_name, record):
self._require_login()
return await self.services.db.update(self.user_uid, table_name, record)
async def set_typing(self,channel_uid,color=None):
async def set_typing(self, channel_uid, color=None):
self._require_login()
user = await self.services.user.get(self.user_uid)
if not color:
color = user["color"]
return await self.services.socket.broadcast(channel_uid, {
"channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2",
"event": "set_typing",
"data": {
"event":"set_typing",
"user_uid": user['uid'],
"username": user["username"],
"nick": user["nick"],
"channel_uid": channel_uid,
"color": color
}
})
return await self.services.socket.broadcast(
channel_uid,
{
"channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2",
"event": "set_typing",
"data": {
"event": "set_typing",
"user_uid": user["uid"],
"username": user["username"],
"nick": user["nick"],
"channel_uid": channel_uid,
"color": color,
},
},
)
async def db_delete(self, table_name, record):
self._require_login()
return await self.services.db.delete(self.user_uid, table_name, record)
async def db_get(self, table_name, record):
self._require_login()
return await self.services.db.get(self.user_uid, table_name, record)
async def db_find(self, table_name, record):
self._require_login()
return await self.services.db.find(self.user_uid, table_name, record)
async def db_upsert(self, table_name, record,keys):
async def db_upsert(self, table_name, record, keys):
self._require_login()
return await self.services.db.upsert(self.user_uid, table_name, record,keys)
return await self.services.db.upsert(
self.user_uid, table_name, record, keys
)
async def db_query(self, table_name, args):
self._require_login()
return await self.services.db.query(self.user_uid, table_name, sql, args)
@property
def user_uid(self):
return self.view.session.get("uid")
@ -195,52 +205,60 @@ class RPCView(BaseView):
async def finalize_message(self, message_uid):
self._require_login()
message = await self.services.channel_message.get(message_uid)
message = await self.services.channel_message.get(message_uid)
if not message:
return False
return False
if message["user_uid"] != self.user_uid:
raise Exception("Not allowed")
if not message['is_final']:
await self.services.chat.finalize(message['uid'])
if not message["is_final"]:
await self.services.chat.finalize(message["uid"])
return True
async def update_message_text(self,message_uid, text):
async def update_message_text(self, message_uid, text):
self._require_login()
message = await self.services.channel_message.get(message_uid)
message = await self.services.channel_message.get(message_uid)
if message["user_uid"] != self.user_uid:
raise Exception("Not allowed")
if message.get_seconds_since_last_update() > 3:
await self.finalize_message(message["uid"])
return {"error": "Message too old","seconds_since_last_update": message.get_seconds_since_last_update(),"success": False}
return {
"error": "Message too old",
"seconds_since_last_update": message.get_seconds_since_last_update(),
"success": False,
}
message['message'] = text
message["message"] = text
if not text:
message['deleted_at'] = now()
message["deleted_at"] = now()
else:
message['deleted_at'] = None
message["deleted_at"] = None
await self.services.channel_message.save(message)
data = message.record
data['text'] = message["message"]
data['message_uid'] = message_uid
data = message.record
data["text"] = message["message"]
data["message_uid"] = message_uid
await self.services.socket.broadcast(
message["channel_uid"],
{
"channel_uid": message["channel_uid"],
"event": "update_message_text",
"data": message.record,
},
)
await self.services.socket.broadcast(message["channel_uid"], {
"channel_uid": message["channel_uid"],
"event": "update_message_text",
"data": message.record
})
return {"success": True}
async def send_message(self, channel_uid, message,is_final=True):
async def send_message(self, channel_uid, message, is_final=True):
self._require_login()
message = await self.services.chat.send(self.user_uid, channel_uid, message,is_final)
message = await self.services.chat.send(
self.user_uid, channel_uid, message, is_final
)
return message["uid"]
async def echo(self, *args):
@ -330,7 +348,8 @@ class RPCView(BaseView):
self._require_login()
results = [
record async for record in self.services.channel.get_recent_users(channel_uid)
record
async for record in self.services.channel.get_recent_users(channel_uid)
]
results = sorted(results, key=lambda x: x["nick"])
return results
@ -371,7 +390,6 @@ class RPCView(BaseView):
await self.services.socket.send_to_user(self.user_uid, call)
self._scheduled.remove(call)
async def ping(self, callId, *args):
if self.user_uid:
user = await self.services.user.get(uid=self.user_uid)
@ -379,17 +397,23 @@ class RPCView(BaseView):
await self.services.user.save(user)
return {"pong": args}
async def stars_render(self, channel_uid, message):
for user in await self.get_online_users(channel_uid):
try:
await self.services.socket.send_to_user(user['uid'], dict(event="stars_render", data={"channel_uid": channel_uid, "message":message}))
await self.services.socket.send_to_user(
user["uid"],
{
"event": "stars_render",
"data": {"channel_uid": channel_uid, "message": message},
},
)
except Exception as ex:
print(ex)
async def get(self):
scheduled = []
async def schedule(uid, seconds, call):
scheduled.append(call)
await asyncio.sleep(seconds)
@ -408,16 +432,18 @@ class RPCView(BaseView):
await self.services.socket.subscribe(
ws, subscription["channel_uid"], self.request.session.get("uid")
)
if not scheduled and self.request.app.uptime_seconds < 5:
await schedule(self.request.session.get("uid"),0,{"event":"refresh", "data": {
"message": "Finishing deployment"}
}
)
await schedule(self.request.session.get("uid"),15,{"event": "deployed", "data": {
"uptime": self.request.app.uptime}
}
if not scheduled and self.request.app.uptime_seconds < 5:
await schedule(
self.request.session.get("uid"),
0,
{"event": "refresh", "data": {"message": "Finishing deployment"}},
)
await schedule(
self.request.session.get("uid"),
15,
{"event": "deployed", "data": {"uptime": self.request.app.uptime}},
)
rpc = RPCView.RPCApi(self, ws)
async for msg in ws:
if msg.type == web.WSMsgType.TEXT:

View File

@ -1,45 +1,48 @@
import asyncio
from aiohttp import web
from snek.system.view import BaseFormView
import pathlib
class ContainersIndexView(BaseFormView):
login_required = True
async def get(self):
user_uid = self.session.get("uid")
containers = []
async for container in self.services.container.find(user_uid=user_uid):
containers.append(container.record)
user = await self.services.user.get(uid=self.session.get("uid"))
return await self.render_template("settings/containers/index.html", {"containers": containers, "user": user})
return await self.render_template(
"settings/containers/index.html", {"containers": containers, "user": user}
)
class ContainersCreateView(BaseFormView):
login_required = True
async def get(self):
return await self.render_template("settings/containers/create.html")
async def post(self):
data = await self.request.post()
container = await self.services.container.create(
await self.services.container.create(
user_uid=self.session.get("uid"),
name=data['name'],
status=data['status'],
resources=data.get('resources', ''),
path=data.get('path', ''),
readonly=bool(data.get('readonly', False))
name=data["name"],
status=data["status"],
resources=data.get("resources", ""),
path=data.get("path", ""),
readonly=bool(data.get("readonly", False)),
)
return web.HTTPFound("/settings/containers/index.html")
class ContainersUpdateView(BaseFormView):
login_required = True
@ -51,40 +54,43 @@ class ContainersUpdateView(BaseFormView):
)
if not container:
return web.HTTPNotFound()
return await self.render_template("settings/containers/update.html", {"container": container.record})
return await self.render_template(
"settings/containers/update.html", {"container": container.record}
)
async def post(self):
data = await self.request.post()
container = await self.services.container.get(
user_uid=self.session.get("uid"), uid=self.request.match_info["uid"]
)
container['status'] = data['status']
container['resources'] = data.get('resources', '')
container['path'] = data.get('path', '')
container['readonly'] = bool(data.get('readonly', False))
container["status"] = data["status"]
container["resources"] = data.get("resources", "")
container["path"] = data.get("path", "")
container["readonly"] = bool(data.get("readonly", False))
await self.services.container.save(container)
return web.HTTPFound("/settings/containers/index.html")
class ContainersDeleteView(BaseFormView):
login_required = True
async def get(self):
container = await self.services.container.get(
user_uid=self.session.get("uid"), uid=self.request.match_info["uid"]
)
if not container:
return web.HTTPNotFound()
return await self.render_template("settings/containers/delete.html", {"container": container.record})
return await self.render_template(
"settings/containers/delete.html", {"container": container.record}
)
async def post(self):
user_uid = self.session.get("uid")
uid = self.request.match_info["uid"]
container = await self.services.container.get(
user_uid=user_uid, uid=uid
)
container = await self.services.container.get(user_uid=user_uid, uid=uid)
if not container:
return web.HTTPNotFound()
await self.services.container.delete(user_uid=user_uid, uid=uid)

View File

@ -1,26 +1,26 @@
import asyncio
from aiohttp import web
from snek.system.view import BaseFormView
import pathlib
class RepositoriesIndexView(BaseFormView):
login_required = True
async def get(self):
user_uid = self.session.get("uid")
repositories = []
async for repository in self.services.repository.find(user_uid=user_uid):
repositories.append(repository.record)
user = await self.services.user.get(uid=self.session.get("uid"))
return await self.render_template("settings/repositories/index.html", {"repositories": repositories, "user": user})
return await self.render_template(
"settings/repositories/index.html",
{"repositories": repositories, "user": user},
)
class RepositoriesCreateView(BaseFormView):
@ -28,14 +28,19 @@ class RepositoriesCreateView(BaseFormView):
login_required = True
async def get(self):
return await self.render_template("settings/repositories/create.html")
async def post(self):
data = await self.request.post()
repository = await self.services.repository.create(user_uid=self.session.get("uid"), name=data['name'], is_private=int(data.get('is_private',0)))
await self.services.repository.create(
user_uid=self.session.get("uid"),
name=data["name"],
is_private=int(data.get("is_private", 0)),
)
return web.HTTPFound("/settings/repositories/index.html")
class RepositoriesUpdateView(BaseFormView):
login_required = True
@ -47,40 +52,41 @@ class RepositoriesUpdateView(BaseFormView):
)
if not repository:
return web.HTTPNotFound()
return await self.render_template("settings/repositories/update.html", {"repository": repository.record})
return await self.render_template(
"settings/repositories/update.html", {"repository": repository.record}
)
async def post(self):
data = await self.request.post()
repository = await self.services.repository.get(
user_uid=self.session.get("uid"), name=self.request.match_info["name"]
)
repository['is_private'] = int(data.get('is_private',0))
repository["is_private"] = int(data.get("is_private", 0))
await self.services.repository.save(repository)
return web.HTTPFound("/settings/repositories/index.html")
class RepositoriesDeleteView(BaseFormView):
login_required = True
async def get(self):
repository = await self.services.repository.get(
user_uid=self.session.get("uid"), name=self.request.match_info["name"]
)
if not repository:
return web.HTTPNotFound()
return await self.render_template("settings/repositories/delete.html", {"repository": repository.record})
return await self.render_template(
"settings/repositories/delete.html", {"repository": repository.record}
)
async def post(self):
user_uid = self.session.get("uid")
name = self.request.match_info["name"]
repository = await self.services.repository.get(
user_uid=user_uid, name=name
)
repository = await self.services.repository.get(user_uid=user_uid, name=name)
if not repository:
return web.HTTPNotFound()
await self.services.repository.delete(user_uid=user_uid, name=name)
return web.HTTPFound("/settings/repositories/index.html")