Update.
This commit is contained in:
parent
f182c2209e
commit
389d417c1b
@ -1,27 +1,31 @@
|
||||
import click
|
||||
import uvloop
|
||||
from aiohttp import web
|
||||
import asyncio
|
||||
from snek.app import Application
|
||||
from IPython import start_ipython
|
||||
import sqlite3
|
||||
import pathlib
|
||||
import shutil
|
||||
import shutil
|
||||
import sqlite3
|
||||
|
||||
import click
|
||||
from aiohttp import web
|
||||
from IPython import start_ipython
|
||||
|
||||
from snek.app import Application
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('--db_path',default="snek.db", help='Database to initialize if not exists.')
|
||||
@click.option('--source',default=None, help='Database to initialize if not exists.')
|
||||
def init(db_path,source):
|
||||
@click.option(
|
||||
"--db_path", default="snek.db", help="Database to initialize if not exists."
|
||||
)
|
||||
@click.option("--source", default=None, help="Database to initialize if not exists.")
|
||||
def init(db_path, source):
|
||||
if source and pathlib.Path(source).exists():
|
||||
print(f"Copying {source} to {db_path}")
|
||||
shutil.copy2(source,db_path)
|
||||
shutil.copy2(source, db_path)
|
||||
print("Database initialized.")
|
||||
return
|
||||
|
||||
|
||||
if pathlib.Path(db_path).exists():
|
||||
return
|
||||
print(f"Initializing database at {db_path}")
|
||||
@ -33,25 +37,44 @@ def init(db_path,source):
|
||||
db.close()
|
||||
print("Database initialized.")
|
||||
|
||||
@cli.command()
|
||||
@click.option('--port', default=8081, show_default=True, help='Port to run the application on')
|
||||
@click.option('--host', default='0.0.0.0', show_default=True, help='Host to run the application on')
|
||||
@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application')
|
||||
def serve(port, host, db_path):
|
||||
#init(db_path)
|
||||
#asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
web.run_app(
|
||||
Application(db_path=f"sqlite:///{db_path}"), port=port, host=host
|
||||
)
|
||||
|
||||
@cli.command()
|
||||
@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application')
|
||||
@click.option(
|
||||
"--port", default=8081, show_default=True, help="Port to run the application on"
|
||||
)
|
||||
@click.option(
|
||||
"--host",
|
||||
default="0.0.0.0",
|
||||
show_default=True,
|
||||
help="Host to run the application on",
|
||||
)
|
||||
@click.option(
|
||||
"--db_path",
|
||||
default="snek.db",
|
||||
show_default=True,
|
||||
help="Database path for the application",
|
||||
)
|
||||
def serve(port, host, db_path):
|
||||
# init(db_path)
|
||||
# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
web.run_app(Application(db_path=f"sqlite:///{db_path}"), port=port, host=host)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"--db_path",
|
||||
default="snek.db",
|
||||
show_default=True,
|
||||
help="Database path for the application",
|
||||
)
|
||||
def shell(db_path):
|
||||
app = Application(db_path=f"sqlite:///{db_path}")
|
||||
start_ipython(argv=[], user_ns={'app': app})
|
||||
start_ipython(argv=[], user_ns={"app": app})
|
||||
|
||||
|
||||
def main():
|
||||
cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
102
src/snek/app.py
102
src/snek/app.py
@ -4,13 +4,15 @@ import pathlib
|
||||
import ssl
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from snek import snode
|
||||
from snek.view.threads import ThreadsView
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
from ipaddress import ip_address
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from ipaddress import ip_address
|
||||
|
||||
import IP2Location
|
||||
from aiohttp import web
|
||||
from aiohttp_session import (
|
||||
get_session as session_get,
|
||||
@ -20,54 +22,56 @@ from aiohttp_session import (
|
||||
from aiohttp_session.cookie_storage import EncryptedCookieStorage
|
||||
from app.app import Application as BaseApplication
|
||||
from jinja2 import FileSystemLoader
|
||||
import IP2Location
|
||||
from snek.sssh import start_ssh_server
|
||||
|
||||
from snek.docs.app import Application as DocsApplication
|
||||
from snek.mapper import get_mappers
|
||||
from snek.service import get_services
|
||||
from snek.sgit import GitApplication
|
||||
from snek.sssh import start_ssh_server
|
||||
from snek.system import http
|
||||
from snek.system.cache import Cache
|
||||
from snek.system.markdown import MarkdownExtension
|
||||
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
|
||||
from snek.system.profiler import profiler_handler
|
||||
from snek.system.template import EmojiExtension, LinkifyExtension, PythonExtension
|
||||
from snek.system.template import (
|
||||
EmojiExtension,
|
||||
LinkifyExtension,
|
||||
PythonExtension,
|
||||
sanitize_html,
|
||||
)
|
||||
from snek.view.about import AboutHTMLView, AboutMDView
|
||||
from snek.view.avatar import AvatarView
|
||||
from snek.view.channel import ChannelAttachmentView, ChannelView
|
||||
from snek.view.docs import DocsHTMLView, DocsMDView
|
||||
from snek.view.drive import DriveView
|
||||
from snek.view.drive import DriveApiView
|
||||
from snek.view.drive import DriveApiView, DriveView
|
||||
from snek.view.index import IndexView
|
||||
from snek.view.login import LoginView
|
||||
from snek.view.push import PushView
|
||||
from snek.view.logout import LogoutView
|
||||
from snek.view.push import PushView
|
||||
from snek.view.register import RegisterView
|
||||
from snek.view.rpc import RPCView
|
||||
from snek.view.repository import RepositoryView
|
||||
from snek.view.rpc import RPCView
|
||||
from snek.view.search_user import SearchUserView
|
||||
from snek.view.settings.repositories import RepositoriesIndexView
|
||||
from snek.view.settings.repositories import RepositoriesCreateView
|
||||
from snek.view.settings.repositories import RepositoriesUpdateView
|
||||
from snek.view.settings.repositories import RepositoriesDeleteView
|
||||
from snek.view.settings.containers import (
|
||||
ContainersCreateView,
|
||||
ContainersDeleteView,
|
||||
ContainersIndexView,
|
||||
ContainersUpdateView,
|
||||
)
|
||||
from snek.view.settings.index import SettingsIndexView
|
||||
from snek.view.settings.profile import SettingsProfileView
|
||||
from snek.view.settings.repositories import (
|
||||
RepositoriesCreateView,
|
||||
RepositoriesDeleteView,
|
||||
RepositoriesIndexView,
|
||||
RepositoriesUpdateView,
|
||||
)
|
||||
from snek.view.stats import StatsView
|
||||
from snek.view.status import StatusView
|
||||
from snek.view.terminal import TerminalSocketView, TerminalView
|
||||
from snek.view.upload import UploadView
|
||||
from snek.view.user import UserView
|
||||
from snek.view.web import WebView
|
||||
from snek.view.channel import ChannelAttachmentView
|
||||
from snek.view.channel import ChannelView
|
||||
from snek.view.settings.containers import (
|
||||
ContainersIndexView,
|
||||
ContainersCreateView,
|
||||
ContainersUpdateView,
|
||||
ContainersDeleteView,
|
||||
)
|
||||
from snek.webdav import WebdavApplication
|
||||
from snek.system.template import sanitize_html
|
||||
from snek.sgit import GitApplication
|
||||
|
||||
SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34"
|
||||
from snek.system.template import whitelist_attributes
|
||||
@ -94,7 +98,7 @@ async def ip2location_middleware(request, handler):
|
||||
if not user:
|
||||
return response
|
||||
location = request.app.ip2location.get(ip)
|
||||
original_city = user["city"]
|
||||
user["city"]
|
||||
if user["city"] != location.city:
|
||||
user["country_long"] = location.country
|
||||
user["country_short"] = locaion.country_short
|
||||
@ -121,7 +125,7 @@ class Application(BaseApplication):
|
||||
cors_middleware,
|
||||
web.normalize_path_middleware(merge_slashes=True),
|
||||
ip2location_middleware,
|
||||
csp_middleware
|
||||
csp_middleware,
|
||||
]
|
||||
self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
|
||||
self.static_path = pathlib.Path(__file__).parent.joinpath("static")
|
||||
@ -139,7 +143,7 @@ class Application(BaseApplication):
|
||||
self.jinja2_env.add_extension(LinkifyExtension)
|
||||
self.jinja2_env.add_extension(PythonExtension)
|
||||
self.jinja2_env.add_extension(EmojiExtension)
|
||||
self.jinja2_env.filters['sanitize'] = sanitize_html
|
||||
self.jinja2_env.filters["sanitize"] = sanitize_html
|
||||
self.time_start = datetime.now()
|
||||
self.ssh_host = "0.0.0.0"
|
||||
self.ssh_port = 2242
|
||||
@ -272,8 +276,12 @@ class Application(BaseApplication):
|
||||
self.router.add_get("/http-photo", self.handle_http_photo)
|
||||
self.router.add_get("/rpc.ws", RPCView)
|
||||
self.router.add_get("/c/{channel:.*}", ChannelView)
|
||||
self.router.add_view("/channel/{channel_uid}/attachment.bin",ChannelAttachmentView)
|
||||
self.router.add_view("/channel/attachment/{relative_url:.*}",ChannelAttachmentView)
|
||||
self.router.add_view(
|
||||
"/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
|
||||
)
|
||||
self.router.add_view(
|
||||
"/channel/attachment/{relative_url:.*}", ChannelAttachmentView
|
||||
)
|
||||
self.router.add_view("/channel/{channel}.html", WebView)
|
||||
self.router.add_view("/threads.html", ThreadsView)
|
||||
self.router.add_view("/terminal.ws", TerminalSocketView)
|
||||
@ -284,15 +292,29 @@ class Application(BaseApplication):
|
||||
self.router.add_view("/stats.json", StatsView)
|
||||
self.router.add_view("/user/{user}.html", UserView)
|
||||
self.router.add_view("/repository/{username}/{repository}", RepositoryView)
|
||||
self.router.add_view("/repository/{username}/{repository}/{path:.*}", RepositoryView)
|
||||
self.router.add_view(
|
||||
"/repository/{username}/{repository}/{path:.*}", RepositoryView
|
||||
)
|
||||
self.router.add_view("/settings/repositories/index.html", RepositoriesIndexView)
|
||||
self.router.add_view("/settings/repositories/create.html", RepositoriesCreateView)
|
||||
self.router.add_view("/settings/repositories/repository/{name}/update.html", RepositoriesUpdateView)
|
||||
self.router.add_view("/settings/repositories/repository/{name}/delete.html", RepositoriesDeleteView)
|
||||
self.router.add_view(
|
||||
"/settings/repositories/create.html", RepositoriesCreateView
|
||||
)
|
||||
self.router.add_view(
|
||||
"/settings/repositories/repository/{name}/update.html",
|
||||
RepositoriesUpdateView,
|
||||
)
|
||||
self.router.add_view(
|
||||
"/settings/repositories/repository/{name}/delete.html",
|
||||
RepositoriesDeleteView,
|
||||
)
|
||||
self.router.add_view("/settings/containers/index.html", ContainersIndexView)
|
||||
self.router.add_view("/settings/containers/create.html", ContainersCreateView)
|
||||
self.router.add_view("/settings/containers/container/{uid}/update.html", ContainersUpdateView)
|
||||
self.router.add_view("/settings/containers/container/{uid}/delete.html", ContainersDeleteView)
|
||||
self.router.add_view(
|
||||
"/settings/containers/container/{uid}/update.html", ContainersUpdateView
|
||||
)
|
||||
self.router.add_view(
|
||||
"/settings/containers/container/{uid}/delete.html", ContainersDeleteView
|
||||
)
|
||||
self.webdav = WebdavApplication(self)
|
||||
self.git = GitApplication(self)
|
||||
self.add_subapp("/webdav", self.webdav)
|
||||
@ -302,9 +324,9 @@ class Application(BaseApplication):
|
||||
|
||||
async def handle_test(self, request):
|
||||
|
||||
return await whitelist_attributes(self.render_template(
|
||||
"test.html", request, context={"name": "retoor"}
|
||||
))
|
||||
return await whitelist_attributes(
|
||||
self.render_template("test.html", request, context={"name": "retoor"})
|
||||
)
|
||||
|
||||
async def handle_http_get(self, request: web.Request):
|
||||
url = request.query.get("url")
|
||||
@ -375,8 +397,8 @@ class Application(BaseApplication):
|
||||
|
||||
self.jinja2_env.loader = self.original_loader
|
||||
|
||||
#rendered.text = whitelist_attributes(rendered.text)
|
||||
#rendered.headers['Content-Lenght'] = len(rendered.text)
|
||||
# rendered.text = whitelist_attributes(rendered.text)
|
||||
# rendered.headers['Content-Lenght'] = len(rendered.text)
|
||||
return rendered
|
||||
|
||||
async def static_handler(self, request):
|
||||
@ -423,7 +445,7 @@ app = Application(db_path="sqlite:///snek.db")
|
||||
|
||||
async def main():
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain('cert.pem', 'key.pem')
|
||||
ssl_context.load_cert_chain("cert.pem", "key.pem")
|
||||
await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context)
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
|
||||
class LoadBalancer:
|
||||
def __init__(self, backend_ports):
|
||||
self.backend_ports = backend_ports
|
||||
@ -8,27 +9,29 @@ class LoadBalancer:
|
||||
self.client_counts = [0] * len(backend_ports)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def start_backend_servers(self,port,workers):
|
||||
async def start_backend_servers(self, port, workers):
|
||||
for x in range(workers):
|
||||
port += 1
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
sys.argv[0],
|
||||
'backend',
|
||||
"backend",
|
||||
str(port),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
port += 1
|
||||
self.backend_processes.append(process)
|
||||
print(f"Started backend server on port {(port-1)/port} with PID {process.pid}")
|
||||
print(
|
||||
f"Started backend server on port {(port-1)/port} with PID {process.pid}"
|
||||
)
|
||||
|
||||
async def handle_client(self, reader, writer):
|
||||
async with self.lock:
|
||||
min_clients = min(self.client_counts)
|
||||
server_index = self.client_counts.index(min_clients)
|
||||
self.client_counts[server_index] += 1
|
||||
backend = ('127.0.0.1', self.backend_ports[server_index])
|
||||
backend = ("127.0.0.1", self.backend_ports[server_index])
|
||||
try:
|
||||
backend_reader, backend_writer = await asyncio.open_connection(*backend)
|
||||
|
||||
@ -62,10 +65,10 @@ class LoadBalancer:
|
||||
for i, count in enumerate(self.client_counts):
|
||||
print(f"Server {self.backend_ports[i]}: {count} clients")
|
||||
|
||||
async def start(self, host='0.0.0.0', port=8081,workers=5):
|
||||
await self.start_backend_servers(port,workers)
|
||||
async def start(self, host="0.0.0.0", port=8081, workers=5):
|
||||
await self.start_backend_servers(port, workers)
|
||||
server = await asyncio.start_server(self.handle_client, host, port)
|
||||
monitor_task = asyncio.create_task(self.monitor())
|
||||
asyncio.create_task(self.monitor())
|
||||
|
||||
# Handle shutdown gracefully
|
||||
try:
|
||||
@ -80,6 +83,7 @@ class LoadBalancer:
|
||||
await asyncio.gather(*(p.wait() for p in self.backend_processes))
|
||||
print("Backend processes terminated.")
|
||||
|
||||
|
||||
async def backend_echo_server(port):
|
||||
async def handle_echo(reader, writer):
|
||||
try:
|
||||
@ -94,10 +98,11 @@ async def backend_echo_server(port):
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_server(handle_echo, '127.0.0.1', port)
|
||||
server = await asyncio.start_server(handle_echo, "127.0.0.1", port)
|
||||
print(f"Backend echo server running on port {port}")
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
async def main():
|
||||
backend_ports = [8001, 8003, 8005, 8006]
|
||||
# Launch backend echo servers
|
||||
@ -105,19 +110,20 @@ async def main():
|
||||
lb = LoadBalancer(backend_ports)
|
||||
await lb.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
if sys.argv[1] == 'backend':
|
||||
if sys.argv[1] == "backend":
|
||||
port = int(sys.argv[2])
|
||||
from snek.app import Application
|
||||
|
||||
snek = Application(port=port)
|
||||
web.run_app(snek, port=port, host='127.0.0.1')
|
||||
elif sys.argv[1] == 'sync':
|
||||
from snek.sync import app
|
||||
web.run_app(snek, port=port, host='127.0.0.1')
|
||||
web.run_app(snek, port=port, host="127.0.0.1")
|
||||
elif sys.argv[1] == "sync":
|
||||
|
||||
web.run_app(snek, port=port, host="127.0.0.1")
|
||||
else:
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down...")
|
||||
|
||||
|
@ -1,22 +1,21 @@
|
||||
import functools
|
||||
|
||||
from snek.mapper.channel import ChannelMapper
|
||||
from snek.mapper.channel_attachment import ChannelAttachmentMapper
|
||||
from snek.mapper.channel_member import ChannelMemberMapper
|
||||
from snek.mapper.channel_message import ChannelMessageMapper
|
||||
from snek.mapper.container import ContainerMapper
|
||||
from snek.mapper.drive import DriveMapper
|
||||
from snek.mapper.drive_item import DriveItemMapper
|
||||
from snek.mapper.notification import NotificationMapper
|
||||
from snek.mapper.push import PushMapper
|
||||
from snek.mapper.repository import RepositoryMapper
|
||||
from snek.mapper.user import UserMapper
|
||||
from snek.mapper.user_property import UserPropertyMapper
|
||||
from snek.mapper.repository import RepositoryMapper
|
||||
from snek.mapper.push import PushMapper
|
||||
from snek.mapper.channel_attachment import ChannelAttachmentMapper
|
||||
from snek.mapper.container import ContainerMapper
|
||||
from snek.system.object import Object
|
||||
|
||||
|
||||
@functools.cache
|
||||
|
||||
def get_mappers(app=None):
|
||||
return Object(
|
||||
**{
|
||||
|
@ -1,6 +1,7 @@
|
||||
from snek.model.container import Container
|
||||
from snek.system.mapper import BaseMapper
|
||||
|
||||
|
||||
class ContainerMapper(BaseMapper):
|
||||
model_class = Container
|
||||
table_name = "container"
|
||||
table_name = "container"
|
||||
|
@ -1,19 +1,19 @@
|
||||
import functools
|
||||
|
||||
from snek.model.channel import ChannelModel
|
||||
from snek.model.channel_attachment import ChannelAttachmentModel
|
||||
from snek.model.channel_member import ChannelMemberModel
|
||||
|
||||
# from snek.model.channel_message import ChannelMessageModel
|
||||
from snek.model.channel_message import ChannelMessageModel
|
||||
from snek.model.container import Container
|
||||
from snek.model.drive import DriveModel
|
||||
from snek.model.drive_item import DriveItemModel
|
||||
from snek.model.notification import NotificationModel
|
||||
from snek.model.push_registration import PushRegistrationModel
|
||||
from snek.model.repository import RepositoryModel
|
||||
from snek.model.user import UserModel
|
||||
from snek.model.user_property import UserPropertyModel
|
||||
from snek.model.repository import RepositoryModel
|
||||
from snek.model.channel_attachment import ChannelAttachmentModel
|
||||
from snek.model.push_registration import PushRegistrationModel
|
||||
from snek.model.container import Container
|
||||
from snek.system.object import Object
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
from snek.system.model import BaseModel
|
||||
from snek.system.model import BaseModel, ModelField
|
||||
|
||||
|
||||
@ -11,6 +10,6 @@ class ChannelAttachmentModel(BaseModel):
|
||||
user_uid = ModelField(name="user_uid", required=True, kind=str)
|
||||
mime_type = ModelField(name="type", required=True, kind=str)
|
||||
relative_url = ModelField(name="relative_url", required=True, kind=str)
|
||||
resource_type = ModelField(name="resource_type", required=True, kind=str,value="file")
|
||||
|
||||
|
||||
resource_type = ModelField(
|
||||
name="resource_type", required=True, kind=str, value="file"
|
||||
)
|
||||
|
@ -1,6 +1,8 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from snek.model.user import UserModel
|
||||
from snek.system.model import BaseModel, ModelField
|
||||
from datetime import datetime,timezone
|
||||
|
||||
|
||||
class ChannelMessageModel(BaseModel):
|
||||
channel_uid = ModelField(name="channel_uid", required=True, kind=str)
|
||||
@ -10,7 +12,11 @@ class ChannelMessageModel(BaseModel):
|
||||
is_final = ModelField(name="is_final", required=True, kind=bool, value=True)
|
||||
|
||||
def get_seconds_since_last_update(self):
|
||||
return int((datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])).total_seconds())
|
||||
return int(
|
||||
(
|
||||
datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])
|
||||
).total_seconds()
|
||||
)
|
||||
|
||||
async def get_user(self) -> UserModel:
|
||||
return await self.app.services.user.get(uid=self["user_uid"])
|
||||
|
@ -1,5 +1,6 @@
|
||||
from snek.system.model import BaseModel, ModelField
|
||||
|
||||
|
||||
class Container(BaseModel):
|
||||
id = ModelField(name="id", required=True, kind=str)
|
||||
name = ModelField(name="name", required=True, kind=str)
|
||||
|
@ -9,7 +9,9 @@ class DriveItemModel(BaseModel):
|
||||
path = ModelField(name="path", required=True, kind=str)
|
||||
file_type = ModelField(name="file_type", required=True, kind=str)
|
||||
file_size = ModelField(name="file_size", required=True, kind=int)
|
||||
is_available = ModelField(name="is_available", required=True, kind=bool, initial_value=True)
|
||||
is_available = ModelField(
|
||||
name="is_available", required=True, kind=bool, initial_value=True
|
||||
)
|
||||
|
||||
@property
|
||||
def extension(self):
|
||||
|
@ -1,14 +1,10 @@
|
||||
from snek.model.user import UserModel
|
||||
from snek.system.model import BaseModel, ModelField
|
||||
|
||||
|
||||
class RepositoryModel(BaseModel):
|
||||
|
||||
user_uid = ModelField(name="user_uid", required=True, kind=str)
|
||||
|
||||
|
||||
name = ModelField(name="name", required=True, kind=str)
|
||||
|
||||
is_private = ModelField(name="is_private", required=False, kind=bool)
|
||||
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ class UserModel(BaseModel):
|
||||
last_ping = ModelField(name="last_ping", required=False, kind=str)
|
||||
|
||||
is_admin = ModelField(name="is_admin", required=False, kind=bool)
|
||||
|
||||
|
||||
country_short = ModelField(name="country_short", required=False, kind=str)
|
||||
country_long = ModelField(name="country_long", required=False, kind=str)
|
||||
city = ModelField(name="city", required=False, kind=str)
|
||||
|
@ -1,51 +1,54 @@
|
||||
import snek.serpentarium
|
||||
|
||||
import time
|
||||
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import snek.serpentarium
|
||||
|
||||
durations = []
|
||||
|
||||
|
||||
def task1():
|
||||
global durations
|
||||
global durations
|
||||
client = snek.serpentarium.DatasetWrapper()
|
||||
|
||||
start=time.time()
|
||||
start = time.time()
|
||||
for x in range(1500):
|
||||
|
||||
client['a'].delete()
|
||||
client['a'].insert({"foo": x})
|
||||
client['a'].find(foo=x)
|
||||
client['a'].find_one(foo=x)
|
||||
client['a'].count()
|
||||
#print(client['a'].find(foo=x) )
|
||||
#print(client['a'].find_one(foo=x) )
|
||||
#print(client['a'].count())
|
||||
|
||||
client["a"].delete()
|
||||
client["a"].insert({"foo": x})
|
||||
client["a"].find(foo=x)
|
||||
client["a"].find_one(foo=x)
|
||||
client["a"].count()
|
||||
# print(client['a'].find(foo=x) )
|
||||
# print(client['a'].find_one(foo=x) )
|
||||
# print(client['a'].count())
|
||||
client.close()
|
||||
duration1 = f"{time.time()-start}"
|
||||
durations.append(duration1)
|
||||
print(durations)
|
||||
|
||||
|
||||
with ProcessPoolExecutor(max_workers=4) as executor:
|
||||
tasks = [executor.submit(task1),
|
||||
tasks = [
|
||||
executor.submit(task1),
|
||||
executor.submit(task1),
|
||||
executor.submit(task1),
|
||||
executor.submit(task1),
|
||||
executor.submit(task1)
|
||||
]
|
||||
for task in tasks:
|
||||
task.result()
|
||||
|
||||
|
||||
import dataset
|
||||
import dataset
|
||||
|
||||
client = dataset.connect("sqlite:///snek.db")
|
||||
start=time.time()
|
||||
start = time.time()
|
||||
for x in range(1500):
|
||||
|
||||
client['a'].delete()
|
||||
client['a'].insert({"foo": x})
|
||||
print([dict(row) for row in client['a'].find(foo=x)])
|
||||
print(dict(client['a'].find_one(foo=x) ))
|
||||
print(client['a'].count())
|
||||
client["a"].delete()
|
||||
client["a"].insert({"foo": x})
|
||||
print([dict(row) for row in client["a"].find(foo=x)])
|
||||
print(dict(client["a"].find_one(foo=x)))
|
||||
print(client["a"].count())
|
||||
duration2 = f"{time.time()-start}"
|
||||
|
||||
print(duration1,duration2)
|
||||
print(duration1, duration2)
|
||||
|
@ -1,22 +1,22 @@
|
||||
import functools
|
||||
|
||||
from snek.service.channel import ChannelService
|
||||
from snek.service.channel_attachment import ChannelAttachmentService
|
||||
from snek.service.channel_member import ChannelMemberService
|
||||
from snek.service.channel_message import ChannelMessageService
|
||||
from snek.service.chat import ChatService
|
||||
from snek.service.container import ContainerService
|
||||
from snek.service.db import DBService
|
||||
from snek.service.drive import DriveService
|
||||
from snek.service.drive_item import DriveItemService
|
||||
from snek.service.notification import NotificationService
|
||||
from snek.service.push import PushService
|
||||
from snek.service.repository import RepositoryService
|
||||
from snek.service.socket import SocketService
|
||||
from snek.service.user import UserService
|
||||
from snek.service.push import PushService
|
||||
from snek.service.user_property import UserPropertyService
|
||||
from snek.service.util import UtilService
|
||||
from snek.service.repository import RepositoryService
|
||||
from snek.service.channel_attachment import ChannelAttachmentService
|
||||
from snek.service.container import ContainerService
|
||||
from snek.system.object import Object
|
||||
from snek.service.db import DBService
|
||||
|
||||
|
||||
@functools.cache
|
||||
|
@ -1,9 +1,9 @@
|
||||
import pathlib
|
||||
from datetime import datetime
|
||||
|
||||
from snek.system.model import now
|
||||
from snek.system.service import BaseService
|
||||
|
||||
import pathlib
|
||||
|
||||
class ChannelService(BaseService):
|
||||
mapper_name = "channel"
|
||||
@ -17,12 +17,10 @@ class ChannelService(BaseService):
|
||||
pass
|
||||
return folder
|
||||
|
||||
async def get_attachment_folder(self, channel_uid,ensure=False):
|
||||
async def get_attachment_folder(self, channel_uid, ensure=False):
|
||||
path = pathlib.Path(f"./drive/{channel_uid}/attachments")
|
||||
if ensure:
|
||||
path.mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
async def get(self, uid=None, **kwargs):
|
||||
@ -78,7 +76,10 @@ class ChannelService(BaseService):
|
||||
return channel
|
||||
|
||||
async def get_recent_users(self, channel_uid):
|
||||
async for user in self.query("SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30", {"channel_uid": channel_uid}):
|
||||
async for user in self.query(
|
||||
"SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30",
|
||||
{"channel_uid": channel_uid},
|
||||
):
|
||||
yield user
|
||||
|
||||
async def get_users(self, channel_uid):
|
||||
|
@ -1,25 +1,25 @@
|
||||
from snek.system.service import BaseService
|
||||
import urllib.parse
|
||||
import pathlib
|
||||
import mimetypes
|
||||
import uuid
|
||||
|
||||
from snek.system.service import BaseService
|
||||
|
||||
|
||||
class ChannelAttachmentService(BaseService):
|
||||
mapper_name="channel_attachment"
|
||||
mapper_name = "channel_attachment"
|
||||
|
||||
async def create_file(self, channel_uid, user_uid, name):
|
||||
attachment = await self.new()
|
||||
attachment["channel_uid"] = channel_uid
|
||||
attachment['user_uid'] = user_uid
|
||||
attachment["user_uid"] = user_uid
|
||||
attachment["name"] = name
|
||||
attachment["mime_type"] = mimetypes.guess_type(name)[0]
|
||||
attachment['resource_type'] = "file"
|
||||
attachment["resource_type"] = "file"
|
||||
real_file_name = f"{attachment['uid']}-{name}"
|
||||
attachment["relative_url"] = (f"{attachment['uid']}-{name}")
|
||||
attachment_folder = await self.services.channel.get_attachment_folder(channel_uid)
|
||||
attachment["relative_url"] = f"{attachment['uid']}-{name}"
|
||||
attachment_folder = await self.services.channel.get_attachment_folder(
|
||||
channel_uid
|
||||
)
|
||||
attachment_path = attachment_folder.joinpath(real_file_name)
|
||||
attachment["path"] = str(attachment_path)
|
||||
if await self.save(attachment):
|
||||
return attachment
|
||||
raise Exception(f"Failed to create channel attachment: {attachment.errors}.")
|
||||
|
||||
|
@ -11,7 +11,7 @@ class ChannelMessageService(BaseService):
|
||||
model["channel_uid"] = channel_uid
|
||||
model["user_uid"] = user_uid
|
||||
model["message"] = message
|
||||
model['is_final'] = is_final
|
||||
model["is_final"] = is_final
|
||||
|
||||
context = {}
|
||||
|
||||
@ -56,7 +56,7 @@ class ChannelMessageService(BaseService):
|
||||
async def save(self, model):
|
||||
context = {}
|
||||
context.update(model.record)
|
||||
user = await self.app.services.user.get(model['user_uid'])
|
||||
user = await self.app.services.user.get(model["user_uid"])
|
||||
context.update(
|
||||
{
|
||||
"user_uid": user["uid"],
|
||||
|
@ -13,13 +13,13 @@ class ChatService(BaseService):
|
||||
channel["last_message_on"] = now()
|
||||
await self.services.channel.save(channel)
|
||||
await self.services.socket.broadcast(
|
||||
channel['uid'],
|
||||
channel["uid"],
|
||||
{
|
||||
"message": channel_message["message"],
|
||||
"html": channel_message["html"],
|
||||
"user_uid": user['uid'],
|
||||
"user_uid": user["uid"],
|
||||
"color": user["color"],
|
||||
"channel_uid": channel['uid'],
|
||||
"channel_uid": channel["uid"],
|
||||
"created_at": channel_message["created_at"],
|
||||
"updated_at": channel_message["updated_at"],
|
||||
"username": user["username"],
|
||||
@ -32,14 +32,12 @@ class ChatService(BaseService):
|
||||
self.services.notification.create_channel_message(message_uid)
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def send(self, user_uid, channel_uid, message, is_final=True):
|
||||
channel = await self.services.channel.get(uid=channel_uid)
|
||||
if not channel:
|
||||
raise Exception("Channel not found.")
|
||||
channel_message = await self.services.channel_message.create(
|
||||
channel_uid, user_uid, message,is_final
|
||||
channel_uid, user_uid, message, is_final
|
||||
)
|
||||
channel_message_uid = channel_message["uid"]
|
||||
|
||||
|
@ -1,36 +1,51 @@
|
||||
from snek.system.docker import ComposeFileManager
|
||||
from snek.system.service import BaseService
|
||||
from snek.system.docker import ComposeFileManager
|
||||
|
||||
|
||||
class ContainerService(BaseService):
|
||||
mapper_name = "container"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
self.compose_path = "snek-container-compose.yml"
|
||||
self.compose = ComposeFileManager(self.compose_path)
|
||||
|
||||
|
||||
|
||||
async def get_instances(self):
|
||||
return list(self.compose.list_instances())
|
||||
|
||||
async def get_container_name(self, channel_uid):
|
||||
return f"channel-{channel_uid}"
|
||||
|
||||
async def create(self, channel_uid, image="ubuntu:latest", command=None, cpus=1, memory='1024m', ports=None, volumes=None):
|
||||
async def create(
|
||||
self,
|
||||
channel_uid,
|
||||
image="ubuntu:latest",
|
||||
command=None,
|
||||
cpus=1,
|
||||
memory="1024m",
|
||||
ports=None,
|
||||
volumes=None,
|
||||
):
|
||||
name = await self.get_container_name(channel_uid)
|
||||
self.compose.create_instance(
|
||||
name,
|
||||
image,
|
||||
command,
|
||||
cpus,
|
||||
memory,
|
||||
ports,
|
||||
["./"+str((await self.services.channel.get_home_folder(channel_uid))) +":"+ "/home/ubuntu"]
|
||||
name,
|
||||
image,
|
||||
command,
|
||||
cpus,
|
||||
memory,
|
||||
ports,
|
||||
[
|
||||
"./"
|
||||
+ str(await self.services.channel.get_home_folder(channel_uid))
|
||||
+ ":"
|
||||
+ "/home/ubuntu"
|
||||
],
|
||||
)
|
||||
|
||||
async def create2(self, id, name, status, resources=None, user_uid=None, path=None, readonly=False):
|
||||
async def create2(
|
||||
self, id, name, status, resources=None, user_uid=None, path=None, readonly=False
|
||||
):
|
||||
model = await self.new()
|
||||
model["id"] = id
|
||||
model["name"] = name
|
||||
|
@ -1,23 +1,21 @@
|
||||
from snek.system.service import BaseService
|
||||
import dataset
|
||||
import uuid
|
||||
import dataset
|
||||
|
||||
from snek.system.service import BaseService
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
class DBService(BaseService):
|
||||
|
||||
|
||||
async def get_db(self, user_uid):
|
||||
|
||||
|
||||
home_folder = await self.app.services.user.get_home_folder(user_uid)
|
||||
home_folder.mkdir(parents=True, exist_ok=True)
|
||||
db_path = home_folder.joinpath("snek/user.db")
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return dataset.connect("sqlite:///" + str(db_path))
|
||||
|
||||
return dataset.connect("sqlite:///" + str(db_path))
|
||||
|
||||
async def insert(self, user_uid, table_name, values):
|
||||
db = await self.get_db(user_uid)
|
||||
return db[table_name].insert(values)
|
||||
|
||||
|
||||
async def update(self, user_uid, table_name, values, filters):
|
||||
db = await self.get_db(user_uid)
|
||||
@ -33,7 +31,7 @@ class DBService(BaseService):
|
||||
|
||||
async def find(self, user_uid, table_name, kwargs):
|
||||
db = await self.get_db(user_uid)
|
||||
kwargs['_limit'] = kwargs.get('_limit', 30)
|
||||
kwargs["_limit"] = kwargs.get("_limit", 30)
|
||||
return [dict(row) for row in db[table_name].find(**kwargs)]
|
||||
|
||||
async def get(self, user_uid, table_name, filters):
|
||||
@ -45,14 +43,13 @@ class DBService(BaseService):
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
async def delete(self, user_uid, table_name, filters):
|
||||
db = await self.get_db(user_uid)
|
||||
if not filters:
|
||||
filters = {}
|
||||
return db[table_name].delete(**filters)
|
||||
|
||||
async def query(self, sql,values):
|
||||
async def query(self, sql, values):
|
||||
db = await self.app.db
|
||||
return [dict(row) for row in db.query(sql, values or {})]
|
||||
|
||||
@ -62,8 +59,6 @@ class DBService(BaseService):
|
||||
filters = {}
|
||||
return bool(db[table_name].find_one(**filters))
|
||||
|
||||
|
||||
|
||||
async def count(self, user_uid, table_name, filters):
|
||||
db = await self.get_db(user_uid)
|
||||
if not filters:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from snek.system.markdown import strip_markdown
|
||||
from snek.system.model import now
|
||||
from snek.system.service import BaseService
|
||||
from snek.system.markdown import strip_markdown
|
||||
|
||||
|
||||
class NotificationService(BaseService):
|
||||
mapper_name = "notification"
|
||||
@ -34,7 +35,7 @@ class NotificationService(BaseService):
|
||||
uid=channel_message_uid
|
||||
)
|
||||
if not channel_message["is_final"]:
|
||||
return
|
||||
return
|
||||
user = await self.services.user.get(uid=channel_message["user_uid"])
|
||||
self.app.db.begin()
|
||||
async for channel_member in self.services.channel_member.find(
|
||||
|
@ -1,25 +1,23 @@
|
||||
import base64
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
from snek.system.service import BaseService
|
||||
import os.path
|
||||
import random
|
||||
import time
|
||||
import base64
|
||||
import uuid
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from urllib.parse import urlparse
|
||||
import os.path
|
||||
|
||||
import aiohttp
|
||||
import jwt
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.hashes import SHA256
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
|
||||
from snek.system.service import BaseService
|
||||
|
||||
# The only reason to persist the keys is to be able to use them in the web push
|
||||
|
||||
PRIVATE_KEY_FILE = Path("./notification-private.pem")
|
||||
@ -268,4 +266,4 @@ class PushService(BaseService):
|
||||
|
||||
raise Exception(
|
||||
f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}"
|
||||
)
|
||||
)
|
||||
|
@ -1,13 +1,17 @@
|
||||
from snek.system.service import BaseService
|
||||
import asyncio
|
||||
import asyncio
|
||||
import shutil
|
||||
|
||||
from snek.system.service import BaseService
|
||||
|
||||
|
||||
class RepositoryService(BaseService):
|
||||
mapper_name = "repository"
|
||||
|
||||
async def delete(self, user_uid, name):
|
||||
loop = asyncio.get_event_loop()
|
||||
repository_path = (await self.services.user.get_repository_path(user_uid)).joinpath(name)
|
||||
repository_path = (
|
||||
await self.services.user.get_repository_path(user_uid)
|
||||
).joinpath(name)
|
||||
try:
|
||||
await loop.run_in_executor(None, shutil.rmtree, repository_path)
|
||||
except Exception as ex:
|
||||
@ -15,7 +19,6 @@ class RepositoryService(BaseService):
|
||||
|
||||
await super().delete(user_uid=user_uid, name=name)
|
||||
|
||||
|
||||
async def exists(self, user_uid, name, **kwargs):
|
||||
kwargs["user_uid"] = user_uid
|
||||
kwargs["name"] = name
|
||||
@ -29,18 +32,16 @@ class RepositoryService(BaseService):
|
||||
repository_path = str(repository_path)
|
||||
if not repository_path.endswith(".git"):
|
||||
repository_path += ".git"
|
||||
command = ['git', 'init', '--bare', repository_path]
|
||||
command = ["git", "init", "--bare", repository_path]
|
||||
process = await asyncio.subprocess.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
return process.returncode == 0
|
||||
|
||||
async def create(self, user_uid, name,is_private=False):
|
||||
async def create(self, user_uid, name, is_private=False):
|
||||
if await self.exists(user_uid=user_uid, name=name):
|
||||
return False
|
||||
return False
|
||||
|
||||
if not await self.init(user_uid=user_uid, name=name):
|
||||
return False
|
||||
|
@ -1,12 +1,14 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from snek.model.user import UserModel
|
||||
from snek.system.service import BaseService
|
||||
from datetime import datetime
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from snek.system.model import now
|
||||
|
||||
|
||||
class SocketService(BaseService):
|
||||
|
||||
class Socket:
|
||||
@ -15,7 +17,6 @@ class SocketService(BaseService):
|
||||
self.is_connected = True
|
||||
self.user = user
|
||||
|
||||
|
||||
async def send_json(self, data):
|
||||
if not self.is_connected:
|
||||
return False
|
||||
@ -40,7 +41,6 @@ class SocketService(BaseService):
|
||||
self.users = {}
|
||||
self.subscriptions = {}
|
||||
self.last_update = str(datetime.now())
|
||||
|
||||
|
||||
async def user_availability_service(self):
|
||||
logger.info("User availability update service started.")
|
||||
@ -50,14 +50,15 @@ class SocketService(BaseService):
|
||||
for s in self.sockets:
|
||||
if not s.is_connected:
|
||||
continue
|
||||
if not s.user in users_updated:
|
||||
if s.user not in users_updated:
|
||||
s.user["last_ping"] = now()
|
||||
await self.app.services.user.save(s.user)
|
||||
users_updated.append(s.user)
|
||||
logger.info(f"Updated user availability for {len(users_updated)} online users.")
|
||||
logger.info(
|
||||
f"Updated user availability for {len(users_updated)} online users."
|
||||
)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
async def add(self, ws, user_uid):
|
||||
s = self.Socket(ws, await self.app.services.user.get(uid=user_uid))
|
||||
self.sockets.add(s)
|
||||
@ -81,7 +82,6 @@ class SocketService(BaseService):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
async def broadcast(self, channel_uid, message):
|
||||
await self._broadcast(channel_uid, message)
|
||||
|
||||
@ -102,4 +102,3 @@ class SocketService(BaseService):
|
||||
await s.close()
|
||||
logger.info(f"Removed socket for user {s.user['username']}")
|
||||
self.sockets.remove(s)
|
||||
|
||||
|
@ -32,14 +32,14 @@ class UserService(BaseService):
|
||||
user["color"] = await self.services.util.random_light_hex_color()
|
||||
return await super().save(user)
|
||||
|
||||
def authenticate_sync(self,username,password):
|
||||
def authenticate_sync(self, username, password):
|
||||
user = self.get_by_username_sync(username)
|
||||
|
||||
|
||||
if not user:
|
||||
return False
|
||||
if not security.verify_sync(password, user["password"]):
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
|
||||
async def authenticate(self, username, password):
|
||||
success = await self.validate_login(username, password)
|
||||
@ -61,14 +61,12 @@ class UserService(BaseService):
|
||||
return None
|
||||
return path
|
||||
|
||||
|
||||
|
||||
async def get_template_path(self, user_uid):
|
||||
path = pathlib.Path(f"./drive/{user_uid}/snek/templates")
|
||||
if not path.exists():
|
||||
return None
|
||||
return path
|
||||
|
||||
|
||||
def get_by_username_sync(self, username):
|
||||
user = self.mapper.db["user"].find_one(username=username, deleted_at=None)
|
||||
return dict(user)
|
||||
|
422
src/snek/sgit.py
422
src/snek/sgit.py
@ -1,48 +1,51 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import aiohttp
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
import shutil
|
||||
import json
|
||||
import tempfile
|
||||
import asyncio
|
||||
import logging
|
||||
import base64
|
||||
import pathlib
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger('git_server')
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("git_server")
|
||||
|
||||
|
||||
class GitApplication(web.Application):
|
||||
def __init__(self, parent=None):
|
||||
#import git
|
||||
#globals()['git'] = git
|
||||
# import git
|
||||
# globals()['git'] = git
|
||||
self.parent = parent
|
||||
super().__init__(client_max_size=1024*1024*1024*5)
|
||||
self.add_routes([
|
||||
web.post('/create/{repo_name}', self.create_repository),
|
||||
web.delete('/delete/{repo_name}', self.delete_repository),
|
||||
web.get('/clone/{repo_name}', self.clone_repository),
|
||||
web.post('/push/{repo_name}', self.push_repository),
|
||||
web.post('/pull/{repo_name}', self.pull_repository),
|
||||
web.get('/status/{repo_name}', self.status_repository),
|
||||
#web.get('/list', self.list_repositories),
|
||||
web.get('/branches/{repo_name}', self.list_branches),
|
||||
web.post('/branches/{repo_name}', self.create_branch),
|
||||
web.get('/log/{repo_name}', self.commit_log),
|
||||
web.get('/file/{repo_name}/{file_path:.*}', self.file_content),
|
||||
web.get('/{path:.+}/info/refs', self.git_smart_http),
|
||||
web.post('/{path:.+}/git-upload-pack', self.git_smart_http),
|
||||
web.post('/{path:.+}/git-receive-pack', self.git_smart_http),
|
||||
web.get('/{repo_name}.git/info/refs', self.git_smart_http),
|
||||
web.post('/{repo_name}.git/git-upload-pack', self.git_smart_http),
|
||||
web.post('/{repo_name}.git/git-receive-pack', self.git_smart_http),
|
||||
])
|
||||
|
||||
super().__init__(client_max_size=1024 * 1024 * 1024 * 5)
|
||||
self.add_routes(
|
||||
[
|
||||
web.post("/create/{repo_name}", self.create_repository),
|
||||
web.delete("/delete/{repo_name}", self.delete_repository),
|
||||
web.get("/clone/{repo_name}", self.clone_repository),
|
||||
web.post("/push/{repo_name}", self.push_repository),
|
||||
web.post("/pull/{repo_name}", self.pull_repository),
|
||||
web.get("/status/{repo_name}", self.status_repository),
|
||||
# web.get('/list', self.list_repositories),
|
||||
web.get("/branches/{repo_name}", self.list_branches),
|
||||
web.post("/branches/{repo_name}", self.create_branch),
|
||||
web.get("/log/{repo_name}", self.commit_log),
|
||||
web.get("/file/{repo_name}/{file_path:.*}", self.file_content),
|
||||
web.get("/{path:.+}/info/refs", self.git_smart_http),
|
||||
web.post("/{path:.+}/git-upload-pack", self.git_smart_http),
|
||||
web.post("/{path:.+}/git-receive-pack", self.git_smart_http),
|
||||
web.get("/{repo_name}.git/info/refs", self.git_smart_http),
|
||||
web.post("/{repo_name}.git/git-upload-pack", self.git_smart_http),
|
||||
web.post("/{repo_name}.git/git-receive-pack", self.git_smart_http),
|
||||
]
|
||||
)
|
||||
|
||||
async def check_basic_auth(self, request):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Basic "):
|
||||
return None,None
|
||||
return None, None
|
||||
encoded_creds = auth_header.split("Basic ")[1]
|
||||
decoded_creds = base64.b64decode(encoded_creds).decode()
|
||||
username, password = decoded_creds.split(":", 1)
|
||||
@ -50,27 +53,31 @@ class GitApplication(web.Application):
|
||||
username=username, password=password
|
||||
)
|
||||
if not request["user"]:
|
||||
return None,None
|
||||
request["repository_path"] = await self.parent.services.user.get_repository_path(
|
||||
request["user"]["uid"]
|
||||
return None, None
|
||||
request["repository_path"] = (
|
||||
await self.parent.services.user.get_repository_path(request["user"]["uid"])
|
||||
)
|
||||
|
||||
return request["user"]['username'],request["repository_path"]
|
||||
|
||||
return request["user"]["username"], request["repository_path"]
|
||||
|
||||
@staticmethod
|
||||
def require_auth(handler):
|
||||
async def wrapped(self, request, *args, **kwargs):
|
||||
username, repository_path = await self.check_basic_auth(request)
|
||||
if not username or not repository_path:
|
||||
return web.Response(status=401, headers={'WWW-Authenticate': 'Basic'}, text='Authentication required')
|
||||
request['username'] = username
|
||||
request['repository_path'] = repository_path
|
||||
return web.Response(
|
||||
status=401,
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
text="Authentication required",
|
||||
)
|
||||
request["username"] = username
|
||||
request["repository_path"] = repository_path
|
||||
return await handler(self, request, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
def repo_path(self, repository_path, repo_name):
|
||||
return repository_path.joinpath(repo_name + '.git')
|
||||
return repository_path.joinpath(repo_name + ".git")
|
||||
|
||||
def check_repo_exists(self, repository_path, repo_name):
|
||||
repo_dir = self.repo_path(repository_path, repo_name)
|
||||
@ -80,10 +87,10 @@ class GitApplication(web.Application):
|
||||
|
||||
@require_auth
|
||||
async def create_repository(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
if not repo_name or '/' in repo_name or '..' in repo_name:
|
||||
username = request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
if not repo_name or "/" in repo_name or ".." in repo_name:
|
||||
return web.Response(text="Invalid repository name", status=400)
|
||||
repo_dir = self.repo_path(repository_path, repo_name)
|
||||
if os.path.exists(repo_dir):
|
||||
@ -98,9 +105,9 @@ class GitApplication(web.Application):
|
||||
|
||||
@require_auth
|
||||
async def delete_repository(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
username = request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
@ -115,9 +122,9 @@ class GitApplication(web.Application):
|
||||
|
||||
@require_auth
|
||||
async def clone_repository(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
@ -126,15 +133,15 @@ class GitApplication(web.Application):
|
||||
response_data = {
|
||||
"repository": repo_name,
|
||||
"clone_command": f"git clone {clone_url}",
|
||||
"clone_url": clone_url
|
||||
"clone_url": clone_url,
|
||||
}
|
||||
return web.json_response(response_data)
|
||||
|
||||
@require_auth
|
||||
async def push_repository(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
username = request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
@ -142,34 +149,40 @@ class GitApplication(web.Application):
|
||||
data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.Response(text="Invalid JSON data", status=400)
|
||||
commit_message = data.get('commit_message', 'Update from server')
|
||||
branch = data.get('branch', 'main')
|
||||
changes = data.get('changes', [])
|
||||
commit_message = data.get("commit_message", "Update from server")
|
||||
branch = data.get("branch", "main")
|
||||
changes = data.get("changes", [])
|
||||
if not changes:
|
||||
return web.Response(text="No changes provided", status=400)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
|
||||
temp_repo = git.Repo.clone_from(
|
||||
self.repo_path(repository_path, repo_name), temp_dir
|
||||
)
|
||||
for change in changes:
|
||||
file_path = os.path.join(temp_dir, change.get('file', ''))
|
||||
content = change.get('content', '')
|
||||
file_path = os.path.join(temp_dir, change.get("file", ""))
|
||||
content = change.get("content", "")
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, 'w') as f:
|
||||
with open(file_path, "w") as f:
|
||||
f.write(content)
|
||||
temp_repo.git.add(A=True)
|
||||
if not temp_repo.config_reader().has_section('user'):
|
||||
temp_repo.config_writer().set_value("user", "name", "Git Server").release()
|
||||
temp_repo.config_writer().set_value("user", "email", "git@server.local").release()
|
||||
if not temp_repo.config_reader().has_section("user"):
|
||||
temp_repo.config_writer().set_value(
|
||||
"user", "name", "Git Server"
|
||||
).release()
|
||||
temp_repo.config_writer().set_value(
|
||||
"user", "email", "git@server.local"
|
||||
).release()
|
||||
temp_repo.index.commit(commit_message)
|
||||
origin = temp_repo.remote('origin')
|
||||
origin = temp_repo.remote("origin")
|
||||
origin.push(refspec=f"{branch}:{branch}")
|
||||
logger.info(f"Pushed to repository: {repo_name} for user {username}")
|
||||
return web.Response(text=f"Successfully pushed changes to {repo_name}")
|
||||
|
||||
@require_auth
|
||||
async def pull_repository(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
username = request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
@ -177,13 +190,15 @@ class GitApplication(web.Application):
|
||||
data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
data = {}
|
||||
remote_url = data.get('remote_url')
|
||||
branch = data.get('branch', 'main')
|
||||
remote_url = data.get("remote_url")
|
||||
branch = data.get("branch", "main")
|
||||
if not remote_url:
|
||||
return web.Response(text="Remote URL is required", status=400)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
local_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
|
||||
local_repo = git.Repo.clone_from(
|
||||
self.repo_path(repository_path, repo_name), temp_dir
|
||||
)
|
||||
remote_name = "pull_source"
|
||||
try:
|
||||
remote = local_repo.create_remote(remote_name, remote_url)
|
||||
@ -192,38 +207,46 @@ class GitApplication(web.Application):
|
||||
remote.set_url(remote_url)
|
||||
remote.fetch()
|
||||
local_repo.git.merge(f"{remote_name}/{branch}")
|
||||
origin = local_repo.remote('origin')
|
||||
origin = local_repo.remote("origin")
|
||||
origin.push()
|
||||
logger.info(f"Pulled to repository {repo_name} from {remote_url} for user {username}")
|
||||
return web.Response(text=f"Successfully pulled changes from {remote_url} to {repo_name}")
|
||||
logger.info(
|
||||
f"Pulled to repository {repo_name} from {remote_url} for user {username}"
|
||||
)
|
||||
return web.Response(
|
||||
text=f"Successfully pulled changes from {remote_url} to {repo_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pulling to {repo_name}: {str(e)}")
|
||||
return web.Response(text=f"Error pulling changes: {str(e)}", status=500)
|
||||
|
||||
@require_auth
|
||||
async def status_repository(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
|
||||
temp_repo = git.Repo.clone_from(
|
||||
self.repo_path(repository_path, repo_name), temp_dir
|
||||
)
|
||||
branches = [b.name for b in temp_repo.branches]
|
||||
active_branch = temp_repo.active_branch.name
|
||||
commits = []
|
||||
for commit in list(temp_repo.iter_commits(max_count=5)):
|
||||
commits.append({
|
||||
"id": commit.hexsha,
|
||||
"author": f"{commit.author.name} <{commit.author.email}>",
|
||||
"date": commit.committed_datetime.isoformat(),
|
||||
"message": commit.message
|
||||
})
|
||||
commits.append(
|
||||
{
|
||||
"id": commit.hexsha,
|
||||
"author": f"{commit.author.name} <{commit.author.email}>",
|
||||
"date": commit.committed_datetime.isoformat(),
|
||||
"message": commit.message,
|
||||
}
|
||||
)
|
||||
files = []
|
||||
for root, dirs, filenames in os.walk(temp_dir):
|
||||
if '.git' in root:
|
||||
if ".git" in root:
|
||||
continue
|
||||
for filename in filenames:
|
||||
full_path = os.path.join(root, filename)
|
||||
@ -234,50 +257,58 @@ class GitApplication(web.Application):
|
||||
"branches": branches,
|
||||
"active_branch": active_branch,
|
||||
"recent_commits": commits,
|
||||
"files": files
|
||||
"files": files,
|
||||
}
|
||||
return web.json_response(status_info)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting status for {repo_name}: {str(e)}")
|
||||
return web.Response(text=f"Error getting repository status: {str(e)}", status=500)
|
||||
return web.Response(
|
||||
text=f"Error getting repository status: {str(e)}", status=500
|
||||
)
|
||||
|
||||
@require_auth
|
||||
async def list_repositories(self, request):
|
||||
username = request['username']
|
||||
request["username"]
|
||||
try:
|
||||
repos = []
|
||||
user_dir = self.REPO_DIR
|
||||
if os.path.exists(user_dir):
|
||||
for item in os.listdir(user_dir):
|
||||
item_path = os.path.join(user_dir, item)
|
||||
if os.path.isdir(item_path) and item.endswith('.git'):
|
||||
if os.path.isdir(item_path) and item.endswith(".git"):
|
||||
repos.append(item[:-4])
|
||||
if request.query.get('format') == 'json':
|
||||
if request.query.get("format") == "json":
|
||||
return web.json_response({"repositories": repos})
|
||||
else:
|
||||
return web.Response(text="\n".join(repos) if repos else "No repositories found")
|
||||
return web.Response(
|
||||
text="\n".join(repos) if repos else "No repositories found"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing repositories: {str(e)}")
|
||||
return web.Response(text=f"Error listing repositories: {str(e)}", status=500)
|
||||
return web.Response(
|
||||
text=f"Error listing repositories: {str(e)}", status=500
|
||||
)
|
||||
|
||||
@require_auth
|
||||
async def list_branches(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
|
||||
temp_repo = git.Repo.clone_from(
|
||||
self.repo_path(repository_path, repo_name), temp_dir
|
||||
)
|
||||
branches = [b.name for b in temp_repo.branches]
|
||||
return web.json_response({"branches": branches})
|
||||
|
||||
@require_auth
|
||||
async def create_branch(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
username = request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
@ -285,145 +316,168 @@ class GitApplication(web.Application):
|
||||
data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.Response(text="Invalid JSON data", status=400)
|
||||
branch_name = data.get('branch_name')
|
||||
start_point = data.get('start_point', 'HEAD')
|
||||
branch_name = data.get("branch_name")
|
||||
start_point = data.get("start_point", "HEAD")
|
||||
if not branch_name:
|
||||
return web.Response(text="Branch name is required", status=400)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
|
||||
temp_repo = git.Repo.clone_from(
|
||||
self.repo_path(repository_path, repo_name), temp_dir
|
||||
)
|
||||
temp_repo.git.branch(branch_name, start_point)
|
||||
temp_repo.git.push('origin', branch_name)
|
||||
logger.info(f"Created branch {branch_name} in repository {repo_name} for user {username}")
|
||||
temp_repo.git.push("origin", branch_name)
|
||||
logger.info(
|
||||
f"Created branch {branch_name} in repository {repo_name} for user {username}"
|
||||
)
|
||||
return web.Response(text=f"Created branch {branch_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating branch {branch_name} in {repo_name}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error creating branch {branch_name} in {repo_name}: {str(e)}"
|
||||
)
|
||||
return web.Response(text=f"Error creating branch: {str(e)}", status=500)
|
||||
|
||||
@require_auth
|
||||
async def commit_log(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
repository_path = request['repository_path']
|
||||
request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
try:
|
||||
limit = int(request.query.get('limit', 10))
|
||||
branch = request.query.get('branch', 'main')
|
||||
limit = int(request.query.get("limit", 10))
|
||||
branch = request.query.get("branch", "main")
|
||||
except ValueError:
|
||||
return web.Response(text="Invalid limit parameter", status=400)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
|
||||
temp_repo = git.Repo.clone_from(
|
||||
self.repo_path(repository_path, repo_name), temp_dir
|
||||
)
|
||||
commits = []
|
||||
try:
|
||||
for commit in list(temp_repo.iter_commits(branch, max_count=limit)):
|
||||
commits.append({
|
||||
"id": commit.hexsha,
|
||||
"short_id": commit.hexsha[:7],
|
||||
"author": f"{commit.author.name} <{commit.author.email}>",
|
||||
"date": commit.committed_datetime.isoformat(),
|
||||
"message": commit.message.strip()
|
||||
})
|
||||
commits.append(
|
||||
{
|
||||
"id": commit.hexsha,
|
||||
"short_id": commit.hexsha[:7],
|
||||
"author": f"{commit.author.name} <{commit.author.email}>",
|
||||
"date": commit.committed_datetime.isoformat(),
|
||||
"message": commit.message.strip(),
|
||||
}
|
||||
)
|
||||
except git.GitCommandError as e:
|
||||
if "unknown revision or path" in str(e):
|
||||
commits = []
|
||||
else:
|
||||
raise
|
||||
return web.json_response({
|
||||
"repository": repo_name,
|
||||
"branch": branch,
|
||||
"commits": commits
|
||||
})
|
||||
return web.json_response(
|
||||
{"repository": repo_name, "branch": branch, "commits": commits}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting commit log for {repo_name}: {str(e)}")
|
||||
return web.Response(text=f"Error getting commit log: {str(e)}", status=500)
|
||||
return web.Response(
|
||||
text=f"Error getting commit log: {str(e)}", status=500
|
||||
)
|
||||
|
||||
@require_auth
|
||||
async def file_content(self, request):
|
||||
username = request['username']
|
||||
repo_name = request.match_info['repo_name']
|
||||
file_path = request.match_info.get('file_path', '')
|
||||
branch = request.query.get('branch', 'main')
|
||||
repository_path = request['repository_path']
|
||||
request["username"]
|
||||
repo_name = request.match_info["repo_name"]
|
||||
file_path = request.match_info.get("file_path", "")
|
||||
branch = request.query.get("branch", "main")
|
||||
repository_path = request["repository_path"]
|
||||
error_response = self.check_repo_exists(repository_path, repo_name)
|
||||
if error_response:
|
||||
return error_response
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
|
||||
temp_repo = git.Repo.clone_from(
|
||||
self.repo_path(repository_path, repo_name), temp_dir
|
||||
)
|
||||
try:
|
||||
temp_repo.git.checkout(branch)
|
||||
except git.GitCommandError:
|
||||
return web.Response(text=f"Branch '{branch}' not found", status=404)
|
||||
file_full_path = os.path.join(temp_dir, file_path)
|
||||
if not os.path.exists(file_full_path):
|
||||
return web.Response(text=f"File '{file_path}' not found", status=404)
|
||||
return web.Response(
|
||||
text=f"File '{file_path}' not found", status=404
|
||||
)
|
||||
if os.path.isdir(file_full_path):
|
||||
files = os.listdir(file_full_path)
|
||||
return web.json_response({
|
||||
"repository": repo_name,
|
||||
"path": file_path,
|
||||
"type": "directory",
|
||||
"contents": files
|
||||
})
|
||||
return web.json_response(
|
||||
{
|
||||
"repository": repo_name,
|
||||
"path": file_path,
|
||||
"type": "directory",
|
||||
"contents": files,
|
||||
}
|
||||
)
|
||||
else:
|
||||
try:
|
||||
with open(file_full_path, 'r') as f:
|
||||
with open(file_full_path) as f:
|
||||
content = f.read()
|
||||
return web.Response(text=content)
|
||||
except UnicodeDecodeError:
|
||||
return web.Response(text=f"Cannot display binary file content for '{file_path}'", status=400)
|
||||
return web.Response(
|
||||
text=f"Cannot display binary file content for '{file_path}'",
|
||||
status=400,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file content from {repo_name}: {str(e)}")
|
||||
return web.Response(text=f"Error getting file content: {str(e)}", status=500)
|
||||
return web.Response(
|
||||
text=f"Error getting file content: {str(e)}", status=500
|
||||
)
|
||||
|
||||
@require_auth
|
||||
async def git_smart_http(self, request):
|
||||
username = request['username']
|
||||
repository_path = request['repository_path']
|
||||
request["username"]
|
||||
repository_path = request["repository_path"]
|
||||
path = request.path
|
||||
|
||||
async def get_repository_path():
|
||||
req_path = path.lstrip('/')
|
||||
if req_path.endswith('/info/refs'):
|
||||
repo_name = req_path[:-len('/info/refs')]
|
||||
elif req_path.endswith('/git-upload-pack'):
|
||||
repo_name = req_path[:-len('/git-upload-pack')]
|
||||
elif req_path.endswith('/git-receive-pack'):
|
||||
repo_name = req_path[:-len('/git-receive-pack')]
|
||||
req_path = path.lstrip("/")
|
||||
if req_path.endswith("/info/refs"):
|
||||
repo_name = req_path[: -len("/info/refs")]
|
||||
elif req_path.endswith("/git-upload-pack"):
|
||||
repo_name = req_path[: -len("/git-upload-pack")]
|
||||
elif req_path.endswith("/git-receive-pack"):
|
||||
repo_name = req_path[: -len("/git-receive-pack")]
|
||||
else:
|
||||
repo_name = req_path
|
||||
if repo_name.endswith('.git'):
|
||||
if repo_name.endswith(".git"):
|
||||
repo_name = repo_name[:-4]
|
||||
repo_name = repo_name[4:]
|
||||
repo_dir = repository_path.joinpath(repo_name + ".git")
|
||||
logger.info(f"Resolved repo path: {repo_dir}")
|
||||
return repo_dir
|
||||
return repo_dir
|
||||
|
||||
async def handle_info_refs(service):
|
||||
repo_path = await get_repository_path()
|
||||
|
||||
|
||||
logger.info(f"handle_info_refs: {repo_path}")
|
||||
if not os.path.exists(repo_path):
|
||||
return web.Response(text="Repository not found", status=404)
|
||||
cmd = [service, '--stateless-rpc', '--advertise-refs', str(repo_path)]
|
||||
cmd = [service, "--stateless-rpc", "--advertise-refs", str(repo_path)]
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode != 0:
|
||||
logger.error(f"Git command failed: {stderr.decode()}")
|
||||
return web.Response(text=f"Git error: {stderr.decode()}", status=500)
|
||||
return web.Response(
|
||||
text=f"Git error: {stderr.decode()}", status=500
|
||||
)
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
reason='OK',
|
||||
reason="OK",
|
||||
headers={
|
||||
'Content-Type': f'application/x-{service}-advertisement',
|
||||
'Cache-Control': 'no-cache'
|
||||
}
|
||||
"Content-Type": f"application/x-{service}-advertisement",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
)
|
||||
await response.prepare(request)
|
||||
packet = f"# service={service}\n"
|
||||
@ -435,48 +489,58 @@ class GitApplication(web.Application):
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling info/refs: {str(e)}")
|
||||
return web.Response(text=f"Server error: {str(e)}", status=500)
|
||||
|
||||
async def handle_service_rpc(service):
|
||||
repo_path = await get_repository_path()
|
||||
logger.info(f"handle_service_rpc: {repo_path}")
|
||||
if not os.path.exists(repo_path):
|
||||
return web.Response(text="Repository not found", status=404)
|
||||
if not request.headers.get('Content-Type') == f'application/x-{service}-request':
|
||||
if (
|
||||
not request.headers.get("Content-Type")
|
||||
== f"application/x-{service}-request"
|
||||
):
|
||||
return web.Response(text="Invalid Content-Type", status=403)
|
||||
body = await request.read()
|
||||
cmd = [service, '--stateless-rpc', str(repo_path)]
|
||||
cmd = [service, "--stateless-rpc", str(repo_path)]
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate(input=body)
|
||||
if process.returncode != 0:
|
||||
logger.error(f"Git command failed: {stderr.decode()}")
|
||||
return web.Response(text=f"Git error: {stderr.decode()}", status=500)
|
||||
return web.Response(
|
||||
text=f"Git error: {stderr.decode()}", status=500
|
||||
)
|
||||
return web.Response(
|
||||
body=stdout,
|
||||
content_type=f'application/x-{service}-result'
|
||||
body=stdout, content_type=f"application/x-{service}-result"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling service RPC: {str(e)}")
|
||||
return web.Response(text=f"Server error: {str(e)}", status=500)
|
||||
if request.method == 'GET' and path.endswith('/info/refs'):
|
||||
service = request.query.get('service')
|
||||
if service in ('git-upload-pack', 'git-receive-pack'):
|
||||
|
||||
if request.method == "GET" and path.endswith("/info/refs"):
|
||||
service = request.query.get("service")
|
||||
if service in ("git-upload-pack", "git-receive-pack"):
|
||||
return await handle_info_refs(service)
|
||||
else:
|
||||
return web.Response(text="Smart HTTP requires service parameter", status=400)
|
||||
elif request.method == 'POST' and '/git-upload-pack' in path:
|
||||
return await handle_service_rpc('git-upload-pack')
|
||||
elif request.method == 'POST' and '/git-receive-pack' in path:
|
||||
return await handle_service_rpc('git-receive-pack')
|
||||
return web.Response(
|
||||
text="Smart HTTP requires service parameter", status=400
|
||||
)
|
||||
elif request.method == "POST" and "/git-upload-pack" in path:
|
||||
return await handle_service_rpc("git-upload-pack")
|
||||
elif request.method == "POST" and "/git-receive-pack" in path:
|
||||
return await handle_service_rpc("git-receive-pack")
|
||||
return web.Response(text="Not found", status=404)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import uvloop
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
logger.info("Using uvloop for improved performance")
|
||||
except ImportError:
|
||||
|
@ -2,28 +2,28 @@ import aiohttp
|
||||
|
||||
ENABLED = False
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
import dataset
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
import json
|
||||
|
||||
queue = asyncio.Queue()
|
||||
|
||||
|
||||
class State:
|
||||
do_not_sync = False
|
||||
do_not_sync = False
|
||||
|
||||
|
||||
async def sync_service(app):
|
||||
if not ENABLED:
|
||||
return
|
||||
return
|
||||
session = aiohttp.ClientSession()
|
||||
async with session.ws_connect('http://localhost:3131/ws') as ws:
|
||||
async with session.ws_connect("http://localhost:3131/ws") as ws:
|
||||
|
||||
async def receive():
|
||||
|
||||
queries_synced = 0
|
||||
@ -31,7 +31,7 @@ async def sync_service(app):
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
State.do_not_sync = True
|
||||
State.do_not_sync = True
|
||||
app.db.execute(*data)
|
||||
app.db.commit()
|
||||
State.do_not_sync = False
|
||||
@ -41,21 +41,25 @@ async def sync_service(app):
|
||||
await app.services.socket.broadcast_event()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
#print(f"Received: {msg.data}")
|
||||
pass
|
||||
# print(f"Received: {msg.data}")
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
break
|
||||
|
||||
async def write():
|
||||
while True:
|
||||
msg = await queue.get()
|
||||
await ws.send_str(json.dumps(msg,default=str))
|
||||
await ws.send_str(json.dumps(msg, default=str))
|
||||
queue.task_done()
|
||||
|
||||
await asyncio.gather(receive(), write())
|
||||
|
||||
await session.close()
|
||||
|
||||
|
||||
queries_queued = 0
|
||||
|
||||
|
||||
# Attach a listener to log all executed statements
|
||||
@event.listens_for(Engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
@ -63,7 +67,7 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
|
||||
return
|
||||
global queries_queued
|
||||
if State.do_not_sync:
|
||||
print(statement,parameters)
|
||||
print(statement, parameters)
|
||||
return
|
||||
if statement.startswith("SELECT"):
|
||||
return
|
||||
@ -71,17 +75,18 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
|
||||
queries_queued += 1
|
||||
print("Queries queued: " + str(queries_queued))
|
||||
|
||||
|
||||
async def websocket_handler(request):
|
||||
queries_broadcasted = 0
|
||||
queries_broadcasted = 0
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
request.app['websockets'].append(ws)
|
||||
request.app["websockets"].append(ws)
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
for client in request.app['websockets']:
|
||||
for client in request.app["websockets"]:
|
||||
if client != ws:
|
||||
await client.send_str(msg.data)
|
||||
cursor = request.app['db'].cursor()
|
||||
cursor = request.app["db"].cursor()
|
||||
data = json.loads(msg.data)
|
||||
queries_broadcasted += 1
|
||||
|
||||
@ -89,28 +94,32 @@ async def websocket_handler(request):
|
||||
cursor.close()
|
||||
print("Queries broadcasted: " + str(queries_broadcasted))
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print(f'WebSocket connection closed with exception {ws.exception()}')
|
||||
print(f"WebSocket connection closed with exception {ws.exception()}")
|
||||
|
||||
request.app['websockets'].remove(ws)
|
||||
request.app["websockets"].remove(ws)
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app['websockets'] = []
|
||||
|
||||
app.router.add_get('/ws', websocket_handler)
|
||||
app = web.Application()
|
||||
app["websockets"] = []
|
||||
|
||||
app.router.add_get("/ws", websocket_handler)
|
||||
|
||||
|
||||
async def on_startup(app):
|
||||
app['db'] = sqlite3.connect('snek.db')
|
||||
app["db"] = sqlite3.connect("snek.db")
|
||||
print("Server starting...")
|
||||
|
||||
|
||||
async def on_cleanup(app):
|
||||
for ws in app['websockets']:
|
||||
for ws in app["websockets"]:
|
||||
await ws.close()
|
||||
app['db'].close()
|
||||
app["db"].close()
|
||||
|
||||
|
||||
app.on_startup.append(on_startup)
|
||||
app.on_cleanup.append(on_cleanup)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
web.run_app(app, host='127.0.0.1', port=3131)
|
||||
if __name__ == "__main__":
|
||||
web.run_app(app, host="127.0.0.1", port=3131)
|
||||
|
@ -1,25 +1,30 @@
|
||||
import asyncssh
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
global _app
|
||||
import asyncssh
|
||||
|
||||
global _app
|
||||
|
||||
|
||||
def set_app(app):
|
||||
global _app
|
||||
_app = app
|
||||
_app = app
|
||||
|
||||
|
||||
def get_app():
|
||||
return _app
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
roots = {}
|
||||
|
||||
|
||||
class SFTPServer(asyncssh.SFTPServer):
|
||||
|
||||
|
||||
def __init__(self, chan: asyncssh.SSHServerChannel):
|
||||
self.root = get_app().services.user.get_home_folder_by_username(
|
||||
chan.get_extra_info('username')
|
||||
chan.get_extra_info("username")
|
||||
)
|
||||
self.root.mkdir(exist_ok=True)
|
||||
self.root = str(self.root)
|
||||
@ -30,20 +35,22 @@ class SFTPServer(asyncssh.SFTPServer):
|
||||
logger.debug(f"Mapping client path {path} to {mapped_path}")
|
||||
return str(mapped_path).encode()
|
||||
|
||||
|
||||
class SSHServer(asyncssh.SSHServer):
|
||||
def password_auth_supported(self):
|
||||
return True
|
||||
|
||||
def validate_password(self, username, password):
|
||||
logger.debug(f"Validating credentials for user {username}")
|
||||
result = get_app().services.user.authenticate_sync(username,password)
|
||||
result = get_app().services.user.authenticate_sync(username, password)
|
||||
logger.info(f"Validating credentials for user {username}: {result}")
|
||||
return result
|
||||
|
||||
async def start_ssh_server(app,host,port):
|
||||
set_app(app)
|
||||
|
||||
async def start_ssh_server(app, host, port):
|
||||
set_app(app)
|
||||
logger.info("Starting SFTP server setup")
|
||||
|
||||
|
||||
host_key_path = Path("drive") / ".ssh" / "sftp_server_key"
|
||||
host_key_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
try:
|
||||
@ -63,14 +70,12 @@ async def start_ssh_server(app,host,port):
|
||||
x = await asyncssh.listen(
|
||||
host=host,
|
||||
port=port,
|
||||
#process_factory=handle_client,
|
||||
# process_factory=handle_client,
|
||||
server_host_keys=[key],
|
||||
server_factory=SSHServer,
|
||||
sftp_factory=SFTPServer
|
||||
sftp_factory=SFTPServer,
|
||||
)
|
||||
return x
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.warning(f"Failed to start SFTP server. Already running.")
|
||||
pass
|
||||
|
||||
|
||||
pass
|
||||
|
@ -1,58 +1,71 @@
|
||||
import yaml
|
||||
import copy
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class ComposeFileManager:
|
||||
def __init__(self, compose_path='docker-compose.yml'):
|
||||
def __init__(self, compose_path="docker-compose.yml"):
|
||||
self.compose_path = compose_path
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
try:
|
||||
with open(self.compose_path, 'r') as f:
|
||||
with open(self.compose_path) as f:
|
||||
self.compose = yaml.safe_load(f) or {}
|
||||
except FileNotFoundError:
|
||||
self.compose = {'version': '3', 'services': {}}
|
||||
self.compose = {"version": "3", "services": {}}
|
||||
|
||||
def _save(self):
|
||||
with open(self.compose_path, 'w') as f:
|
||||
with open(self.compose_path, "w") as f:
|
||||
yaml.dump(self.compose, f, default_flow_style=False)
|
||||
|
||||
def list_instances(self):
|
||||
return list(self.compose.get('services', {}).keys())
|
||||
return list(self.compose.get("services", {}).keys())
|
||||
|
||||
def create_instance(self, name, image, command=None, cpus=None, memory=None, ports=None, volumes=None):
|
||||
def create_instance(
|
||||
self,
|
||||
name,
|
||||
image,
|
||||
command=None,
|
||||
cpus=None,
|
||||
memory=None,
|
||||
ports=None,
|
||||
volumes=None,
|
||||
):
|
||||
service = {
|
||||
'image': image,
|
||||
"image": image,
|
||||
}
|
||||
if command:
|
||||
service['command'] = command
|
||||
service["command"] = command
|
||||
if cpus or memory:
|
||||
service['deploy'] = {'resources': {'limits': {}}}
|
||||
service["deploy"] = {"resources": {"limits": {}}}
|
||||
if cpus:
|
||||
service['deploy']['resources']['limits']['cpus'] = str(cpus)
|
||||
service["deploy"]["resources"]["limits"]["cpus"] = str(cpus)
|
||||
if memory:
|
||||
service['deploy']['resources']['limits']['memory'] = str(memory)
|
||||
service["deploy"]["resources"]["limits"]["memory"] = str(memory)
|
||||
if ports:
|
||||
service['ports'] = [f"{host}:{container}" for container, host in ports.items()]
|
||||
service["ports"] = [
|
||||
f"{host}:{container}" for container, host in ports.items()
|
||||
]
|
||||
if volumes:
|
||||
service['volumes'] = volumes
|
||||
service["volumes"] = volumes
|
||||
|
||||
self.compose.setdefault('services', {})[name] = service
|
||||
self.compose.setdefault("services", {})[name] = service
|
||||
self._save()
|
||||
|
||||
def remove_instance(self, name):
|
||||
if name in self.compose.get('services', {}):
|
||||
del self.compose['services'][name]
|
||||
if name in self.compose.get("services", {}):
|
||||
del self.compose["services"][name]
|
||||
self._save()
|
||||
|
||||
def get_instance(self, name):
|
||||
return self.compose.get('services', {}).get(name)
|
||||
return self.compose.get("services", {}).get(name)
|
||||
|
||||
def duplicate_instance(self, name, new_name):
|
||||
orig = self.get_instance(name)
|
||||
if not orig:
|
||||
raise ValueError(f"No such instance: {name}")
|
||||
self.compose['services'][new_name] = copy.deepcopy(orig)
|
||||
self.compose["services"][new_name] = copy.deepcopy(orig)
|
||||
self._save()
|
||||
|
||||
def update_instance(self, name, **kwargs):
|
||||
@ -62,11 +75,12 @@ class ComposeFileManager:
|
||||
for k, v in kwargs.items():
|
||||
if v is not None:
|
||||
service[k] = v
|
||||
self.compose['services'][name] = service
|
||||
self.compose["services"][name] = service
|
||||
self._save()
|
||||
|
||||
# Storage size is not tracked in compose files; would need Docker API for that.
|
||||
|
||||
|
||||
# Example usage:
|
||||
# mgr = ComposeFileManager()
|
||||
# mgr.create_instance('web', 'nginx:latest', cpus=1, memory='512m', ports={80:8080}, volumes=['./data:/data'])
|
||||
|
@ -1,6 +1,7 @@
|
||||
DEFAULT_LIMIT = 30
|
||||
import asyncio
|
||||
import typing
|
||||
import asyncio
|
||||
|
||||
from snek.system.model import BaseModel
|
||||
|
||||
|
||||
@ -12,21 +13,21 @@ class BaseMapper:
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.semaphore = asyncio.Semaphore(1)
|
||||
self.semaphore = asyncio.Semaphore(1)
|
||||
self.default_limit = self.__class__.default_limit
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self.app.db
|
||||
|
||||
@property
|
||||
@property
|
||||
def loop(self):
|
||||
return asyncio.get_event_loop()
|
||||
|
||||
async def run_in_executor(self, func, *args, **kwargs):
|
||||
async with self.semaphore:
|
||||
return func(*args, **kwargs)
|
||||
#return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs))
|
||||
# return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs))
|
||||
|
||||
async def new(self):
|
||||
return self.model_class(mapper=self, app=self.app)
|
||||
@ -38,8 +39,8 @@ class BaseMapper:
|
||||
async def get(self, uid: str = None, **kwargs) -> BaseModel:
|
||||
if uid:
|
||||
kwargs["uid"] = uid
|
||||
|
||||
record = await self.run_in_executor(self.table.find_one,**kwargs)
|
||||
|
||||
record = await self.run_in_executor(self.table.find_one, **kwargs)
|
||||
if not record:
|
||||
return None
|
||||
record = dict(record)
|
||||
@ -50,7 +51,7 @@ class BaseMapper:
|
||||
return await self.model_class.from_record(mapper=self, record=record)
|
||||
|
||||
async def exists(self, **kwargs):
|
||||
return await self.run_in_executor(self.table.exists,**kwargs)
|
||||
return await self.run_in_executor(self.table.exists, **kwargs)
|
||||
|
||||
async def count(self, **kwargs) -> int:
|
||||
return await self.run_in_executor(self.table.count, **kwargs)
|
||||
@ -71,7 +72,7 @@ class BaseMapper:
|
||||
yield model
|
||||
|
||||
async def query(self, sql, *args):
|
||||
for record in await self.run_in_executor(self.db.query,sql, *args):
|
||||
for record in await self.run_in_executor(self.db.query, sql, *args):
|
||||
yield dict(record)
|
||||
|
||||
async def update(self, model):
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Original source: https://brandonjay.dev/posts/2021/render-markdown-html-in-python-with-jinja2
|
||||
import re
|
||||
import re
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.cache import time_cache_async
|
||||
@ -13,19 +13,20 @@ from pygments.lexers import get_lexer_by_name
|
||||
|
||||
|
||||
def strip_markdown(md_text):
|
||||
# Remove code blocks (
|
||||
md_text = re.sub(r'[\s\S]?```', '', md_text)
|
||||
md_text = re.sub(r'^\s{4,}.$', '', md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r'^\s{0,3}#{1,6}\s+', '', md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r'!\[.?\]\(.?\)', '', md_text)
|
||||
md_text = re.sub(r'\[([^\]]+)\]\(.?\)', r'\1', md_text)
|
||||
md_text = re.sub(r'(\*|_){1,3}(.+?)\1{1,3}', r'\2', md_text)
|
||||
md_text = re.sub(r'^\s{0,3}>+\s?', '', md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r'^(\s)(\-{3,}|_{3,}|\{3,})\s$', '', md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r'[`~>#+\-=]', '', md_text)
|
||||
md_text = re.sub(r'\s+', ' ', md_text)
|
||||
# Remove code blocks (
|
||||
md_text = re.sub(r"[\s\S]?```", "", md_text)
|
||||
md_text = re.sub(r"^\s{4,}.$", "", md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r"^\s{0,3}#{1,6}\s+", "", md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r"!\[.?\]\(.?\)", "", md_text)
|
||||
md_text = re.sub(r"\[([^\]]+)\]\(.?\)", r"\1", md_text)
|
||||
md_text = re.sub(r"(\*|_){1,3}(.+?)\1{1,3}", r"\2", md_text)
|
||||
md_text = re.sub(r"^\s{0,3}>+\s?", "", md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r"^(\s)(\-{3,}|_{3,}|\{3,})\s$", "", md_text, flags=re.MULTILINE)
|
||||
md_text = re.sub(r"[`~>#+\-=]", "", md_text)
|
||||
md_text = re.sub(r"\s+", " ", md_text)
|
||||
return md_text.strip()
|
||||
|
||||
|
||||
class MarkdownRenderer(HTMLRenderer):
|
||||
|
||||
_allow_harmful_protocols = False
|
||||
@ -40,7 +41,7 @@ class MarkdownRenderer(HTMLRenderer):
|
||||
formatter = html.HtmlFormatter()
|
||||
self.env.globals["highlight_styles"] = formatter.get_style_defs()
|
||||
|
||||
#def _escape(self, str):
|
||||
# def _escape(self, str):
|
||||
# return str ##escape(str)
|
||||
|
||||
def get_lexer(self, lang, default="bash"):
|
||||
|
@ -6,9 +6,9 @@
|
||||
|
||||
# MIT License: This code is distributed under the MIT License.
|
||||
|
||||
from aiohttp import web
|
||||
import secrets
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
csp_policy = (
|
||||
"default-src 'self'; "
|
||||
@ -22,14 +22,16 @@ csp_policy = (
|
||||
def generate_nonce():
|
||||
return secrets.token_hex(16)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def csp_middleware(request, handler):
|
||||
|
||||
response = await handler(request)
|
||||
return response
|
||||
nonce = generate_nonce()
|
||||
response.headers['Content-Security-Policy'] = csp_policy.format(nonce=nonce)
|
||||
return response
|
||||
nonce = generate_nonce()
|
||||
response.headers["Content-Security-Policy"] = csp_policy.format(nonce=nonce)
|
||||
return response
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def no_cors_middleware(request, handler):
|
||||
|
@ -63,9 +63,11 @@ def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str:
|
||||
obj = hashlib.sha256(salted)
|
||||
return obj.hexdigest()
|
||||
|
||||
|
||||
async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
|
||||
return hash_sync(data, salt)
|
||||
|
||||
|
||||
def verify_sync(string: str, hashed: str) -> bool:
|
||||
"""Verify if the given string matches the hashed value.
|
||||
|
||||
@ -78,5 +80,6 @@ def verify_sync(string: str, hashed: str) -> bool:
|
||||
"""
|
||||
return hash_sync(string) == hashed
|
||||
|
||||
|
||||
async def verify(string: str, hashed: str) -> bool:
|
||||
return verify_sync(string, hashed)
|
||||
|
@ -1,17 +1,16 @@
|
||||
import mimetypes
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
from app.cache import time_cache
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import bleach
|
||||
import emoji
|
||||
import requests
|
||||
from app.cache import time_cache
|
||||
from bs4 import BeautifulSoup
|
||||
from jinja2 import TemplateSyntaxError, nodes
|
||||
from jinja2.ext import Extension
|
||||
from jinja2.nodes import Const
|
||||
import bleach
|
||||
|
||||
|
||||
emoji.EMOJI_DATA['<img src="/emoji/snek1.gif" />'] = {
|
||||
"en": ":snek1:",
|
||||
@ -81,13 +80,29 @@ emoji.EMOJI_DATA[
|
||||
|
||||
|
||||
ALLOWED_TAGS = list(bleach.sanitizer.ALLOWED_TAGS) + [
|
||||
"img", "video", "audio", "source", "iframe", "picture", "span"
|
||||
"img",
|
||||
"video",
|
||||
"audio",
|
||||
"source",
|
||||
"iframe",
|
||||
"picture",
|
||||
"span",
|
||||
]
|
||||
ALLOWED_ATTRIBUTES = {
|
||||
**bleach.sanitizer.ALLOWED_ATTRIBUTES,
|
||||
"img": ["src", "alt", "title", "width", "height"],
|
||||
"a": ["href", "title", "target", "rel", "referrerpolicy", "class"],
|
||||
"iframe": ["src", "width", "height", "frameborder", "allow", "allowfullscreen", "title", "referrerpolicy", "style"],
|
||||
"iframe": [
|
||||
"src",
|
||||
"width",
|
||||
"height",
|
||||
"frameborder",
|
||||
"allow",
|
||||
"allowfullscreen",
|
||||
"title",
|
||||
"referrerpolicy",
|
||||
"style",
|
||||
],
|
||||
"video": ["src", "controls", "width", "height"],
|
||||
"audio": ["src", "controls"],
|
||||
"source": ["src", "type"],
|
||||
@ -96,9 +111,6 @@ ALLOWED_ATTRIBUTES = {
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def sanitize_html(value):
|
||||
return bleach.clean(
|
||||
value,
|
||||
@ -109,7 +121,6 @@ def sanitize_html(value):
|
||||
)
|
||||
|
||||
|
||||
|
||||
def set_link_target_blank(text):
|
||||
soup = BeautifulSoup(text, "html.parser")
|
||||
|
||||
@ -118,26 +129,43 @@ def set_link_target_blank(text):
|
||||
element.attrs["rel"] = "noopener noreferrer"
|
||||
element.attrs["referrerpolicy"] = "no-referrer"
|
||||
element.attrs["href"] = element.attrs["href"].strip(".").strip(",")
|
||||
|
||||
|
||||
return str(soup)
|
||||
|
||||
|
||||
SAFE_ATTRIBUTES = {
|
||||
'href', 'src', 'alt', 'title', 'width', 'height', 'style', 'id', 'class',
|
||||
'rel', 'type', 'name', 'value', 'placeholder', 'aria-hidden', 'aria-label', 'srcset'
|
||||
"href",
|
||||
"src",
|
||||
"alt",
|
||||
"title",
|
||||
"width",
|
||||
"height",
|
||||
"style",
|
||||
"id",
|
||||
"class",
|
||||
"rel",
|
||||
"type",
|
||||
"name",
|
||||
"value",
|
||||
"placeholder",
|
||||
"aria-hidden",
|
||||
"aria-label",
|
||||
"srcset",
|
||||
}
|
||||
|
||||
|
||||
def whitelist_attributes(html):
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
for tag in soup.find_all():
|
||||
if hasattr(tag, 'attrs'):
|
||||
if tag.name in ['script','form','input']:
|
||||
tag.replace_with('')
|
||||
if hasattr(tag, "attrs"):
|
||||
if tag.name in ["script", "form", "input"]:
|
||||
tag.replace_with("")
|
||||
continue
|
||||
attrs = dict(tag.attrs)
|
||||
for attr in list(attrs):
|
||||
# Check if attribute is in the safe list or is a data-* attribute
|
||||
if not (attr in SAFE_ATTRIBUTES or attr.startswith('data-')):
|
||||
if not (attr in SAFE_ATTRIBUTES or attr.startswith("data-")):
|
||||
del tag.attrs[attr]
|
||||
return str(soup)
|
||||
|
||||
@ -236,12 +264,12 @@ def enrich_image_rendering(text):
|
||||
for element in soup.find_all("img"):
|
||||
if element.attrs["src"].startswith("/"):
|
||||
element.attrs["src"] += "?width=240&height=240"
|
||||
picture_template = f'''
|
||||
picture_template = f"""
|
||||
<picture>
|
||||
<source srcset="{element.attrs["src"]}" type="{mimetypes.guess_type(element.attrs["src"])[0]}" />
|
||||
<source srcset="{element.attrs["src"]}&format=webp" type="image/webp" />
|
||||
<img src="{element.attrs["src"]}&format=png" title="{element.attrs["src"]}" alt="{element.attrs["src"]}" />
|
||||
</picture>'''
|
||||
</picture>"""
|
||||
element.replace_with(BeautifulSoup(picture_template, "html.parser"))
|
||||
return str(soup)
|
||||
|
||||
@ -285,7 +313,7 @@ def linkify_https(text):
|
||||
return set_link_target_blank(str(soup))
|
||||
|
||||
|
||||
@time_cache(timeout=60*60)
|
||||
@time_cache(timeout=60 * 60)
|
||||
def get_url_content(url):
|
||||
try:
|
||||
response = requests.get(url, timeout=5)
|
||||
@ -368,12 +396,11 @@ def embed_url(text):
|
||||
page_video = get_element_options(None, "video", "video", "video")
|
||||
page_audio = get_element_options(None, "audio", "audio", "audio")
|
||||
|
||||
preview_size = (
|
||||
(
|
||||
get_element_options(None, None, None, "card")
|
||||
or "summary_large_image"
|
||||
)
|
||||
|
||||
|
||||
attachment_base = BeautifulSoup(str(element), "html.parser")
|
||||
attachments[original_link_name] = attachment_base
|
||||
|
||||
@ -412,20 +439,28 @@ def embed_url(text):
|
||||
)
|
||||
|
||||
description_element.append(
|
||||
BeautifulSoup(f'<strong class="page-name">{page_name}</strong>', "html.parser")
|
||||
BeautifulSoup(
|
||||
f'<strong class="page-name">{page_name}</strong>',
|
||||
"html.parser",
|
||||
)
|
||||
)
|
||||
|
||||
description_element.append(
|
||||
BeautifulSoup(f"<p class='page-description'>{page_description or "No description available."}</p>", "html.parser")
|
||||
BeautifulSoup(
|
||||
f"<p class='page-description'>{page_description or "No description available."}</p>",
|
||||
"html.parser",
|
||||
)
|
||||
)
|
||||
|
||||
description_element.append(
|
||||
BeautifulSoup(f"<p class='page-original-link'>{original_link_name}</p>", "html.parser")
|
||||
BeautifulSoup(
|
||||
f"<p class='page-original-link'>{original_link_name}</p>",
|
||||
"html.parser",
|
||||
)
|
||||
)
|
||||
|
||||
render_element.append(description_element_base)
|
||||
|
||||
|
||||
for attachment in attachments.values():
|
||||
soup.append(attachment)
|
||||
|
||||
|
@ -1,22 +1,24 @@
|
||||
class WebSocketClient:
|
||||
def __init__(self, hostname, port):
|
||||
self.buffer = b''
|
||||
self.buffer = b""
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.hostname = hostname
|
||||
self.hostname = hostname
|
||||
self.port = port
|
||||
self.connect()
|
||||
|
||||
self.connect()
|
||||
|
||||
def __getattr__(self, method, *args, **kwargs):
|
||||
if method in self.__dict__.keys():
|
||||
return self.__dict__[method]
|
||||
|
||||
def call(*args, **kwargs):
|
||||
self.write(json.dumps({'method': method, 'args': args, 'kwargs': kwargs}))
|
||||
self.write(json.dumps({"method": method, "args": args, "kwargs": kwargs}))
|
||||
return json.loads(self.read())
|
||||
|
||||
return call
|
||||
|
||||
def connect(self):
|
||||
self.socket.connect((self.hostname, self.port))
|
||||
key = base64.b64encode(b'1234123412341234').decode('utf-8')
|
||||
key = base64.b64encode(b"1234123412341234").decode("utf-8")
|
||||
handshake = (
|
||||
f"GET /db HTTP/1.1\r\n"
|
||||
f"Host: localhost:3131\r\n"
|
||||
@ -25,57 +27,58 @@ class WebSocketClient:
|
||||
f"Sec-WebSocket-Key: {key}\r\n"
|
||||
f"Sec-WebSocket-Version: 13\r\n\r\n"
|
||||
)
|
||||
self.socket.sendall(handshake.encode('utf-8'))
|
||||
response = self.read_until(b'\r\n\r\n')
|
||||
if b'101 Switching Protocols' not in response:
|
||||
self.socket.sendall(handshake.encode("utf-8"))
|
||||
response = self.read_until(b"\r\n\r\n")
|
||||
if b"101 Switching Protocols" not in response:
|
||||
raise Exception("Failed to connect to WebSocket")
|
||||
|
||||
def write(self, message):
|
||||
message_bytes = message.encode('utf-8')
|
||||
message_bytes = message.encode("utf-8")
|
||||
length = len(message_bytes)
|
||||
if length <= 125:
|
||||
self.socket.sendall(b'\x81' + bytes([length]) + message_bytes)
|
||||
self.socket.sendall(b"\x81" + bytes([length]) + message_bytes)
|
||||
elif length >= 126 and length <= 65535:
|
||||
self.socket.sendall(b'\x81' + bytes([126]) + length.to_bytes(2, 'big') + message_bytes)
|
||||
self.socket.sendall(
|
||||
b"\x81" + bytes([126]) + length.to_bytes(2, "big") + message_bytes
|
||||
)
|
||||
else:
|
||||
self.socket.sendall(b'\x81' + bytes([127]) + length.to_bytes(8, 'big') + message_bytes)
|
||||
|
||||
self.socket.sendall(
|
||||
b"\x81" + bytes([127]) + length.to_bytes(8, "big") + message_bytes
|
||||
)
|
||||
|
||||
def read_until(self, delimiter):
|
||||
def read_until(self, delimiter):
|
||||
while True:
|
||||
find_pos = self.buffer.find(delimiter)
|
||||
if find_pos != -1:
|
||||
data = self.buffer[:find_pos+4]
|
||||
self.buffer = self.buffer[find_pos+4:]
|
||||
return data
|
||||
|
||||
data = self.buffer[: find_pos + 4]
|
||||
self.buffer = self.buffer[find_pos + 4 :]
|
||||
return data
|
||||
|
||||
chunk = self.socket.recv(1024)
|
||||
if not chunk:
|
||||
return None
|
||||
self.buffer += chunk
|
||||
|
||||
|
||||
def read_exactly(self, length):
|
||||
while len(self.buffer) < length:
|
||||
chunk = self.socket.recv(length - len(self.buffer))
|
||||
if not chunk:
|
||||
return None
|
||||
self.buffer += chunk
|
||||
response = self.buffer[: length]
|
||||
self.buffer += chunk
|
||||
response = self.buffer[:length]
|
||||
self.buffer = self.buffer[length:]
|
||||
return response
|
||||
|
||||
def read(self):
|
||||
frame = None
|
||||
frame = None
|
||||
frame = self.read_exactly(2)
|
||||
length = frame[1] & 127
|
||||
if length == 126:
|
||||
length = int.from_bytes(self.read_exactly(2), 'big')
|
||||
length = int.from_bytes(self.read_exactly(2), "big")
|
||||
elif length == 127:
|
||||
length = int.from_bytes(self.read_exactly(8), 'big')
|
||||
length = int.from_bytes(self.read_exactly(8), "big")
|
||||
message = self.read_exactly(length)
|
||||
return message
|
||||
|
||||
|
||||
def close(self):
|
||||
self.socket.close()
|
||||
|
||||
|
||||
|
@ -31,21 +31,17 @@ from multiavatar import multiavatar
|
||||
from snek.system.view import BaseView
|
||||
from snek.view.avatar_animal import generate_avatar_with_options
|
||||
|
||||
import functools
|
||||
|
||||
class AvatarView(BaseView):
|
||||
login_required = False
|
||||
|
||||
def __init__(self, *args,**kwargs):
|
||||
super().__init__(*args,**kwargs)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.avatars = {}
|
||||
|
||||
async def get(self):
|
||||
type_ = self.request.query.get("type")
|
||||
type_match = {
|
||||
"animal": self.get_animal,
|
||||
"default": self.get_default
|
||||
}
|
||||
type_match = {"animal": self.get_animal, "default": self.get_default}
|
||||
handler = type_match.get(type_, self.get_default)
|
||||
return await handler()
|
||||
|
||||
@ -64,10 +60,9 @@ class AvatarView(BaseView):
|
||||
avatar = generate_avatar_with_options(self.request.query)
|
||||
await self.app.set(key, avatar)
|
||||
return web.Response(text=avatar, content_type="image/svg+xml")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _get(self, uid):
|
||||
avatar = generate_avatar_with_options(self.request.query)
|
||||
return avatar
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -31,13 +31,12 @@ from multiavatar import multiavatar
|
||||
from snek.system.view import BaseView
|
||||
from snek.view.avatar_animal import generate_avatar_with_options
|
||||
|
||||
import functools
|
||||
|
||||
class AvatarView(BaseView):
|
||||
login_required = False
|
||||
|
||||
def __init__(self, *args,**kwargs):
|
||||
super().__init__(*args,**kwargs)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.avatars = {}
|
||||
|
||||
async def get(self):
|
||||
@ -45,10 +44,9 @@ class AvatarView(BaseView):
|
||||
while True:
|
||||
try:
|
||||
return web.Response(text=self._get(uid), content_type="image/svg+xml")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _get(self, uid):
|
||||
if uid in self.avatars:
|
||||
return self.avatars[uid]
|
||||
|
@ -1,15 +1,15 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import pathlib
|
||||
import urllib.parse
|
||||
from os.path import isfile
|
||||
|
||||
from PIL import Image
|
||||
import pillow_heif.HeifImagePlugin
|
||||
|
||||
from snek.system.view import BaseView
|
||||
import aiofiles
|
||||
from aiohttp import web
|
||||
import pathlib
|
||||
import urllib.parse
|
||||
from PIL import Image
|
||||
|
||||
from snek.system.view import BaseView
|
||||
|
||||
|
||||
class ChannelAttachmentView(BaseView):
|
||||
async def get(self):
|
||||
@ -73,18 +73,21 @@ class ChannelAttachmentView(BaseView):
|
||||
# response.write is async but image.save is not
|
||||
naughty_steal = response.write
|
||||
tasks = []
|
||||
|
||||
def sync_writer(*args, **kwargs):
|
||||
tasks.append(naughty_steal(*args, **kwargs))
|
||||
return True
|
||||
|
||||
|
||||
setattr(response, "write", sync_writer)
|
||||
|
||||
image.save(response, format=format_, quality=100, optimize=True, save_all=True)
|
||||
image.save(
|
||||
response, format=format_, quality=100, optimize=True, save_all=True
|
||||
)
|
||||
|
||||
setattr(response, "write", naughty_steal)
|
||||
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
return response
|
||||
else:
|
||||
response = web.FileResponse(channel_attachment["path"])
|
||||
@ -127,7 +130,9 @@ class ChannelAttachmentView(BaseView):
|
||||
attachment_records = []
|
||||
for attachment in attachments:
|
||||
attachment_record = attachment.record
|
||||
attachment_record['relative_url'] = urllib.parse.quote(attachment_record['relative_url'])
|
||||
attachment_record["relative_url"] = urllib.parse.quote(
|
||||
attachment_record["relative_url"]
|
||||
)
|
||||
attachment_records.append(attachment_record)
|
||||
|
||||
return web.json_response(
|
||||
@ -138,23 +143,37 @@ class ChannelAttachmentView(BaseView):
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ChannelView(BaseView):
|
||||
async def get(self):
|
||||
channel_name = self.request.match_info.get("channel")
|
||||
if(channel_name is None):
|
||||
if channel_name is None:
|
||||
return web.HTTPNotFound()
|
||||
channel = await self.services.channel.get(label="#" + channel_name)
|
||||
if(channel is None):
|
||||
if channel is None:
|
||||
channel = await self.services.channel.get(label=channel_name)
|
||||
channel = await self.services.channel.get(channel_name)
|
||||
if(channel is None):
|
||||
if channel is None:
|
||||
channel = await self.services.channel.get(label=channel_name)
|
||||
if(channel is None):
|
||||
user = await self.services.user.get(uid=self.session.get("uid"))
|
||||
if channel is None:
|
||||
user = await self.services.user.get(uid=self.session.get("uid"))
|
||||
is_listed = self.request.query.get("listed", False) == "true"
|
||||
is_private = self.request.query.get("private", False) == "true"
|
||||
channel = await self.services.channel.create(label=channel_name,created_by_uid=user['uid'],description="No description provided.",tag="user",is_private=is_private,is_listed=is_listed)
|
||||
channel_member = await self.services.channel_member.create(channel_uid=channel['uid'],user_uid=user['uid'],is_moderator=True,is_read_only=False,is_muted=False,is_banned=False)
|
||||
|
||||
return web.HTTPFound("/channel/{}.html".format(channel["uid"]))
|
||||
channel = await self.services.channel.create(
|
||||
label=channel_name,
|
||||
created_by_uid=user["uid"],
|
||||
description="No description provided.",
|
||||
tag="user",
|
||||
is_private=is_private,
|
||||
is_listed=is_listed,
|
||||
)
|
||||
await self.services.channel_member.create(
|
||||
channel_uid=channel["uid"],
|
||||
user_uid=user["uid"],
|
||||
is_moderator=True,
|
||||
is_read_only=False,
|
||||
is_muted=False,
|
||||
is_banned=False,
|
||||
)
|
||||
|
||||
return web.HTTPFound("/channel/{}.html".format(channel["uid"]))
|
||||
|
@ -1,20 +1,15 @@
|
||||
import mimetypes
|
||||
import os
|
||||
import urllib.parse
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from snek.system.view import BaseView
|
||||
|
||||
|
||||
import os
|
||||
import mimetypes
|
||||
from aiohttp import web
|
||||
from urllib.parse import unquote, quote
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
||||
from aiohttp import web
|
||||
from pathlib import Path
|
||||
import mimetypes, urllib.parse
|
||||
|
||||
class DriveView(BaseView):
|
||||
|
||||
async def get(self):
|
||||
@ -27,7 +22,7 @@ class DriveView(BaseView):
|
||||
return web.HTTPNotFound(reason="Path not found")
|
||||
|
||||
if target.is_dir():
|
||||
return await self.render_template("drive.html",{"path": rel_path})
|
||||
return await self.render_template("drive.html", {"path": rel_path})
|
||||
if target.is_file():
|
||||
return web.FileResponse(target)
|
||||
|
||||
@ -37,7 +32,7 @@ class DriveApiView(BaseView):
|
||||
target = await self.services.user.get_home_folder(self.session.get("uid"))
|
||||
rel = self.request.query.get("path", "")
|
||||
offset = int(self.request.query.get("offset", 0))
|
||||
limit = int(self.request.query.get("limit", 20))
|
||||
limit = int(self.request.query.get("limit", 20))
|
||||
|
||||
if rel:
|
||||
target = target.joinpath(rel)
|
||||
@ -47,58 +42,54 @@ class DriveApiView(BaseView):
|
||||
|
||||
if target.is_dir():
|
||||
entries = []
|
||||
for p in sorted(target.iterdir(), key=lambda p: (p.is_file(), p.name.lower())):
|
||||
for p in sorted(
|
||||
target.iterdir(), key=lambda p: (p.is_file(), p.name.lower())
|
||||
):
|
||||
item_path = (Path(rel) / p.name).as_posix()
|
||||
mime = mimetypes.guess_type(p.name)[0] if p.is_file() else "inode/directory"
|
||||
url = (self.request.url.with_path(f"/drive/{urllib.parse.quote(item_path)}")
|
||||
if p.is_file() else None)
|
||||
entries.append({
|
||||
"name": p.name,
|
||||
"type": "directory" if p.is_dir() else "file",
|
||||
"mimetype": mime,
|
||||
"size": p.stat().st_size if p.is_file() else None,
|
||||
"path": item_path,
|
||||
"url": url,
|
||||
})
|
||||
import json
|
||||
mime = (
|
||||
mimetypes.guess_type(p.name)[0]
|
||||
if p.is_file()
|
||||
else "inode/directory"
|
||||
)
|
||||
url = (
|
||||
self.request.url.with_path(
|
||||
f"/drive/{urllib.parse.quote(item_path)}"
|
||||
)
|
||||
if p.is_file()
|
||||
else None
|
||||
)
|
||||
entries.append(
|
||||
{
|
||||
"name": p.name,
|
||||
"type": "directory" if p.is_dir() else "file",
|
||||
"mimetype": mime,
|
||||
"size": p.stat().st_size if p.is_file() else None,
|
||||
"path": item_path,
|
||||
"url": url,
|
||||
}
|
||||
)
|
||||
import json
|
||||
|
||||
total = len(entries)
|
||||
items = entries[offset:offset+limit]
|
||||
return web.json_response({
|
||||
"items": json.loads(json.dumps(items,default=str)),
|
||||
"pagination": {"offset": offset, "limit": limit, "total": total}
|
||||
})
|
||||
|
||||
items = entries[offset : offset + limit]
|
||||
return web.json_response(
|
||||
{
|
||||
"items": json.loads(json.dumps(items, default=str)),
|
||||
"pagination": {"offset": offset, "limit": limit, "total": total},
|
||||
}
|
||||
)
|
||||
|
||||
url = self.request.url.with_path(f"/drive/{urllib.parse.quote(rel)}")
|
||||
return web.json_response({
|
||||
"name": target.name,
|
||||
"type": "file",
|
||||
"mimetype": mimetypes.guess_type(target.name)[0],
|
||||
"size": target.stat().st_size,
|
||||
"path": rel,
|
||||
"url": str(url),
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"name": target.name,
|
||||
"type": "file",
|
||||
"mimetype": mimetypes.guess_type(target.name)[0],
|
||||
"size": target.stat().st_size,
|
||||
"path": rel,
|
||||
"url": str(url),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class DriveView222(BaseView):
|
||||
@ -124,7 +115,11 @@ class DriveView222(BaseView):
|
||||
entry_path = os.path.join(dir_path, entry)
|
||||
stat = os.stat(entry_path)
|
||||
is_dir = os.path.isdir(entry_path)
|
||||
mimetype = None if is_dir else (mimetypes.guess_type(entry_path)[0] or "application/octet-stream")
|
||||
mimetype = (
|
||||
None
|
||||
if is_dir
|
||||
else (mimetypes.guess_type(entry_path)[0] or "application/octet-stream")
|
||||
)
|
||||
size = stat.st_size if not is_dir else None
|
||||
created_at = datetime.fromtimestamp(stat.st_ctime).isoformat()
|
||||
updated_at = datetime.fromtimestamp(stat.st_mtime).isoformat()
|
||||
@ -155,15 +150,20 @@ class DriveView222(BaseView):
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
paged_entries = entries[start:end]
|
||||
details = [await self.entry_details(full_path, entry, rel_path) for entry in paged_entries]
|
||||
return web.json_response({
|
||||
"path": rel_path,
|
||||
"absolute_url": abs_url,
|
||||
"entries": details,
|
||||
"total": len(entries),
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
})
|
||||
details = [
|
||||
await self.entry_details(full_path, entry, rel_path)
|
||||
for entry in paged_entries
|
||||
]
|
||||
return web.json_response(
|
||||
{
|
||||
"path": rel_path,
|
||||
"absolute_url": abs_url,
|
||||
"entries": details,
|
||||
"total": len(entries),
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
else:
|
||||
with open(full_path, "rb") as f:
|
||||
content = f.read()
|
||||
@ -180,14 +180,18 @@ class DriveView222(BaseView):
|
||||
data = await self.request.post()
|
||||
if data.get("type") == "dir":
|
||||
os.makedirs(full_path)
|
||||
return web.json_response({"status": "created", "type": "dir", "absolute_url": abs_url})
|
||||
return web.json_response(
|
||||
{"status": "created", "type": "dir", "absolute_url": abs_url}
|
||||
)
|
||||
else:
|
||||
file_field = data.get("file")
|
||||
if not file_field:
|
||||
raise web.HTTPBadRequest(reason="No file uploaded")
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_field.file.read())
|
||||
return web.json_response({"status": "created", "type": "file", "absolute_url": abs_url})
|
||||
return web.json_response(
|
||||
{"status": "created", "type": "file", "absolute_url": abs_url}
|
||||
)
|
||||
|
||||
async def put(self):
|
||||
rel_path = self.request.match_info.get("rel_path", "")
|
||||
@ -210,10 +214,14 @@ class DriveView222(BaseView):
|
||||
raise web.HTTPNotFound(reason="Path not found")
|
||||
if os.path.isdir(full_path):
|
||||
os.rmdir(full_path)
|
||||
return web.json_response({"status": "deleted", "type": "dir", "absolute_url": abs_url})
|
||||
return web.json_response(
|
||||
{"status": "deleted", "type": "dir", "absolute_url": abs_url}
|
||||
)
|
||||
else:
|
||||
os.remove(full_path)
|
||||
return web.json_response({"status": "deleted", "type": "file", "absolute_url": abs_url})
|
||||
return web.json_response(
|
||||
{"status": "deleted", "type": "file", "absolute_url": abs_url}
|
||||
)
|
||||
|
||||
|
||||
class DriveViewi2(BaseView):
|
||||
@ -223,20 +231,17 @@ class DriveViewi2(BaseView):
|
||||
async def get(self):
|
||||
|
||||
drive_uid = self.request.match_info.get("drive")
|
||||
|
||||
|
||||
before = self.request.query.get("before")
|
||||
filters = {}
|
||||
filters = {}
|
||||
if before:
|
||||
filters["created_at__lt"] = before
|
||||
|
||||
if drive_uid:
|
||||
filters['drive_uid'] = drive_uid
|
||||
filters["drive_uid"] = drive_uid
|
||||
drive = await self.services.drive.get(uid=drive_uid)
|
||||
drive_items = []
|
||||
|
||||
|
||||
|
||||
|
||||
async for item in self.services.drive_item.find(**filters):
|
||||
record = item.record
|
||||
record["url"] = "/drive.bin/" + record["uid"] + "." + item.extension
|
||||
|
@ -82,4 +82,4 @@ class PushView(BaseFormView):
|
||||
print(post_notification.text)
|
||||
print(post_notification.headers)
|
||||
|
||||
return await self.json_response({ "registered": True })
|
||||
return await self.json_response({"registered": True})
|
||||
|
@ -1,14 +1,13 @@
|
||||
import os
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import os
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
import humanize
|
||||
from aiohttp import web
|
||||
|
||||
from snek.system.view import BaseView
|
||||
import asyncio
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class BareRepoNavigator:
|
||||
@ -25,18 +24,18 @@ class BareRepoNavigator:
|
||||
except Exception as e:
|
||||
print(f"Error opening repository: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
self.repo_path = repo_path
|
||||
self.branches = list(self.repo.branches)
|
||||
self.current_branch = None
|
||||
self.current_commit = None
|
||||
self.current_path = ""
|
||||
self.history = []
|
||||
|
||||
|
||||
def get_branches(self):
|
||||
"""Return a list of branch names in the repository."""
|
||||
return [branch.name for branch in self.branches]
|
||||
|
||||
|
||||
def set_branch(self, branch_name):
|
||||
"""Set the current branch."""
|
||||
try:
|
||||
@ -47,23 +46,27 @@ class BareRepoNavigator:
|
||||
return True
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
|
||||
def get_commits(self, count=10):
|
||||
"""Get the latest commits on the current branch."""
|
||||
if not self.current_branch:
|
||||
return []
|
||||
|
||||
|
||||
commits = []
|
||||
for commit in self.repo.iter_commits(self.current_branch, max_count=count):
|
||||
commits.append({
|
||||
'hash': commit.hexsha,
|
||||
'short_hash': commit.hexsha[:7],
|
||||
'message': commit.message.strip(),
|
||||
'author': commit.author.name,
|
||||
'date': datetime.fromtimestamp(commit.committed_date).strftime('%Y-%m-%d %H:%M:%S')
|
||||
})
|
||||
commits.append(
|
||||
{
|
||||
"hash": commit.hexsha,
|
||||
"short_hash": commit.hexsha[:7],
|
||||
"message": commit.message.strip(),
|
||||
"author": commit.author.name,
|
||||
"date": datetime.fromtimestamp(commit.committed_date).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
}
|
||||
)
|
||||
return commits
|
||||
|
||||
|
||||
def set_commit(self, commit_hash):
|
||||
"""Set the current commit by hash."""
|
||||
try:
|
||||
@ -73,48 +76,48 @@ class BareRepoNavigator:
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def list_directory(self, path=""):
|
||||
"""List the contents of a directory in the current commit."""
|
||||
if not self.current_commit:
|
||||
return {'dirs': [], 'files': []}
|
||||
|
||||
return {"dirs": [], "files": []}
|
||||
|
||||
dirs = []
|
||||
files = []
|
||||
|
||||
|
||||
try:
|
||||
# Get the tree at the current path
|
||||
if path:
|
||||
tree = self.current_commit.tree[path]
|
||||
if not hasattr(tree, 'trees'): # It's a blob, not a tree
|
||||
return {'dirs': [], 'files': [path]}
|
||||
if not hasattr(tree, "trees"): # It's a blob, not a tree
|
||||
return {"dirs": [], "files": [path]}
|
||||
else:
|
||||
tree = self.current_commit.tree
|
||||
|
||||
|
||||
# List directories and files
|
||||
for item in tree:
|
||||
if item.type == 'tree':
|
||||
if item.type == "tree":
|
||||
item_path = os.path.join(path, item.name) if path else item.name
|
||||
dirs.append(item_path)
|
||||
elif item.type == 'blob':
|
||||
elif item.type == "blob":
|
||||
item_path = os.path.join(path, item.name) if path else item.name
|
||||
files.append(item_path)
|
||||
|
||||
|
||||
dirs.sort()
|
||||
files.sort()
|
||||
return {'dirs': dirs, 'files': files}
|
||||
|
||||
return {"dirs": dirs, "files": files}
|
||||
|
||||
except KeyError:
|
||||
return {'dirs': [], 'files': []}
|
||||
|
||||
return {"dirs": [], "files": []}
|
||||
|
||||
def get_file_content(self, file_path):
|
||||
"""Get the content of a file in the current commit."""
|
||||
if not self.current_commit:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
blob = self.current_commit.tree[file_path]
|
||||
return blob.data_stream.read().decode('utf-8', errors='replace')
|
||||
return blob.data_stream.read().decode("utf-8", errors="replace")
|
||||
except (KeyError, UnicodeDecodeError):
|
||||
try:
|
||||
# Try to get as binary if text decoding fails
|
||||
@ -122,12 +125,12 @@ class BareRepoNavigator:
|
||||
return blob.data_stream.read()
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def navigate_to(self, path):
|
||||
"""Navigate to a specific path, updating the current path."""
|
||||
if not self.current_commit:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
if path:
|
||||
self.current_commit.tree[path] # Check if path exists
|
||||
@ -136,7 +139,7 @@ class BareRepoNavigator:
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
|
||||
def navigate_back(self):
|
||||
"""Navigate back to the previous path."""
|
||||
if self.history:
|
||||
@ -145,12 +148,13 @@ class BareRepoNavigator:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
class RepositoryView(BaseView):
|
||||
|
||||
login_required = True
|
||||
|
||||
def checkout_bare_repo(self, bare_repo_path: Path, target_path: Path, ref: str = 'HEAD'):
|
||||
def checkout_bare_repo(
|
||||
self, bare_repo_path: Path, target_path: Path, ref: str = "HEAD"
|
||||
):
|
||||
repo = Repo(bare_repo_path)
|
||||
assert repo.bare, "Repository is not bare."
|
||||
|
||||
@ -159,23 +163,22 @@ class RepositoryView(BaseView):
|
||||
|
||||
for blob in tree.traverse():
|
||||
target_file = target_path / blob.path
|
||||
|
||||
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(blob.path)
|
||||
|
||||
with open(target_file, 'wb') as f:
|
||||
with open(target_file, "wb") as f:
|
||||
f.write(blob.data_stream.read())
|
||||
|
||||
|
||||
async def get(self):
|
||||
|
||||
base_repo_path = Path("drive/repositories")
|
||||
base_repo_path = Path("drive/repositories")
|
||||
|
||||
authenticated_user_id = self.session.get("uid")
|
||||
|
||||
username = self.request.match_info.get('username')
|
||||
repo_name = self.request.match_info.get('repository')
|
||||
rel_path = self.request.match_info.get('path', '')
|
||||
username = self.request.match_info.get("username")
|
||||
repo_name = self.request.match_info.get("repository")
|
||||
rel_path = self.request.match_info.get("path", "")
|
||||
user = None
|
||||
if not username.count("-") == 4:
|
||||
user = await self.app.services.user.get(username=username)
|
||||
@ -185,22 +188,24 @@ class RepositoryView(BaseView):
|
||||
else:
|
||||
user = await self.app.services.user.get(uid=username)
|
||||
|
||||
repo = await self.app.services.repository.get(name=repo_name, user_uid=user["uid"])
|
||||
repo = await self.app.services.repository.get(
|
||||
name=repo_name, user_uid=user["uid"]
|
||||
)
|
||||
if not repo:
|
||||
return web.Response(text="404 Not Found", status=404)
|
||||
if repo['is_private'] and authenticated_user_id != repo['uid']:
|
||||
return web.Response(text="404 Not Found", status=404)
|
||||
if repo["is_private"] and authenticated_user_id != repo["uid"]:
|
||||
return web.Response(text="404 Not Found", status=404)
|
||||
|
||||
repo_root_base = (base_repo_path / user['uid'] / (repo_name + ".git")).resolve()
|
||||
repo_root = (base_repo_path / user['uid'] / repo_name).resolve()
|
||||
repo_root_base = (base_repo_path / user["uid"] / (repo_name + ".git")).resolve()
|
||||
repo_root = (base_repo_path / user["uid"] / repo_name).resolve()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None,
|
||||
self.checkout_bare_repo, repo_root_base, repo_root
|
||||
await loop.run_in_executor(
|
||||
None, self.checkout_bare_repo, repo_root_base, repo_root
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if not repo_root.exists() or not repo_root.is_dir():
|
||||
return web.Response(text="404 Not Found", status=404)
|
||||
|
||||
@ -211,24 +216,35 @@ class RepositoryView(BaseView):
|
||||
return web.Response(text="404 Not Found", status=404)
|
||||
|
||||
if abs_path.is_dir():
|
||||
return web.Response(text=self.render_directory(abs_path, username, repo_name, safe_rel_path), content_type='text/html')
|
||||
return web.Response(
|
||||
text=self.render_directory(
|
||||
abs_path, username, repo_name, safe_rel_path
|
||||
),
|
||||
content_type="text/html",
|
||||
)
|
||||
else:
|
||||
return web.Response(text=self.render_file(abs_path), content_type='text/html')
|
||||
return web.Response(
|
||||
text=self.render_file(abs_path), content_type="text/html"
|
||||
)
|
||||
|
||||
def render_directory(self, abs_path, username, repo_name, safe_rel_path):
|
||||
entries = sorted(abs_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()))
|
||||
entries = sorted(
|
||||
abs_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())
|
||||
)
|
||||
items = []
|
||||
|
||||
if safe_rel_path:
|
||||
parent_path = Path(safe_rel_path).parent
|
||||
parent_link = f"/repository/{username}/{repo_name}/{parent_path}".rstrip('/')
|
||||
parent_link = f"/repository/{username}/{repo_name}/{parent_path}".rstrip(
|
||||
"/"
|
||||
)
|
||||
items.append(f'<li><a href="{parent_link}">⬅️ ..</a></li>')
|
||||
|
||||
for entry in entries:
|
||||
link_path = urllib.parse.quote(str(Path(safe_rel_path) / entry.name))
|
||||
link = f"/repository/{username}/{repo_name}/{link_path}".rstrip('/')
|
||||
display = entry.name + ('/' if entry.is_dir() else '')
|
||||
size = '' if entry.is_dir() else humanize.naturalsize(entry.stat().st_size)
|
||||
link = f"/repository/{username}/{repo_name}/{link_path}".rstrip("/")
|
||||
display = entry.name + ("/" if entry.is_dir() else "")
|
||||
size = "" if entry.is_dir() else humanize.naturalsize(entry.stat().st_size)
|
||||
icon = self.get_icon(entry)
|
||||
items.append(f'<li>{icon} <a href="{link}">{display}</a> {size}</li>')
|
||||
|
||||
@ -247,19 +263,24 @@ class RepositoryView(BaseView):
|
||||
|
||||
def render_file(self, abs_path):
|
||||
try:
|
||||
with open(abs_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
with open(abs_path, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
return f"<pre>{content}</pre>"
|
||||
except Exception as e:
|
||||
return f"<h1>Error</h1><pre>{e}</pre>"
|
||||
|
||||
def get_icon(self, file):
|
||||
if file.is_dir(): return "📁"
|
||||
mime = mimetypes.guess_type(file.name)[0] or ''
|
||||
if mime.startswith("image"): return "🖼️"
|
||||
if mime.startswith("text"): return "📄"
|
||||
if mime.startswith("audio"): return "🎵"
|
||||
if mime.startswith("video"): return "🎬"
|
||||
if file.name.endswith(".py"): return "🐍"
|
||||
if file.is_dir():
|
||||
return "📁"
|
||||
mime = mimetypes.guess_type(file.name)[0] or ""
|
||||
if mime.startswith("image"):
|
||||
return "🖼️"
|
||||
if mime.startswith("text"):
|
||||
return "📄"
|
||||
if mime.startswith("audio"):
|
||||
return "🎵"
|
||||
if mime.startswith("video"):
|
||||
return "🎬"
|
||||
if file.name.endswith(".py"):
|
||||
return "🐍"
|
||||
return "📦"
|
||||
|
||||
|
@ -7,19 +7,20 @@
|
||||
# MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions.
|
||||
|
||||
|
||||
import json
|
||||
import traceback
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from snek.system.model import now
|
||||
from snek.system.profiler import Profiler
|
||||
from snek.system.view import BaseView
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RPCView(BaseView):
|
||||
class RPCApi:
|
||||
def __init__(self, view, ws):
|
||||
@ -32,7 +33,7 @@ class RPCView(BaseView):
|
||||
|
||||
async def _session_ensure(self):
|
||||
uid = await self.view.session_get("uid")
|
||||
if not uid in self.user_session:
|
||||
if uid not in self.user_session:
|
||||
self.user_session[uid] = {
|
||||
"said_hello": False,
|
||||
}
|
||||
@ -50,45 +51,54 @@ class RPCView(BaseView):
|
||||
self._require_login()
|
||||
|
||||
return await self.services.db.insert(self.user_uid, table_name, record)
|
||||
|
||||
async def db_update(self, table_name, record):
|
||||
self._require_login()
|
||||
return await self.services.db.update(self.user_uid, table_name, record)
|
||||
async def set_typing(self,channel_uid,color=None):
|
||||
|
||||
async def set_typing(self, channel_uid, color=None):
|
||||
self._require_login()
|
||||
user = await self.services.user.get(self.user_uid)
|
||||
if not color:
|
||||
color = user["color"]
|
||||
return await self.services.socket.broadcast(channel_uid, {
|
||||
"channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2",
|
||||
"event": "set_typing",
|
||||
"data": {
|
||||
"event":"set_typing",
|
||||
"user_uid": user['uid'],
|
||||
"username": user["username"],
|
||||
"nick": user["nick"],
|
||||
"channel_uid": channel_uid,
|
||||
"color": color
|
||||
}
|
||||
})
|
||||
return await self.services.socket.broadcast(
|
||||
channel_uid,
|
||||
{
|
||||
"channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2",
|
||||
"event": "set_typing",
|
||||
"data": {
|
||||
"event": "set_typing",
|
||||
"user_uid": user["uid"],
|
||||
"username": user["username"],
|
||||
"nick": user["nick"],
|
||||
"channel_uid": channel_uid,
|
||||
"color": color,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
async def db_delete(self, table_name, record):
|
||||
self._require_login()
|
||||
return await self.services.db.delete(self.user_uid, table_name, record)
|
||||
|
||||
async def db_get(self, table_name, record):
|
||||
self._require_login()
|
||||
return await self.services.db.get(self.user_uid, table_name, record)
|
||||
|
||||
async def db_find(self, table_name, record):
|
||||
self._require_login()
|
||||
return await self.services.db.find(self.user_uid, table_name, record)
|
||||
async def db_upsert(self, table_name, record,keys):
|
||||
|
||||
async def db_upsert(self, table_name, record, keys):
|
||||
self._require_login()
|
||||
return await self.services.db.upsert(self.user_uid, table_name, record,keys)
|
||||
return await self.services.db.upsert(
|
||||
self.user_uid, table_name, record, keys
|
||||
)
|
||||
|
||||
async def db_query(self, table_name, args):
|
||||
self._require_login()
|
||||
return await self.services.db.query(self.user_uid, table_name, sql, args)
|
||||
|
||||
|
||||
@property
|
||||
def user_uid(self):
|
||||
return self.view.session.get("uid")
|
||||
@ -195,52 +205,60 @@ class RPCView(BaseView):
|
||||
|
||||
async def finalize_message(self, message_uid):
|
||||
self._require_login()
|
||||
message = await self.services.channel_message.get(message_uid)
|
||||
message = await self.services.channel_message.get(message_uid)
|
||||
if not message:
|
||||
return False
|
||||
return False
|
||||
|
||||
if message["user_uid"] != self.user_uid:
|
||||
raise Exception("Not allowed")
|
||||
|
||||
|
||||
if not message['is_final']:
|
||||
await self.services.chat.finalize(message['uid'])
|
||||
|
||||
if not message["is_final"]:
|
||||
await self.services.chat.finalize(message["uid"])
|
||||
|
||||
return True
|
||||
|
||||
async def update_message_text(self,message_uid, text):
|
||||
async def update_message_text(self, message_uid, text):
|
||||
self._require_login()
|
||||
message = await self.services.channel_message.get(message_uid)
|
||||
message = await self.services.channel_message.get(message_uid)
|
||||
if message["user_uid"] != self.user_uid:
|
||||
raise Exception("Not allowed")
|
||||
|
||||
|
||||
if message.get_seconds_since_last_update() > 3:
|
||||
await self.finalize_message(message["uid"])
|
||||
return {"error": "Message too old","seconds_since_last_update": message.get_seconds_since_last_update(),"success": False}
|
||||
return {
|
||||
"error": "Message too old",
|
||||
"seconds_since_last_update": message.get_seconds_since_last_update(),
|
||||
"success": False,
|
||||
}
|
||||
|
||||
message['message'] = text
|
||||
message["message"] = text
|
||||
if not text:
|
||||
message['deleted_at'] = now()
|
||||
message["deleted_at"] = now()
|
||||
else:
|
||||
message['deleted_at'] = None
|
||||
message["deleted_at"] = None
|
||||
|
||||
await self.services.channel_message.save(message)
|
||||
data = message.record
|
||||
data['text'] = message["message"]
|
||||
data['message_uid'] = message_uid
|
||||
data = message.record
|
||||
data["text"] = message["message"]
|
||||
data["message_uid"] = message_uid
|
||||
|
||||
await self.services.socket.broadcast(
|
||||
message["channel_uid"],
|
||||
{
|
||||
"channel_uid": message["channel_uid"],
|
||||
"event": "update_message_text",
|
||||
"data": message.record,
|
||||
},
|
||||
)
|
||||
|
||||
await self.services.socket.broadcast(message["channel_uid"], {
|
||||
"channel_uid": message["channel_uid"],
|
||||
"event": "update_message_text",
|
||||
"data": message.record
|
||||
})
|
||||
|
||||
return {"success": True}
|
||||
|
||||
async def send_message(self, channel_uid, message,is_final=True):
|
||||
async def send_message(self, channel_uid, message, is_final=True):
|
||||
self._require_login()
|
||||
message = await self.services.chat.send(self.user_uid, channel_uid, message,is_final)
|
||||
|
||||
message = await self.services.chat.send(
|
||||
self.user_uid, channel_uid, message, is_final
|
||||
)
|
||||
|
||||
return message["uid"]
|
||||
|
||||
async def echo(self, *args):
|
||||
@ -330,7 +348,8 @@ class RPCView(BaseView):
|
||||
self._require_login()
|
||||
|
||||
results = [
|
||||
record async for record in self.services.channel.get_recent_users(channel_uid)
|
||||
record
|
||||
async for record in self.services.channel.get_recent_users(channel_uid)
|
||||
]
|
||||
results = sorted(results, key=lambda x: x["nick"])
|
||||
return results
|
||||
@ -371,7 +390,6 @@ class RPCView(BaseView):
|
||||
await self.services.socket.send_to_user(self.user_uid, call)
|
||||
self._scheduled.remove(call)
|
||||
|
||||
|
||||
async def ping(self, callId, *args):
|
||||
if self.user_uid:
|
||||
user = await self.services.user.get(uid=self.user_uid)
|
||||
@ -379,17 +397,23 @@ class RPCView(BaseView):
|
||||
await self.services.user.save(user)
|
||||
return {"pong": args}
|
||||
|
||||
|
||||
async def stars_render(self, channel_uid, message):
|
||||
|
||||
|
||||
for user in await self.get_online_users(channel_uid):
|
||||
try:
|
||||
await self.services.socket.send_to_user(user['uid'], dict(event="stars_render", data={"channel_uid": channel_uid, "message":message}))
|
||||
await self.services.socket.send_to_user(
|
||||
user["uid"],
|
||||
{
|
||||
"event": "stars_render",
|
||||
"data": {"channel_uid": channel_uid, "message": message},
|
||||
},
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
||||
|
||||
async def get(self):
|
||||
scheduled = []
|
||||
|
||||
async def schedule(uid, seconds, call):
|
||||
scheduled.append(call)
|
||||
await asyncio.sleep(seconds)
|
||||
@ -408,16 +432,18 @@ class RPCView(BaseView):
|
||||
await self.services.socket.subscribe(
|
||||
ws, subscription["channel_uid"], self.request.session.get("uid")
|
||||
)
|
||||
if not scheduled and self.request.app.uptime_seconds < 5:
|
||||
await schedule(self.request.session.get("uid"),0,{"event":"refresh", "data": {
|
||||
"message": "Finishing deployment"}
|
||||
}
|
||||
)
|
||||
await schedule(self.request.session.get("uid"),15,{"event": "deployed", "data": {
|
||||
"uptime": self.request.app.uptime}
|
||||
}
|
||||
if not scheduled and self.request.app.uptime_seconds < 5:
|
||||
await schedule(
|
||||
self.request.session.get("uid"),
|
||||
0,
|
||||
{"event": "refresh", "data": {"message": "Finishing deployment"}},
|
||||
)
|
||||
|
||||
await schedule(
|
||||
self.request.session.get("uid"),
|
||||
15,
|
||||
{"event": "deployed", "data": {"uptime": self.request.app.uptime}},
|
||||
)
|
||||
|
||||
rpc = RPCView.RPCApi(self, ws)
|
||||
async for msg in ws:
|
||||
if msg.type == web.WSMsgType.TEXT:
|
||||
|
@ -1,45 +1,48 @@
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
|
||||
from snek.system.view import BaseFormView
|
||||
import pathlib
|
||||
|
||||
|
||||
class ContainersIndexView(BaseFormView):
|
||||
|
||||
login_required = True
|
||||
|
||||
async def get(self):
|
||||
|
||||
|
||||
user_uid = self.session.get("uid")
|
||||
|
||||
|
||||
containers = []
|
||||
async for container in self.services.container.find(user_uid=user_uid):
|
||||
containers.append(container.record)
|
||||
|
||||
|
||||
user = await self.services.user.get(uid=self.session.get("uid"))
|
||||
|
||||
return await self.render_template("settings/containers/index.html", {"containers": containers, "user": user})
|
||||
return await self.render_template(
|
||||
"settings/containers/index.html", {"containers": containers, "user": user}
|
||||
)
|
||||
|
||||
|
||||
class ContainersCreateView(BaseFormView):
|
||||
|
||||
login_required = True
|
||||
|
||||
async def get(self):
|
||||
|
||||
|
||||
return await self.render_template("settings/containers/create.html")
|
||||
|
||||
async def post(self):
|
||||
data = await self.request.post()
|
||||
container = await self.services.container.create(
|
||||
await self.services.container.create(
|
||||
user_uid=self.session.get("uid"),
|
||||
name=data['name'],
|
||||
status=data['status'],
|
||||
resources=data.get('resources', ''),
|
||||
path=data.get('path', ''),
|
||||
readonly=bool(data.get('readonly', False))
|
||||
name=data["name"],
|
||||
status=data["status"],
|
||||
resources=data.get("resources", ""),
|
||||
path=data.get("path", ""),
|
||||
readonly=bool(data.get("readonly", False)),
|
||||
)
|
||||
return web.HTTPFound("/settings/containers/index.html")
|
||||
|
||||
|
||||
class ContainersUpdateView(BaseFormView):
|
||||
|
||||
login_required = True
|
||||
@ -51,40 +54,43 @@ class ContainersUpdateView(BaseFormView):
|
||||
)
|
||||
if not container:
|
||||
return web.HTTPNotFound()
|
||||
return await self.render_template("settings/containers/update.html", {"container": container.record})
|
||||
return await self.render_template(
|
||||
"settings/containers/update.html", {"container": container.record}
|
||||
)
|
||||
|
||||
async def post(self):
|
||||
data = await self.request.post()
|
||||
container = await self.services.container.get(
|
||||
user_uid=self.session.get("uid"), uid=self.request.match_info["uid"]
|
||||
)
|
||||
container['status'] = data['status']
|
||||
container['resources'] = data.get('resources', '')
|
||||
container['path'] = data.get('path', '')
|
||||
container['readonly'] = bool(data.get('readonly', False))
|
||||
container["status"] = data["status"]
|
||||
container["resources"] = data.get("resources", "")
|
||||
container["path"] = data.get("path", "")
|
||||
container["readonly"] = bool(data.get("readonly", False))
|
||||
await self.services.container.save(container)
|
||||
return web.HTTPFound("/settings/containers/index.html")
|
||||
|
||||
|
||||
class ContainersDeleteView(BaseFormView):
|
||||
|
||||
login_required = True
|
||||
|
||||
async def get(self):
|
||||
|
||||
|
||||
container = await self.services.container.get(
|
||||
user_uid=self.session.get("uid"), uid=self.request.match_info["uid"]
|
||||
)
|
||||
if not container:
|
||||
return web.HTTPNotFound()
|
||||
|
||||
return await self.render_template("settings/containers/delete.html", {"container": container.record})
|
||||
return await self.render_template(
|
||||
"settings/containers/delete.html", {"container": container.record}
|
||||
)
|
||||
|
||||
async def post(self):
|
||||
user_uid = self.session.get("uid")
|
||||
uid = self.request.match_info["uid"]
|
||||
container = await self.services.container.get(
|
||||
user_uid=user_uid, uid=uid
|
||||
)
|
||||
container = await self.services.container.get(user_uid=user_uid, uid=uid)
|
||||
if not container:
|
||||
return web.HTTPNotFound()
|
||||
await self.services.container.delete(user_uid=user_uid, uid=uid)
|
||||
|
@ -1,26 +1,26 @@
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
|
||||
from snek.system.view import BaseFormView
|
||||
import pathlib
|
||||
|
||||
|
||||
class RepositoriesIndexView(BaseFormView):
|
||||
|
||||
login_required = True
|
||||
|
||||
async def get(self):
|
||||
|
||||
|
||||
user_uid = self.session.get("uid")
|
||||
|
||||
|
||||
repositories = []
|
||||
async for repository in self.services.repository.find(user_uid=user_uid):
|
||||
repositories.append(repository.record)
|
||||
|
||||
|
||||
user = await self.services.user.get(uid=self.session.get("uid"))
|
||||
|
||||
return await self.render_template("settings/repositories/index.html", {"repositories": repositories, "user": user})
|
||||
|
||||
|
||||
return await self.render_template(
|
||||
"settings/repositories/index.html",
|
||||
{"repositories": repositories, "user": user},
|
||||
)
|
||||
|
||||
|
||||
class RepositoriesCreateView(BaseFormView):
|
||||
@ -28,14 +28,19 @@ class RepositoriesCreateView(BaseFormView):
|
||||
login_required = True
|
||||
|
||||
async def get(self):
|
||||
|
||||
|
||||
return await self.render_template("settings/repositories/create.html")
|
||||
|
||||
async def post(self):
|
||||
data = await self.request.post()
|
||||
repository = await self.services.repository.create(user_uid=self.session.get("uid"), name=data['name'], is_private=int(data.get('is_private',0)))
|
||||
await self.services.repository.create(
|
||||
user_uid=self.session.get("uid"),
|
||||
name=data["name"],
|
||||
is_private=int(data.get("is_private", 0)),
|
||||
)
|
||||
return web.HTTPFound("/settings/repositories/index.html")
|
||||
|
||||
|
||||
class RepositoriesUpdateView(BaseFormView):
|
||||
|
||||
login_required = True
|
||||
@ -47,40 +52,41 @@ class RepositoriesUpdateView(BaseFormView):
|
||||
)
|
||||
if not repository:
|
||||
return web.HTTPNotFound()
|
||||
return await self.render_template("settings/repositories/update.html", {"repository": repository.record})
|
||||
return await self.render_template(
|
||||
"settings/repositories/update.html", {"repository": repository.record}
|
||||
)
|
||||
|
||||
async def post(self):
|
||||
data = await self.request.post()
|
||||
repository = await self.services.repository.get(
|
||||
user_uid=self.session.get("uid"), name=self.request.match_info["name"]
|
||||
)
|
||||
repository['is_private'] = int(data.get('is_private',0))
|
||||
repository["is_private"] = int(data.get("is_private", 0))
|
||||
await self.services.repository.save(repository)
|
||||
return web.HTTPFound("/settings/repositories/index.html")
|
||||
|
||||
|
||||
class RepositoriesDeleteView(BaseFormView):
|
||||
|
||||
login_required = True
|
||||
|
||||
async def get(self):
|
||||
|
||||
|
||||
repository = await self.services.repository.get(
|
||||
user_uid=self.session.get("uid"), name=self.request.match_info["name"]
|
||||
)
|
||||
if not repository:
|
||||
return web.HTTPNotFound()
|
||||
|
||||
return await self.render_template("settings/repositories/delete.html", {"repository": repository.record})
|
||||
return await self.render_template(
|
||||
"settings/repositories/delete.html", {"repository": repository.record}
|
||||
)
|
||||
|
||||
async def post(self):
|
||||
user_uid = self.session.get("uid")
|
||||
name = self.request.match_info["name"]
|
||||
repository = await self.services.repository.get(
|
||||
user_uid=user_uid, name=name
|
||||
)
|
||||
repository = await self.services.repository.get(user_uid=user_uid, name=name)
|
||||
if not repository:
|
||||
return web.HTTPNotFound()
|
||||
await self.services.repository.delete(user_uid=user_uid, name=name)
|
||||
return web.HTTPFound("/settings/repositories/index.html")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user