234 lines
8.9 KiB
Python
234 lines
8.9 KiB
Python
|
import asyncio
|
||
|
import getpass
|
||
|
import sys
|
||
|
import logging
|
||
|
import traceback
|
||
|
from typing import Dict, Any
|
||
|
|
||
|
import aiohttp
|
||
|
|
||
|
from snekbot.rpc import RPC
|
||
|
from snekbot.bot import Bot
|
||
|
|
||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
logger: logging.Logger = logging.getLogger("snek_cli")
|
||
|
|
||
|
|
||
|
def markdown_to_ansi(text: str) -> str:
|
||
|
import re
|
||
|
|
||
|
def heading_sub(match: re.Match) -> str:
|
||
|
hashes = match.group(1)
|
||
|
content = match.group(2).strip()
|
||
|
level = len(hashes)
|
||
|
if level == 1:
|
||
|
return f"\033[1;4m{content}\033[0m" # Bold + Underline
|
||
|
elif level == 2:
|
||
|
return f"\033[1m{content}\033[0m"
|
||
|
elif level == 3:
|
||
|
return f"\033[1;36m{content}\033[0m"
|
||
|
elif level == 4:
|
||
|
return f"\033[36m{content}\033[0m"
|
||
|
elif level == 5:
|
||
|
return f"\033[2;36m{content}\033[0m"
|
||
|
else:
|
||
|
return content
|
||
|
|
||
|
text = re.sub(r"^(#{1,5})\s+(.+)$", heading_sub, text, flags=re.MULTILINE)
|
||
|
|
||
|
text = re.sub(r"^\s*---+\s*$", "\033[38;5;244m" + "─" * 40 + "\033[0m", text, flags=re.MULTILINE)
|
||
|
|
||
|
def ordered_list_sub(match: re.Match) -> str:
|
||
|
idx = match.group(1)
|
||
|
content = match.group(2)
|
||
|
return f"\033[38;5;33m{idx}.\033[0m {content}"
|
||
|
|
||
|
text = re.sub(r"^(\d+)\.\s+(.*)$", ordered_list_sub, text, flags=re.MULTILINE)
|
||
|
|
||
|
def unordered_list_sub(match: re.Match) -> str:
|
||
|
bullet = match.group(1)
|
||
|
content = match.group(2)
|
||
|
return f"\033[38;5;33m•\033[0m {content}"
|
||
|
|
||
|
text = re.sub(r"^(\s*[-*+])\s+(.*)$", unordered_list_sub, text, flags=re.MULTILINE)
|
||
|
|
||
|
text = re.sub(r"`([^`]+)`", r"\033[38;5;244m\033[48;5;236m \1 \033[0m", text)
|
||
|
|
||
|
text = re.sub(r"\*\*(.*?)\*\*", r"\033[1m\1\033[0m", text)
|
||
|
|
||
|
text = re.sub(r"\*(.*?)\*", r"\033[3m\1\033[0m", text)
|
||
|
|
||
|
return text
|
||
|
|
||
|
|
||
|
class CliClient(Bot):
|
||
|
active_channel_uid: Any = None
|
||
|
_channel_cache: Dict[Any, Any] = {}
|
||
|
|
||
|
def _display(self, message: str) -> None:
|
||
|
ansi_message = markdown_to_ansi(message)
|
||
|
sys.stdout.write('\r' + ' ' * 80 + '\r')
|
||
|
sys.stdout.write(ansi_message + '\n')
|
||
|
sys.stdout.write(f'[{self._get_active_channel_name()}]> ')
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
def _get_active_channel_name(self) -> str:
|
||
|
if self.active_channel_uid and self.active_channel_uid in self._channel_cache:
|
||
|
return self._channel_cache[self.active_channel_uid]['name']
|
||
|
return "No Channel"
|
||
|
|
||
|
async def send_message(self, channel_uid: Any, message: str) -> bool:
|
||
|
if self.ws and not self.ws.closed:
|
||
|
payload: Dict[str, Any] = {
|
||
|
"method": "send_message",
|
||
|
"args": [channel_uid, message, True],
|
||
|
"kwargs": {},
|
||
|
"callId": None
|
||
|
}
|
||
|
await self.ws.send_json(payload)
|
||
|
return True
|
||
|
logger.error("Cannot send message, WebSocket is not connected.")
|
||
|
return False
|
||
|
|
||
|
async def _network_loop(self) -> None:
|
||
|
try:
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
async with session.ws_connect(self.url) as ws:
|
||
|
self.ws = ws
|
||
|
self.rpc = RPC(self.ws)
|
||
|
await self.rpc.login(self.username, self.password)
|
||
|
self.user: Dict[str, Any] = await self.rpc.get_user(None)
|
||
|
await self.on_init()
|
||
|
while not self.ws.closed:
|
||
|
data = await self.rpc.receive()
|
||
|
if not data:
|
||
|
break
|
||
|
event: str = "?"
|
||
|
try:
|
||
|
event = data.event
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
try:
|
||
|
message: str = data.message.strip()
|
||
|
event = "message"
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
if event == "?":
|
||
|
continue
|
||
|
elif event == "message":
|
||
|
if not data.is_final:
|
||
|
continue
|
||
|
if data.username == self.user["username"]:
|
||
|
continue
|
||
|
else:
|
||
|
await self.on_message(data.username, data.user_nick, data.channel_uid, message)
|
||
|
else:
|
||
|
try:
|
||
|
await getattr(self, "on_" + data.event)(**data.data)
|
||
|
except AttributeError:
|
||
|
logger.debug("Not implemented event: " + event)
|
||
|
except Exception:
|
||
|
logger.error("Network loop disconnected with an error.")
|
||
|
traceback.print_exc()
|
||
|
|
||
|
async def run(self) -> None:
|
||
|
network_task = asyncio.create_task(self._network_loop())
|
||
|
while not self.user:
|
||
|
if network_task.done():
|
||
|
network_task.result()
|
||
|
return
|
||
|
await asyncio.sleep(0.1)
|
||
|
input_task = asyncio.create_task(self._input_loop())
|
||
|
done, pending = await asyncio.wait(
|
||
|
[network_task, input_task],
|
||
|
return_when=asyncio.FIRST_COMPLETED
|
||
|
)
|
||
|
for task in pending:
|
||
|
task.cancel()
|
||
|
|
||
|
async def on_init(self) -> None:
|
||
|
try:
|
||
|
logger.info("Successfully logged in as %s.", self.user['username'])
|
||
|
except KeyError:
|
||
|
raise ValueError("\033[91mUsername or password incorrect!\033[0m")
|
||
|
|
||
|
channels = await self.get_channels()
|
||
|
for channel in channels:
|
||
|
self._channel_cache[channel['uid']] = channel
|
||
|
self._display(f"**Welcome, {self.user['nick']}!** Your available commands are:")
|
||
|
self._display("> /join <channel_name>, /channels, /quit")
|
||
|
await self.on_idle()
|
||
|
|
||
|
async def on_message(self, username: str, user_nick: str, channel_uid: Any, message: str) -> None:
|
||
|
sys.stdout.write('\a')
|
||
|
sys.stdout.flush()
|
||
|
channel_name = self._channel_cache.get(channel_uid, {}).get('name', 'unknown')
|
||
|
self._display(f"**[{channel_name}]** *<{user_nick}>*: {message}")
|
||
|
|
||
|
async def on_own_message(self, channel_uid: Any, message: str) -> None:
|
||
|
channel_name = self._channel_cache.get(channel_uid, {}).get('name', 'unknown')
|
||
|
full_message = f"**[{channel_name}]** *<{self.user['nick']}>*: {message}"
|
||
|
self._display(full_message)
|
||
|
|
||
|
async def on_idle(self) -> None:
|
||
|
sys.stdout.write(f'[{self._get_active_channel_name()}]> ')
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
async def _handle_command(self, line: str) -> bool:
|
||
|
command, *args = line.strip().split()
|
||
|
if command == "/join":
|
||
|
if not args:
|
||
|
self._display("> Usage: `/join <channel_name>`")
|
||
|
return True
|
||
|
channel_name_to_join = args[0]
|
||
|
for uid, channel_data in self._channel_cache.items():
|
||
|
if channel_data['name'] == channel_name_to_join:
|
||
|
self.active_channel_uid = uid
|
||
|
self._display(f"*Active channel set to* **'{channel_name_to_join}'**.")
|
||
|
return True
|
||
|
self._display(f"*** Error: Channel '{channel_name_to_join}' not found.")
|
||
|
elif command == "/channels":
|
||
|
names = [ch['name'] for ch in self._channel_cache.values()]
|
||
|
self._display(f"*Available channels*: `{'`, `'.join(names)}`")
|
||
|
elif command == "/quit":
|
||
|
self._display("*Disconnecting. Goodbye!*")
|
||
|
return False
|
||
|
else:
|
||
|
self._display(f"*** Unknown command: {command}")
|
||
|
return True
|
||
|
|
||
|
async def _input_loop(self) -> None:
|
||
|
loop = asyncio.get_event_loop()
|
||
|
try:
|
||
|
while True:
|
||
|
line = await loop.run_in_executor(None, sys.stdin.readline)
|
||
|
line = line.strip()
|
||
|
if not line:
|
||
|
continue
|
||
|
if line.startswith("/"):
|
||
|
if not await self._handle_command(line):
|
||
|
break
|
||
|
else:
|
||
|
if not self.active_channel_uid:
|
||
|
self._display("> No active channel. Use `/join <channel_name>` to start talking.")
|
||
|
continue
|
||
|
await self.send_message(self.active_channel_uid, line)
|
||
|
await self.on_own_message(self.active_channel_uid, line)
|
||
|
finally:
|
||
|
if self.ws and not self.ws.closed:
|
||
|
await self.ws.close()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
print("--- Snekbot CLI Client ---")
|
||
|
url = "wss://snek.molodetz.nl/rpc.ws"
|
||
|
username = input("Username: ")
|
||
|
password = getpass.getpass("Password: ")
|
||
|
client = CliClient(username, password, url=url)
|
||
|
try:
|
||
|
asyncio.run(client.run())
|
||
|
except KeyboardInterrupt:
|
||
|
print("\nClient shut down by user.")
|
||
|
except Exception as e:
|
||
|
logger.error(f"An unexpected error occurred: {e}")
|