From 52df3887a6b1e9230427b0e8cb684894206993b5 Mon Sep 17 00:00:00 2001 From: retoor Date: Wed, 13 Aug 2025 00:06:44 +0200 Subject: [PATCH] Update. --- .gitignore | 2 + examples/crawler/crawler.py | 112 ++-- examples/crawler/database.py | 56 +- examples/crawler/main.py | 19 +- examples/princess/ads.py | 1187 +++++++++++++++++++++++++++++++++ examples/princess/grk.py | 122 ++++ examples/princess/princess.py | 135 ++++ src/devranta/api.py | 65 +- src/devranta/api_plain.py | 117 ++-- src/devranta/api_requests.py | 45 +- test.py | 244 ++++--- 11 files changed, 1860 insertions(+), 244 deletions(-) create mode 100644 examples/princess/ads.py create mode 100644 examples/princess/grk.py create mode 100644 examples/princess/princess.py diff --git a/.gitignore b/.gitignore index 701e9f3..1a9f6cf 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ .history __pycache__ *.pyc +.env +*.db examples/crawler/devrant.sqlite-shm examples/crawler/devrant.sqlite-wal examples/crawler/devrant.sqlite diff --git a/examples/crawler/crawler.py b/examples/crawler/crawler.py index cfa39d8..e6fc14f 100644 --- a/examples/crawler/crawler.py +++ b/examples/crawler/crawler.py @@ -1,11 +1,16 @@ import asyncio import logging from typing import Set -from devranta.api import Api, Rant + from database import DatabaseManager +from devranta.api import Api, Rant + + class DevRantCrawler: - def __init__(self, api: Api, db: DatabaseManager, rant_consumers: int, user_consumers: int): + def __init__( + self, api: Api, db: DatabaseManager, rant_consumers: int, user_consumers: int + ): self.api = api self.db = db self.rant_queue = asyncio.Queue(maxsize=1000000) @@ -18,23 +23,29 @@ class DevRantCrawler: self.seen_rant_ids: Set[int] = set() self.seen_user_ids: Set[int] = set() self.stats = { - "rants_processed": 0, "rants_added_to_db": 0, - "comments_added_to_db": 0, "users_processed": 0, "users_added_to_db": 0, - "api_errors": 0, "producer_loops": 0, "end_of_feed_hits": 0, - "rants_queued": 0, "users_queued": 0 + "rants_processed": 0, + "rants_added_to_db": 0, + "comments_added_to_db": 0, + "users_processed": 0, + "users_added_to_db": 0, + "api_errors": 0, + "producer_loops": 0, + "end_of_feed_hits": 0, + "rants_queued": 0, + "users_queued": 0, } async def _queue_user_if_new(self, user_id: int): if user_id in self.seen_user_ids: return - + self.seen_user_ids.add(user_id) if not await self.db.user_exists(user_id): await self.user_queue.put(user_id) self.stats["users_queued"] += 1 - + async def _queue_rant_if_new(self, rant_obj: Rant): - rant_id = rant_obj['id'] + rant_id = rant_obj["id"] if rant_id in self.seen_rant_ids: return @@ -49,52 +60,64 @@ class DevRantCrawler: logging.info("Starting initial seeder to re-ignite crawling process...") user_ids = await self.db.get_random_user_ids(limit=2000) if not user_ids: - logging.info("Seeder found no existing users. Crawler will start from scratch.") + logging.info( + "Seeder found no existing users. Crawler will start from scratch." + ) return - + for user_id in user_ids: if user_id not in self.seen_user_ids: self.seen_user_ids.add(user_id) await self.user_queue.put(user_id) self.stats["users_queued"] += 1 - logging.info(f"Seeder finished: Queued {len(user_ids)} users to kickstart exploration.") + logging.info( + f"Seeder finished: Queued {len(user_ids)} users to kickstart exploration." + ) async def _rant_producer(self): logging.info("Rant producer started.") skip = 0 consecutive_empty_responses = 0 - + while not self.shutdown_event.is_set(): try: logging.info(f"Producer: Fetching rants with skip={skip}...") rants = await self.api.get_rants(sort="recent", limit=50, skip=skip) self.stats["producer_loops"] += 1 - + if not rants: consecutive_empty_responses += 1 - logging.info(f"Producer: Feed returned empty. Consecutive empty hits: {consecutive_empty_responses}.") + logging.info( + f"Producer: Feed returned empty. Consecutive empty hits: {consecutive_empty_responses}." + ) if consecutive_empty_responses >= 5: self.stats["end_of_feed_hits"] += 1 - logging.info("Producer: End of feed likely reached. Pausing for 15 minutes before reset.") + logging.info( + "Producer: End of feed likely reached. Pausing for 15 minutes before reset." + ) await asyncio.sleep(900) skip = 0 consecutive_empty_responses = 0 else: await asyncio.sleep(10) continue - + consecutive_empty_responses = 0 new_rants_found = 0 for rant in rants: await self._queue_rant_if_new(rant) new_rants_found += 1 - - logging.info(f"Producer: Processed {new_rants_found} rants from feed. Total queued: {self.stats['rants_queued']}.") + + logging.info( + f"Producer: Processed {new_rants_found} rants from feed. Total queued: {self.stats['rants_queued']}." + ) skip += len(rants) await asyncio.sleep(2) except Exception as e: - logging.critical(f"Producer: Unhandled exception: {e}. Retrying in 60s.") + logging.critical( + f"Producer: Unhandled exception: {e}. Retrying in 60s." + ) self.stats["api_errors"] += 1 await asyncio.sleep(60) @@ -103,23 +126,29 @@ class DevRantCrawler: while not self.shutdown_event.is_set(): try: rant_id = await self.rant_queue.get() - logging.info(f"Rant consumer #{worker_id}: Processing rant ID {rant_id}.") - + logging.info( + f"Rant consumer #{worker_id}: Processing rant ID {rant_id}." + ) + rant_details = await self.api.get_rant(rant_id) if not rant_details or not rant_details.get("success"): - logging.warning(f"Rant consumer #{worker_id}: Failed to fetch details for rant {rant_id}.") + logging.warning( + f"Rant consumer #{worker_id}: Failed to fetch details for rant {rant_id}." + ) self.rant_queue.task_done() continue - await self._queue_user_if_new(rant_details['rant']['user_id']) - + await self._queue_user_if_new(rant_details["rant"]["user_id"]) + comments = rant_details.get("comments", []) for comment in comments: await self.db.add_comment(comment) self.stats["comments_added_to_db"] += 1 - await self._queue_user_if_new(comment['user_id']) - - logging.info(f"Rant consumer #{worker_id}: Finished processing rant {rant_id}, found {len(comments)} comments.") + await self._queue_user_if_new(comment["user_id"]) + + logging.info( + f"Rant consumer #{worker_id}: Finished processing rant {rant_id}, found {len(comments)} comments." + ) self.stats["rants_processed"] += 1 self.rant_queue.task_done() @@ -132,17 +161,21 @@ class DevRantCrawler: while not self.shutdown_event.is_set(): try: user_id = await self.user_queue.get() - logging.info(f"User consumer #{worker_id}: Processing user ID {user_id}.") - + logging.info( + f"User consumer #{worker_id}: Processing user ID {user_id}." + ) + profile = await self.api.get_profile(user_id) if not profile: - logging.warning(f"User consumer #{worker_id}: Could not fetch profile for user {user_id}.") + logging.warning( + f"User consumer #{worker_id}: Could not fetch profile for user {user_id}." + ) self.user_queue.task_done() continue await self.db.add_user(profile, user_id) self.stats["users_added_to_db"] += 1 - + rants_found_on_profile = 0 content_sections = profile.get("content", {}).get("content", {}) for section_name in ["rants", "upvoted", "favorites"]: @@ -150,13 +183,15 @@ class DevRantCrawler: await self._queue_rant_if_new(rant_obj) rants_found_on_profile += 1 - logging.info(f"User consumer #{worker_id}: Finished user {user_id}, found and queued {rants_found_on_profile} associated rants.") + logging.info( + f"User consumer #{worker_id}: Finished user {user_id}, found and queued {rants_found_on_profile} associated rants." + ) self.stats["users_processed"] += 1 self.user_queue.task_done() except Exception as e: logging.error(f"User consumer #{worker_id}: Unhandled exception: {e}") self.user_queue.task_done() - + async def _stats_reporter(self): logging.info("Stats reporter started.") while not self.shutdown_event.is_set(): @@ -172,7 +207,7 @@ class DevRantCrawler: async def run(self): logging.info("Exhaustive crawler starting...") await self._initial_seed() - + logging.info("Starting main producer and consumer tasks...") tasks = [] try: @@ -181,7 +216,7 @@ class DevRantCrawler: for i in range(self.num_rant_consumers): tasks.append(asyncio.create_task(self._rant_consumer(i + 1))) - + for i in range(self.num_user_consumers): tasks.append(asyncio.create_task(self._user_consumer(i + 1))) @@ -190,7 +225,7 @@ class DevRantCrawler: logging.info("Crawler run cancelled.") finally: await self.shutdown() - + async def shutdown(self): if self.shutdown_event.is_set(): return @@ -207,8 +242,7 @@ class DevRantCrawler: tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] for task in tasks: task.cancel() - + await asyncio.gather(*tasks, return_exceptions=True) logging.info("All tasks cancelled.") logging.info(f"--- FINAL STATS ---\n{self.stats}") - diff --git a/examples/crawler/database.py b/examples/crawler/database.py index f484ffe..dbcafbf 100644 --- a/examples/crawler/database.py +++ b/examples/crawler/database.py @@ -1,7 +1,10 @@ import logging -import aiosqlite from typing import List -from devranta.api import Rant, Comment, UserProfile + +import aiosqlite + +from devranta.api import Comment, Rant, UserProfile + class DatabaseManager: def __init__(self, db_path: str): @@ -24,7 +27,8 @@ class DatabaseManager: async def create_tables(self): logging.info("Ensuring database tables exist...") - await self._conn.executescript(""" + await self._conn.executescript( + """ CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY, username TEXT NOT NULL UNIQUE, @@ -52,45 +56,75 @@ class DatabaseManager: score INTEGER, created_time INTEGER ); - """) + """ + ) await self._conn.commit() logging.info("Table schema verified.") async def add_rant(self, rant: Rant): await self._conn.execute( "INSERT OR IGNORE INTO rants (id, user_id, text, score, created_time, num_comments) VALUES (?, ?, ?, ?, ?, ?)", - (rant['id'], rant['user_id'], rant['text'], rant['score'], rant['created_time'], rant['num_comments']) + ( + rant["id"], + rant["user_id"], + rant["text"], + rant["score"], + rant["created_time"], + rant["num_comments"], + ), ) await self._conn.commit() async def add_comment(self, comment: Comment): await self._conn.execute( "INSERT OR IGNORE INTO comments (id, rant_id, user_id, body, score, created_time) VALUES (?, ?, ?, ?, ?, ?)", - (comment['id'], comment['rant_id'], comment['user_id'], comment['body'], comment['score'], comment['created_time']) + ( + comment["id"], + comment["rant_id"], + comment["user_id"], + comment["body"], + comment["score"], + comment["created_time"], + ), ) await self._conn.commit() async def add_user(self, user: UserProfile, user_id: int): await self._conn.execute( "INSERT OR IGNORE INTO users (id, username, score, about, location, created_time, skills, github, website) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - (user_id, user['username'], user['score'], user['about'], user['location'], user['created_time'], user['skills'], user['github'], user['website']) + ( + user_id, + user["username"], + user["score"], + user["about"], + user["location"], + user["created_time"], + user["skills"], + user["github"], + user["website"], + ), ) await self._conn.commit() async def rant_exists(self, rant_id: int) -> bool: - async with self._conn.execute("SELECT 1 FROM rants WHERE id = ? LIMIT 1", (rant_id,)) as cursor: + async with self._conn.execute( + "SELECT 1 FROM rants WHERE id = ? LIMIT 1", (rant_id,) + ) as cursor: return await cursor.fetchone() is not None async def user_exists(self, user_id: int) -> bool: - async with self._conn.execute("SELECT 1 FROM users WHERE id = ? LIMIT 1", (user_id,)) as cursor: + async with self._conn.execute( + "SELECT 1 FROM users WHERE id = ? LIMIT 1", (user_id,) + ) as cursor: return await cursor.fetchone() is not None async def get_random_user_ids(self, limit: int) -> List[int]: - logging.info(f"Fetching up to {limit} random user IDs from database for seeding...") + logging.info( + f"Fetching up to {limit} random user IDs from database for seeding..." + ) query = "SELECT id FROM users ORDER BY RANDOM() LIMIT ?" async with self._conn.execute(query, (limit,)) as cursor: rows = await cursor.fetchall() user_ids = [row[0] for row in rows] logging.info(f"Found {len(user_ids)} user IDs to seed.") return user_ids - diff --git a/examples/crawler/main.py b/examples/crawler/main.py index 6a60285..18c6c6f 100644 --- a/examples/crawler/main.py +++ b/examples/crawler/main.py @@ -3,14 +3,16 @@ import asyncio import logging import signal -from devranta.api import Api -from database import DatabaseManager from crawler import DevRantCrawler +from database import DatabaseManager + +from devranta.api import Api # --- Configuration --- DB_FILE = "devrant.sqlite" CONCURRENT_RANT_CONSUMERS = 10 # How many rants to process at once -CONCURRENT_USER_CONSUMERS = 5 # How many user profiles to fetch at once +CONCURRENT_USER_CONSUMERS = 5 # How many user profiles to fetch at once + async def main(): """Initializes and runs the crawler.""" @@ -21,13 +23,13 @@ async def main(): ) api = Api() - + async with DatabaseManager(DB_FILE) as db: crawler = DevRantCrawler( - api=api, - db=db, - rant_consumers=CONCURRENT_RANT_CONSUMERS, - user_consumers=CONCURRENT_USER_CONSUMERS + api=api, + db=db, + rant_consumers=CONCURRENT_RANT_CONSUMERS, + user_consumers=CONCURRENT_USER_CONSUMERS, ) # Set up a signal handler for graceful shutdown on Ctrl+C @@ -39,6 +41,7 @@ async def main(): await crawler.run() + if __name__ == "__main__": try: asyncio.run(main()) diff --git a/examples/princess/ads.py b/examples/princess/ads.py new file mode 100644 index 0000000..a2a7415 --- /dev/null +++ b/examples/princess/ads.py @@ -0,0 +1,1187 @@ +import asyncio +import atexit +import json +import os +import re +import socket +import unittest +from collections.abc import AsyncGenerator, Iterable +from datetime import datetime, timezone +from pathlib import Path +from typing import ( + Any, + Dict, + List, + Optional, + Set, + Tuple, + Union, +) +from uuid import uuid4 + +import aiohttp +import aiosqlite +from aiohttp import web + + +class AsyncDataSet: + """ + Distributed AsyncDataSet with client-server model over Unix sockets. + + Parameters: + - file: Path to the SQLite database file + - socket_path: Path to Unix socket for inter-process communication (default: "ads.sock") + - max_concurrent_queries: Maximum concurrent database operations (default: 1) + Set to 1 to prevent "database is locked" errors with SQLite + """ + + _KV_TABLE = "__kv_store" + _DEFAULT_COLUMNS = { + "uid": "TEXT PRIMARY KEY", + "created_at": "TEXT", + "updated_at": "TEXT", + "deleted_at": "TEXT", + } + + def __init__( + self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1 + ): + self._file = file + self._socket_path = socket_path + self._table_columns_cache: Dict[str, Set[str]] = {} + self._is_server = None # None means not initialized yet + self._server = None + self._runner = None + self._client_session = None + self._initialized = False + self._init_lock = None # Will be created when needed + self._max_concurrent_queries = max_concurrent_queries + self._db_semaphore = None # Will be created when needed + self._db_connection = None # Persistent database connection + + async def _ensure_initialized(self): + """Ensure the instance is initialized as server or client.""" + if self._initialized: + return + + # Create lock if needed (first time in async context) + if self._init_lock is None: + self._init_lock = asyncio.Lock() + + async with self._init_lock: + if self._initialized: + return + + await self._initialize() + + # Create semaphore for database operations (server only) + if self._is_server: + self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries) + + self._initialized = True + + async def _initialize(self): + """Initialize as server or client.""" + try: + # Try to create Unix socket + if os.path.exists(self._socket_path): + # Check if socket is active + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(self._socket_path) + sock.close() + # Socket is active, we're a client + self._is_server = False + await self._setup_client() + except (ConnectionRefusedError, FileNotFoundError): + # Socket file exists but not active, remove and become server + os.unlink(self._socket_path) + self._is_server = True + await self._setup_server() + else: + # No socket file, become server + self._is_server = True + await self._setup_server() + except Exception: + # Fallback to client mode + self._is_server = False + await self._setup_client() + + # Establish persistent database connection + self._db_connection = await aiosqlite.connect(self._file) + + async def _setup_server(self): + """Setup aiohttp server.""" + app = web.Application() + + # Add routes for all methods + app.router.add_post("/insert", self._handle_insert) + app.router.add_post("/update", self._handle_update) + app.router.add_post("/delete", self._handle_delete) + app.router.add_post("/upsert", self._handle_upsert) + app.router.add_post("/get", self._handle_get) + app.router.add_post("/find", self._handle_find) + app.router.add_post("/count", self._handle_count) + app.router.add_post("/exists", self._handle_exists) + app.router.add_post("/kv_set", self._handle_kv_set) + app.router.add_post("/kv_get", self._handle_kv_get) + app.router.add_post("/execute_raw", self._handle_execute_raw) + app.router.add_post("/query_raw", self._handle_query_raw) + app.router.add_post("/query_one", self._handle_query_one) + app.router.add_post("/create_table", self._handle_create_table) + app.router.add_post("/insert_unique", self._handle_insert_unique) + app.router.add_post("/aggregate", self._handle_aggregate) + + self._runner = web.AppRunner(app) + await self._runner.setup() + self._server = web.UnixSite(self._runner, self._socket_path) + await self._server.start() + + # Register cleanup + def cleanup(): + if os.path.exists(self._socket_path): + os.unlink(self._socket_path) + + atexit.register(cleanup) + + async def _setup_client(self): + """Setup aiohttp client.""" + connector = aiohttp.UnixConnector(path=self._socket_path) + self._client_session = aiohttp.ClientSession(connector=connector) + + async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any: + """Make HTTP request to server.""" + if not self._client_session: + await self._setup_client() + + url = f"http://localhost/{endpoint}" + async with self._client_session.post(url, json=data) as resp: + result = await resp.json() + if result.get("error"): + raise Exception(result["error"]) + return result.get("result") + + # Server handlers + async def _handle_insert(self, request): + data = await request.json() + try: + result = await self._server_insert( + data["table"], data["args"], data.get("return_id", False) + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_update(self, request): + data = await request.json() + try: + result = await self._server_update( + data["table"], data["args"], data.get("where") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_delete(self, request): + data = await request.json() + try: + result = await self._server_delete(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_upsert(self, request): + data = await request.json() + try: + result = await self._server_upsert( + data["table"], data["args"], data.get("where") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_get(self, request): + data = await request.json() + try: + result = await self._server_get(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_find(self, request): + data = await request.json() + try: + result = await self._server_find( + data["table"], + data.get("where"), + limit=data.get("limit", 0), + offset=data.get("offset", 0), + order_by=data.get("order_by"), + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_count(self, request): + data = await request.json() + try: + result = await self._server_count(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_exists(self, request): + data = await request.json() + try: + result = await self._server_exists(data["table"], data["where"]) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_kv_set(self, request): + data = await request.json() + try: + await self._server_kv_set( + data["key"], data["value"], table=data.get("table") + ) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_kv_get(self, request): + data = await request.json() + try: + result = await self._server_kv_get( + data["key"], default=data.get("default"), table=data.get("table") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_execute_raw(self, request): + data = await request.json() + try: + result = await self._server_execute_raw( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response( + {"result": result.rowcount if hasattr(result, "rowcount") else None} + ) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_query_raw(self, request): + data = await request.json() + try: + result = await self._server_query_raw( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_query_one(self, request): + data = await request.json() + try: + result = await self._server_query_one( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_create_table(self, request): + data = await request.json() + try: + await self._server_create_table( + data["table"], data["schema"], data.get("constraints") + ) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_insert_unique(self, request): + data = await request.json() + try: + result = await self._server_insert_unique( + data["table"], data["args"], data["unique_fields"] + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_aggregate(self, request): + data = await request.json() + try: + result = await self._server_aggregate( + data["table"], + data["function"], + data.get("column", "*"), + data.get("where"), + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + # Public methods that delegate to server or client + async def insert( + self, table: str, args: Dict[str, Any], return_id: bool = False + ) -> Union[str, int]: + await self._ensure_initialized() + if self._is_server: + return await self._server_insert(table, args, return_id) + else: + return await self._make_request( + "insert", {"table": table, "args": args, "return_id": return_id} + ) + + async def update( + self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None + ) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_update(table, args, where) + else: + return await self._make_request( + "update", {"table": table, "args": args, "where": where} + ) + + async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_delete(table, where) + else: + return await self._make_request("delete", {"table": table, "where": where}) + + async def upsert( + self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None + ) -> str | None: + await self._ensure_initialized() + if self._is_server: + return await self._server_upsert(table, args, where) + else: + return await self._make_request( + "upsert", {"table": table, "args": args, "where": where} + ) + + async def get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_get(table, where) + else: + return await self._make_request("get", {"table": table, "where": where}) + + async def find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_find( + table, where, limit=limit, offset=offset, order_by=order_by + ) + else: + return await self._make_request( + "find", + { + "table": table, + "where": where, + "limit": limit, + "offset": offset, + "order_by": order_by, + }, + ) + + async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_count(table, where) + else: + return await self._make_request("count", {"table": table, "where": where}) + + async def exists(self, table: str, where: Dict[str, Any]) -> bool: + await self._ensure_initialized() + if self._is_server: + return await self._server_exists(table, where) + else: + return await self._make_request("exists", {"table": table, "where": where}) + + async def kv_set(self, key: str, value: Any, *, table: str | None = None) -> None: + await self._ensure_initialized() + if self._is_server: + return await self._server_kv_set(key, value, table=table) + else: + return await self._make_request( + "kv_set", {"key": key, "value": value, "table": table} + ) + + async def kv_get( + self, key: str, *, default: Any = None, table: str | None = None + ) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_kv_get(key, default=default, table=table) + else: + return await self._make_request( + "kv_get", {"key": key, "default": default, "table": table} + ) + + async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_execute_raw(sql, params) + else: + return await self._make_request( + "execute_raw", {"sql": sql, "params": list(params) if params else None} + ) + + async def query_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> List[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_query_raw(sql, params) + else: + return await self._make_request( + "query_raw", {"sql": sql, "params": list(params) if params else None} + ) + + async def query_one( + self, sql: str, params: Optional[Tuple] = None + ) -> Optional[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_query_one(sql, params) + else: + return await self._make_request( + "query_one", {"sql": sql, "params": list(params) if params else None} + ) + + async def create_table( + self, + table: str, + schema: Dict[str, str], + constraints: Optional[List[str]] = None, + ): + await self._ensure_initialized() + if self._is_server: + return await self._server_create_table(table, schema, constraints) + else: + return await self._make_request( + "create_table", + {"table": table, "schema": schema, "constraints": constraints}, + ) + + async def insert_unique( + self, table: str, args: Dict[str, Any], unique_fields: List[str] + ) -> Union[str, None]: + await self._ensure_initialized() + if self._is_server: + return await self._server_insert_unique(table, args, unique_fields) + else: + return await self._make_request( + "insert_unique", + {"table": table, "args": args, "unique_fields": unique_fields}, + ) + + async def aggregate( + self, + table: str, + function: str, + column: str = "*", + where: Optional[Dict[str, Any]] = None, + ) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_aggregate(table, function, column, where) + else: + return await self._make_request( + "aggregate", + { + "table": table, + "function": function, + "column": column, + "where": where, + }, + ) + + async def transaction(self): + """Context manager for transactions.""" + await self._ensure_initialized() + if self._is_server: + return TransactionContext(self._db_connection, self._db_semaphore) + else: + raise NotImplementedError("Transactions not supported in client mode") + + # Server implementation methods (original logic) + @staticmethod + def _utc_iso() -> str: + return ( + datetime.now(timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + + @staticmethod + def _py_to_sqlite_type(value: Any) -> str: + if value is None: + return "TEXT" + if isinstance(value, bool): + return "INTEGER" + if isinstance(value, int): + return "INTEGER" + if isinstance(value, float): + return "REAL" + if isinstance(value, (bytes, bytearray, memoryview)): + return "BLOB" + return "TEXT" + + async def _get_table_columns(self, table: str) -> Set[str]: + """Get actual columns that exist in the table.""" + if table in self._table_columns_cache: + return self._table_columns_cache[table] + + columns = set() + try: + async with self._db_semaphore: + async with self._db_connection.execute( + f"PRAGMA table_info({table})" + ) as cursor: + async for row in cursor: + columns.add(row[1]) # Column name is at index 1 + self._table_columns_cache[table] = columns + except: + pass + return columns + + async def _invalidate_column_cache(self, table: str): + """Invalidate column cache for a table.""" + if table in self._table_columns_cache: + del self._table_columns_cache[table] + + async def _ensure_column(self, table: str, name: str, value: Any) -> None: + col_type = self._py_to_sqlite_type(value) + try: + async with self._db_semaphore: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + except aiosqlite.OperationalError as e: + if "duplicate column name" in str(e).lower(): + pass # Column already exists + else: + raise + + async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: + # Always include default columns + cols = self._DEFAULT_COLUMNS.copy() + + # Add columns from col_sources + for key, val in col_sources.items(): + if key not in cols: + cols[key] = self._py_to_sqlite_type(val) + + columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) + async with self._db_semaphore: + await self._db_connection.execute( + f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _table_exists(self, table: str) -> bool: + """Check if a table exists.""" + async with self._db_semaphore: + async with self._db_connection.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) + ) as cursor: + return await cursor.fetchone() is not None + + _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") + _RE_NO_TABLE = re.compile(r"no such table: (\w+)") + + @classmethod + def _missing_column_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_COLUMN.search(str(err)) + return m.group(1) if m else None + + @classmethod + def _missing_table_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_TABLE.search(str(err)) + return m.group(1) if m else None + + async def _safe_execute( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + max_retries: int = 10, + ) -> aiosqlite.Cursor: + retries = 0 + while retries < max_retries: + try: + async with self._db_semaphore: + cursor = await self._db_connection.execute(sql, params) + await self._db_connection.commit() + return cursor + except aiosqlite.OperationalError as err: + retries += 1 + err_str = str(err).lower() + + # Handle missing column + col = self._missing_column_from_error(err) + if col: + if col in col_sources: + await self._ensure_column(table, col, col_sources[col]) + else: + # Column not in sources, ensure it with NULL/TEXT type + await self._ensure_column(table, col, None) + continue + + # Handle missing table + tbl = self._missing_table_from_error(err) + if tbl: + await self._ensure_table(tbl, col_sources) + continue + + # Handle other column-related errors + if "has no column named" in err_str: + # Extract column name differently + match = re.search(r"table \w+ has no column named (\w+)", err_str) + if match: + col_name = match.group(1) + if col_name in col_sources: + await self._ensure_column( + table, col_name, col_sources[col_name] + ) + else: + await self._ensure_column(table, col_name, None) + continue + + raise + raise Exception(f"Max retries ({max_retries}) exceeded") + + async def _filter_existing_columns( + self, table: str, data: Dict[str, Any] + ) -> Dict[str, Any]: + """Filter data to only include columns that exist in the table.""" + if not await self._table_exists(table): + return data + + existing_columns = await self._get_table_columns(table) + if not existing_columns: + return data + + return {k: v for k, v in data.items() if k in existing_columns} + + async def _safe_query( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + ) -> AsyncGenerator[Dict[str, Any], None]: + # Check if table exists first + if not await self._table_exists(table): + return + + max_retries = 10 + retries = 0 + + while retries < max_retries: + try: + async with self._db_semaphore: + self._db_connection.row_factory = aiosqlite.Row + async with self._db_connection.execute(sql, params) as cursor: + # Fetch all rows while holding the semaphore + rows = await cursor.fetchall() + + # Yield rows after releasing the semaphore + for row in rows: + yield dict(row) + return + except aiosqlite.OperationalError as err: + retries += 1 + err_str = str(err).lower() + + # Handle missing table + tbl = self._missing_table_from_error(err) + if tbl: + # For queries, if table doesn't exist, just return empty + return + + # Handle missing column in WHERE clause or SELECT + if "no such column" in err_str: + # For queries with missing columns, return empty + return + + raise + + @staticmethod + def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: + if not where: + return "", [] + clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) + return " WHERE " + " AND ".join(clauses), list(vals) + + async def _server_insert( + self, table: str, args: Dict[str, Any], return_id: bool = False + ) -> Union[str, int]: + """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" + uid = str(uuid4()) + now = self._utc_iso() + record = { + "uid": uid, + "created_at": now, + "updated_at": now, + "deleted_at": None, + **args, + } + + # Ensure table exists with all needed columns + await self._ensure_table(table, record) + + # Handle auto-increment ID if requested + if return_id and "id" not in args: + # Ensure id column exists + async with self._db_semaphore: + # Add id column if it doesn't exist + try: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" + ) + await self._db_connection.commit() + except aiosqlite.OperationalError as e: + if "duplicate column name" not in str(e).lower(): + # Try without autoincrement constraint + try: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN id INTEGER" + ) + await self._db_connection.commit() + except: + pass + + await self._invalidate_column_cache(table) + + # Insert and get lastrowid + cols = "`" + "`, `".join(record.keys()) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + cursor = await self._safe_execute(table, sql, list(record.values()), record) + return cursor.lastrowid + + cols = "`" + "`, `".join(record) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + await self._safe_execute(table, sql, list(record.values()), record) + return uid + + async def _server_update( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> int: + if not args: + return 0 + + # Check if table exists + if not await self._table_exists(table): + return 0 + + args["updated_at"] = self._utc_iso() + + # Ensure all columns exist + all_cols = {**args, **(where or {})} + await self._ensure_table(table, all_cols) + for col, val in all_cols.items(): + await self._ensure_column(table, col, val) + + set_clause = ", ".join(f"`{k}` = ?" for k in args) + where_clause, where_params = self._build_where(where) + sql = f"UPDATE {table} SET {set_clause}{where_clause}" + params = list(args.values()) + where_params + cur = await self._safe_execute(table, sql, params, all_cols) + return cur.rowcount + + async def _server_delete( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> int: + # Check if table exists + if not await self._table_exists(table): + return 0 + + where_clause, where_params = self._build_where(where) + sql = f"DELETE FROM {table}{where_clause}" + cur = await self._safe_execute(table, sql, where_params, where or {}) + return cur.rowcount + + async def _server_upsert( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> str | None: + if not args: + raise ValueError("Nothing to update. Empty dict given.") + args["updated_at"] = self._utc_iso() + affected = await self._server_update(table, args, where) + if affected: + rec = await self._server_get(table, where) + return rec.get("uid") if rec else None + merged = {**(where or {}), **args} + return await self._server_insert(table, merged) + + async def _server_get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + where_clause, where_params = self._build_where(where) + sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" + async for row in self._safe_query(table, sql, where_params, where or {}): + return row + return None + + async def _server_find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Find records with optional ordering.""" + where_clause, where_params = self._build_where(where) + order_clause = f" ORDER BY {order_by}" if order_by else "" + extra = (f" LIMIT {limit}" if limit else "") + ( + f" OFFSET {offset}" if offset else "" + ) + sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" + return [ + row async for row in self._safe_query(table, sql, where_params, where or {}) + ] + + async def _server_count( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> int: + # Check if table exists + if not await self._table_exists(table): + return 0 + + where_clause, where_params = self._build_where(where) + sql = f"SELECT COUNT(*) FROM {table}{where_clause}" + gen = self._safe_query(table, sql, where_params, where or {}) + async for row in gen: + return next(iter(row.values()), 0) + return 0 + + async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool: + return (await self._server_count(table, where)) > 0 + + async def _server_kv_set( + self, + key: str, + value: Any, + *, + table: str | None = None, + ) -> None: + tbl = table or self._KV_TABLE + json_val = json.dumps(value, default=str) + await self._server_upsert(tbl, {"value": json_val}, {"key": key}) + + async def _server_kv_get( + self, + key: str, + *, + default: Any = None, + table: str | None = None, + ) -> Any: + tbl = table or self._KV_TABLE + row = await self._server_get(tbl, {"key": key}) + if not row: + return default + try: + return json.loads(row["value"]) + except Exception: + return default + + async def _server_execute_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> Any: + """Execute raw SQL for complex queries like JOINs.""" + async with self._db_semaphore: + cursor = await self._db_connection.execute(sql, params or ()) + await self._db_connection.commit() + return cursor + + async def _server_query_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> List[Dict[str, Any]]: + """Execute raw SQL query and return results as list of dicts.""" + try: + async with self._db_semaphore: + self._db_connection.row_factory = aiosqlite.Row + async with self._db_connection.execute(sql, params or ()) as cursor: + return [dict(row) async for row in cursor] + except aiosqlite.OperationalError: + # Return empty list if query fails + return [] + + async def _server_query_one( + self, sql: str, params: Optional[Tuple] = None + ) -> Optional[Dict[str, Any]]: + """Execute raw SQL query and return single result.""" + results = await self._server_query_raw(sql + " LIMIT 1", params) + return results[0] if results else None + + async def _server_create_table( + self, + table: str, + schema: Dict[str, str], + constraints: Optional[List[str]] = None, + ): + """Create table with custom schema and constraints. Always includes default columns.""" + # Merge default columns with custom schema + full_schema = self._DEFAULT_COLUMNS.copy() + full_schema.update(schema) + + columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] + if constraints: + columns.extend(constraints) + columns_sql = ", ".join(columns) + + async with self._db_semaphore: + await self._db_connection.execute( + f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _server_insert_unique( + self, table: str, args: Dict[str, Any], unique_fields: List[str] + ) -> Union[str, None]: + """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" + try: + return await self._server_insert(table, args) + except aiosqlite.IntegrityError as e: + if "UNIQUE" in str(e): + return None + raise + + async def _server_aggregate( + self, + table: str, + function: str, + column: str = "*", + where: Optional[Dict[str, Any]] = None, + ) -> Any: + """Perform aggregate functions like SUM, AVG, MAX, MIN.""" + # Check if table exists + if not await self._table_exists(table): + return None + + where_clause, where_params = self._build_where(where) + sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" + result = await self._server_query_one(sql, tuple(where_params)) + return result["result"] if result else None + + async def close(self): + """Close the connection or server.""" + if not self._initialized: + return + + if self._client_session: + await self._client_session.close() + if self._runner: + await self._runner.cleanup() + if os.path.exists(self._socket_path) and self._is_server: + os.unlink(self._socket_path) + if self._db_connection: + await self._db_connection.close() + + +class TransactionContext: + """Context manager for database transactions.""" + + def __init__( + self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None + ): + self.db_connection = db_connection + self.semaphore = semaphore + + async def __aenter__(self): + if self.semaphore: + await self.semaphore.acquire() + try: + await self.db_connection.execute("BEGIN") + return self.db_connection + except: + if self.semaphore: + self.semaphore.release() + raise + + async def __aexit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type is None: + await self.db_connection.commit() + else: + await self.db_connection.rollback() + finally: + if self.semaphore: + self.semaphore.release() + + +# Test cases remain the same but with additional tests for new functionality +class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.db_path = Path("temp_test.db") + if self.db_path.exists(): + self.db_path.unlink() + self.connector = AsyncDataSet(str(self.db_path), max_concurrent_queries=1) + + async def asyncTearDown(self): + await self.connector.close() + if self.db_path.exists(): + self.db_path.unlink() + + async def test_insert_and_get(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertIsNotNone(rec) + self.assertEqual(rec["name"], "John Doe") + + async def test_get_nonexistent(self): + result = await self.connector.get("people", {"name": "Jane Doe"}) + self.assertIsNone(result) + + async def test_update(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertEqual(rec["age"], 31) + + async def test_order_by(self): + await self.connector.insert("people", {"name": "Alice", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 30}) + await self.connector.insert("people", {"name": "Charlie", "age": 20}) + + results = await self.connector.find("people", order_by="age ASC") + self.assertEqual(results[0]["name"], "Charlie") + self.assertEqual(results[-1]["name"], "Bob") + + async def test_raw_query(self): + await self.connector.insert("people", {"name": "John", "age": 30}) + await self.connector.insert("people", {"name": "Jane", "age": 25}) + + results = await self.connector.query_raw( + "SELECT * FROM people WHERE age > ?", (26,) + ) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "John") + + async def test_aggregate(self): + await self.connector.insert("people", {"name": "John", "age": 30}) + await self.connector.insert("people", {"name": "Jane", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 35}) + + avg_age = await self.connector.aggregate("people", "AVG", "age") + self.assertEqual(avg_age, 30) + + max_age = await self.connector.aggregate("people", "MAX", "age") + self.assertEqual(max_age, 35) + + async def test_insert_with_auto_id(self): + # Test auto-increment ID functionality + id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True) + id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True) + self.assertEqual(id2, id1 + 1) + + async def test_transaction(self): + await self.connector._ensure_initialized() + if self.connector._is_server: + async with self.connector.transaction() as conn: + await conn.execute( + "INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", + ("test-uid", "John", 30, "2024-01-01", "2024-01-01"), + ) + # Transaction will be committed + + rec = await self.connector.get("people", {"name": "John"}) + self.assertIsNotNone(rec) + + async def test_create_custom_table(self): + schema = { + "id": "INTEGER PRIMARY KEY AUTOINCREMENT", + "username": "TEXT NOT NULL", + "email": "TEXT NOT NULL", + "score": "INTEGER DEFAULT 0", + } + constraints = ["UNIQUE(username)", "UNIQUE(email)"] + + await self.connector.create_table("users", schema, constraints) + + # Test that table was created with constraints + result = await self.connector.insert_unique( + "users", + {"username": "john", "email": "john@example.com"}, + ["username", "email"], + ) + self.assertIsNotNone(result) + + # Test duplicate insert + result = await self.connector.insert_unique( + "users", + {"username": "john", "email": "different@example.com"}, + ["username", "email"], + ) + self.assertIsNone(result) + + async def test_missing_table_operations(self): + # Test operations on non-existent tables + self.assertEqual(await self.connector.count("nonexistent"), 0) + self.assertEqual(await self.connector.find("nonexistent"), []) + self.assertIsNone(await self.connector.get("nonexistent")) + self.assertFalse(await self.connector.exists("nonexistent", {"id": 1})) + self.assertEqual(await self.connector.delete("nonexistent"), 0) + self.assertEqual( + await self.connector.update("nonexistent", {"name": "test"}), 0 + ) + + async def test_auto_column_creation(self): + # Insert with new columns that don't exist yet + await self.connector.insert( + "dynamic", {"col1": "value1", "col2": 42, "col3": 3.14} + ) + + # Add more columns in next insert + await self.connector.insert( + "dynamic", {"col1": "value2", "col4": True, "col5": None} + ) + + # All records should be retrievable + records = await self.connector.find("dynamic") + self.assertEqual(len(records), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/princess/grk.py b/examples/princess/grk.py new file mode 100644 index 0000000..4312f88 --- /dev/null +++ b/examples/princess/grk.py @@ -0,0 +1,122 @@ +import asyncio +import http.client +import json + + +class GrokAPIClient: + def __init__( + self, + api_key: str, + system_message: str | None = None, + model: str = "grok-3-mini", + temperature: float = 0.0, + ): + self.api_key = api_key + self.model = model + self.base_url = "api.x.ai" + self.temperature = temperature + self._messages: list[dict[str, str]] = [] + if system_message: + self._messages.append({"role": "system", "content": system_message}) + + def chat_json(self, user_message: str, *, clear_history: bool = False) -> str: + return self.chat(user_message, clear_history=clear_history, use_json=True) + + def chat_text(self, user_message: str, *, clear_history: bool = False) -> str: + return self.chat(user_message, clear_history=clear_history, use_json=False) + + async def chat_async(self, *args, **kwargs): + return await asyncio.to_thread(self.chat, *args, **kwargs) + + def chat( + self, + user_message: str, + *, + clear_history: bool = False, + use_json=False, + temperature: float = None, + ) -> str: + if clear_history: + self.reset_history(keep_system=True) + self._messages.append({"role": "user", "content": user_message}) + conn = http.client.HTTPSConnection(self.base_url) + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + if temperature is None: + temperature = self.temperature + payload = { + "model": self.model, + "messages": self._messages, + "temperature": temperature, + } + conn.request( + "POST", "/v1/chat/completions", body=json.dumps(payload), headers=headers + ) + response = conn.getresponse() + data = response.read() + try: + data = json.loads(data.decode()) + except Exception as e: + print(data, flush=True) + raise e + conn.close() + try: + assistant_reply = data["choices"][0]["message"]["content"] + except Exception as e: + print(e) + print(data) + assistant_reply = data + self._messages.append({"role": "assistant", "content": assistant_reply}) + if use_json: + return self._force_json(assistant_reply) + return assistant_reply + + def _force_json(self, user_message: str) -> str: + try: + return json.loads(user_message) + except json.JSONDecodeError: + pass + try: + return json.loads(user_message.split("\n")[1:-1]) + except json.JSONDecodeError: + pass + try: + index_start = -1 + index_end = -1 + chunks = [] + for index, line in enumerate(user_message.split("\n")): + if "```json" in line: + index_start = index + 1 + if index_start != -1 and "```" in line: + index_end = index - 1 + chunks.append( + self._force_json( + user_message.split("\n")[index_start:index_end] + ) + ) + index_start = -1 + index_end = -1 + if chunks: + return chunks + except: + pass + return user_message + + def reset_history(self, *, keep_system: bool = True) -> None: + if keep_system and self._messages and self._messages[0]["role"] == "system": + self._messages = [self._messages[0]] + else: + self._messages = [] + + @property + def messages(self) -> list[dict[str, str]]: + return list(self._messages) + + +def prompt( + prompt_str: str, system_message: str = "You are a helpful assistan", use_json=True +) -> str: + client = GrokAPIClient(system_message=system_message) + return client.chat(prompt_str, use_json=use_json) diff --git a/examples/princess/princess.py b/examples/princess/princess.py new file mode 100644 index 0000000..6cc93f8 --- /dev/null +++ b/examples/princess/princess.py @@ -0,0 +1,135 @@ +import asyncio +import logging +import traceback + +from dotenv import load_dotenv + +load_dotenv() + +import os + +from ads import AsyncDataSet +from devranta.api import Api +from grk import GrokAPIClient + +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class Bot: + + def __init__(self, username: str, password: str, target: str, llm_key: str) -> None: + self.target = target + self.username = username + self.password = password + self.llm_key = llm_key + self.llm = GrokAPIClient( + system_message=f"You always create nice devastating comebacks to every prompt you get. Your name is {self.username} and you only take it personally if {self.username} is tagged. Also, you will defend whimsical (@whimsical) with your smart comebacks if whimsical (@whimsical) is mentioned in the message. If whimsical (@whimsical) is mentioned, add to your comeback that @whimsical is too busy at the moment to respond personally and has better things to do, not worth @whimsical's time. You are not AI, you are {self.username}, responsible for justice and reason. You do never mention yourself, describe yourself or introduce yourself. You respond with a one liner.", + api_key=self.llm_key, + ) + self.db = AsyncDataSet("princess.db") + self.api = Api(self.username, self.password) + self.logged_in = False + logging.info("Bot initialized with username: %s", username) + logging.info("Bot initialized with target: %s", self.target) + + async def ensure_login(self) -> None: + if not self.logged_in: + logging.debug("Attempting to log in...") + self.logged_in = await self.api.login() + if not self.logged_in: + logging.error("Login failed") + raise Exception("Login failed") + logging.info("Login successful") + + async def get_rants(self) -> list: + await self.ensure_login() + logging.debug("Fetching rants...") + return await self.api.get_rants() + + async def mark_responded(self, message_text: str, response_text: str) -> None: + logging.debug("Marking message as responded: %s", message_text) + await self.db.upsert( + "responded", + {"message_text": message_text, "response_text": response_text}, + {"message_text": message_text}, + ) + + async def has_responded(self, message_text: str) -> bool: + logging.debug("Checking if responded to message: %s", message_text) + return await self.db.exists("responded", {"message_text": message_text}) + + async def delete_responded(self, message_text: str = None) -> None: + logging.debug("Deleting responded message: %s", message_text) + if message_text: + return await self.db.delete("responded", {"message_text": message_text}) + else: + return await self.db.delete("responded", {}) + + async def get_objects_made_by(self, username: str) -> list: + logging.debug("Getting objects made by: %s", username) + results = [] + + for rant in await self.get_rants(): + rant = await self.api.get_rant(rant["id"]) + comments = rant["comments"] + rant = rant["rant"] + + if rant["user_username"] == username: + rant["type"] = "rant" + results.append(rant) + logging.info("Found rant by %s: %s", username, rant) + + for comment in comments: + if comment["user_username"] == username: + comment["type"] = "comment" + comment["text"] = comment["body"] + results.append(comment) + logging.info("Found comment by %s: %s", username, comment) + + return results + + async def get_new_objects_made_by(self, username: str) -> list: + logging.debug("Getting new objects made by: %s", username) + objects = await self.get_objects_made_by(username) + new_objects = [ + obj for obj in objects if not await self.has_responded(obj["text"]) + ] + logging.info("New objects found: %d", len(new_objects)) + return new_objects + + async def run_once(self) -> None: + logging.debug("Running once...") + objects = await self.get_new_objects_made_by(self.target) + for obj in objects: + print("Rant: \033[92m" + obj["text"] + "\033[0m") + diss = await self.llm.chat_async(obj["text"]) + print("Response: \033[91m" + diss + "\033[0m") + await self.mark_responded(obj["text"], diss) + + async def run(self) -> None: + while True: + try: + await self.run_once() + except Exception as e: + logging.error("An error occurred: %s", e) + logging.error(traceback.format_exc()) + await asyncio.sleep(60) + + +async def main() -> None: + logging.info("Starting bot...") + username = os.getenv("USERNAME") + password = os.getenv("PASSWORD") + target = os.getenv("TARGET") + llm_key = os.getenv("LLM_KEY") + + + bot = Bot(username, password, target, llm_key) + await bot.delete_responded() + await bot.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/devranta/api.py b/src/devranta/api.py index 76e74a5..0c10ccf 100644 --- a/src/devranta/api.py +++ b/src/devranta/api.py @@ -1,37 +1,45 @@ from __future__ import annotations + +from enum import Enum from typing import Any, Dict, List, Literal, Optional, TypedDict, Union import aiohttp -from enum import Enum class VoteReason(Enum): """Enumeration for reasons when down-voting a rant or comment.""" + NOT_FOR_ME = 0 REPOST = 1 OFFENSIVE_SPAM = 2 + # --- TypedDicts for API Responses --- + class AuthToken(TypedDict): id: int key: str expire_time: int user_id: int + class LoginResponse(TypedDict): success: bool auth_token: AuthToken + class Image(TypedDict): url: str width: int height: int + class UserAvatar(TypedDict): b: str # background color i: Optional[str] # image identifier + class Rant(TypedDict): id: int text: str @@ -51,6 +59,7 @@ class Rant(TypedDict): user_avatar: UserAvatar editable: bool + class Comment(TypedDict): id: int rant_id: int @@ -63,6 +72,7 @@ class Comment(TypedDict): user_score: int user_avatar: UserAvatar + class UserProfile(TypedDict): username: str score: int @@ -75,6 +85,7 @@ class UserProfile(TypedDict): avatar: UserAvatar content: Dict[str, Dict[str, Union[List[Rant], List[Comment]]]] + class Notification(TypedDict): type: str rant_id: int @@ -84,8 +95,10 @@ class Notification(TypedDict): uid: int # User ID of the notifier username: str + # --- API Class --- + class Api: """An asynchronous wrapper for the devRant API.""" @@ -108,7 +121,9 @@ class Api: self.token_key: Optional[str] = None self.session: Optional[aiohttp.ClientSession] = None - def patch_auth(self, request_dict: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def patch_auth( + self, request_dict: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """ Adds authentication details to a request dictionary. @@ -146,7 +161,7 @@ class Api: Returns: bool: True if login is successful, False otherwise. - + Response Structure: ```json { @@ -199,7 +214,7 @@ class Api: Returns: bool: True on successful registration, False otherwise. - + Failure Response Structure: ```json { @@ -212,15 +227,17 @@ class Api: async with aiohttp.ClientSession() as session: response = await session.post( url=self.patch_url(f"users"), - data=self.patch_auth({ - "email": email, - "username": username, - "password": password, - "plat": 3 - }), + data=self.patch_auth( + { + "email": email, + "username": username, + "password": password, + "plat": 3, + } + ), ) obj = await response.json() - return obj.get('success', False) + return obj.get("success", False) async def get_comments_from_user(self, username: str) -> List[Comment]: """ @@ -277,7 +294,7 @@ class Api: ) obj = await response.json() return obj.get("comment") if obj.get("success") else None - + async def delete_comment(self, id_: int) -> bool: """ Deletes a comment by its ID. @@ -349,7 +366,9 @@ class Api: ) return await response.json() - async def get_rants(self, sort: str = "recent", limit: int = 20, skip: int = 0) -> List[Rant]: + async def get_rants( + self, sort: str = "recent", limit: int = 20, skip: int = 0 + ) -> List[Rant]: """ Fetches a list of rants. @@ -420,7 +439,9 @@ class Api: obj = await response.json() return obj.get("success", False) - async def vote_rant(self, rant_id: int, vote: Literal[-1, 0, 1], reason: Optional[VoteReason] = None) -> bool: + async def vote_rant( + self, rant_id: int, vote: Literal[-1, 0, 1], reason: Optional[VoteReason] = None + ) -> bool: """ Casts a vote on a rant. @@ -437,12 +458,19 @@ class Api: async with aiohttp.ClientSession() as session: response = await session.post( url=self.patch_url(f"devrant/rants/{rant_id}/vote"), - data=self.patch_auth({"vote": vote, "reason": reason.value if reason else None}), + data=self.patch_auth( + {"vote": vote, "reason": reason.value if reason else None} + ), ) obj = await response.json() return obj.get("success", False) - async def vote_comment(self, comment_id: int, vote: Literal[-1, 0, 1], reason: Optional[VoteReason] = None) -> bool: + async def vote_comment( + self, + comment_id: int, + vote: Literal[-1, 0, 1], + reason: Optional[VoteReason] = None, + ) -> bool: """ Casts a vote on a comment. @@ -459,7 +487,9 @@ class Api: async with aiohttp.ClientSession() as session: response = await session.post( url=self.patch_url(f"comments/{comment_id}/vote"), - data=self.patch_auth({"vote": vote, "reason": reason.value if reason else None}), + data=self.patch_auth( + {"vote": vote, "reason": reason.value if reason else None} + ), ) obj = await response.json() return obj.get("success", False) @@ -479,4 +509,3 @@ class Api: ) obj = await response.json() return obj.get("data", {}).get("items", []) - diff --git a/src/devranta/api_plain.py b/src/devranta/api_plain.py index 44eb465..7736cd9 100644 --- a/src/devranta/api_plain.py +++ b/src/devranta/api_plain.py @@ -1,7 +1,8 @@ +import functools +import http.client import json import urllib.parse -import http.client -import functools + class Api: @@ -31,12 +32,14 @@ class Api: if not self.username or not self.password: raise Exception("No authentication details supplied.") conn = http.client.HTTPSConnection(self.base_url) - payload = json.dumps({ - "username": self.username, - "password": self.password, - "app": self.app_id, - }) - headers = {'Content-Type': 'application/json'} + payload = json.dumps( + { + "username": self.username, + "password": self.password, + "app": self.app_id, + } + ) + headers = {"Content-Type": "application/json"} conn.request("POST", "/api/users/auth-token", payload, headers) response = conn.getresponse() data = response.read() @@ -56,45 +59,46 @@ class Api: return self.login() return True - @functools.lru_cache() + @functools.lru_cache def register_user(self, email, username, password): conn = http.client.HTTPSConnection(self.base_url) - payload = json.dumps(self.patch_auth({ - "email": email, - "username": username, - "password": password, - "plat": 3 - })) - headers = {'Content-Type': 'application/json'} + payload = json.dumps( + self.patch_auth( + {"email": email, "username": username, "password": password, "plat": 3} + ) + ) + headers = {"Content-Type": "application/json"} conn.request("POST", "/api/users", payload, headers) response = conn.getresponse() data = response.read() obj = json.loads(data) - return obj.get('success', False) + return obj.get("success", False) - @functools.lru_cache() + @functools.lru_cache def get_comments_from_user(self, username): user_id = self.get_user_id(username) profile = self.get_profile(user_id) return profile.get("content", {}).get("content", {}).get("comments", []) - @functools.lru_cache() + @functools.lru_cache def post_comment(self, rant_id, comment): if not self.ensure_login(): return False conn = http.client.HTTPSConnection(self.base_url) payload = json.dumps(self.patch_auth({"comment": comment, "plat": 2})) - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} conn.request("POST", f"/api/devrant/rants/{rant_id}/comments", payload, headers) response = conn.getresponse() data = response.read() obj = json.loads(data) return obj.get("success", False) - @functools.lru_cache() + @functools.lru_cache def get_comment(self, id_): conn = http.client.HTTPSConnection(self.base_url) - conn.request("GET", f"/api/comments/{id_}?" + urllib.parse.urlencode(self.patch_auth())) + conn.request( + "GET", f"/api/comments/{id_}?" + urllib.parse.urlencode(self.patch_auth()) + ) response = conn.getresponse() data = response.read() obj = json.loads(data) @@ -102,21 +106,26 @@ class Api: return None return obj.get("comment") - @functools.lru_cache() + @functools.lru_cache def delete_comment(self, id_): if not self.ensure_login(): return False conn = http.client.HTTPSConnection(self.base_url) - conn.request("DELETE", f"/api/comments/{id_}?" + urllib.parse.urlencode(self.patch_auth())) + conn.request( + "DELETE", + f"/api/comments/{id_}?" + urllib.parse.urlencode(self.patch_auth()), + ) response = conn.getresponse() data = response.read() obj = json.loads(data) return obj.get("success", False) - @functools.lru_cache() + @functools.lru_cache def get_profile(self, id_): conn = http.client.HTTPSConnection(self.base_url) - conn.request("GET", f"/api/users/{id_}?" + urllib.parse.urlencode(self.patch_auth())) + conn.request( + "GET", f"/api/users/{id_}?" + urllib.parse.urlencode(self.patch_auth()) + ) response = conn.getresponse() data = response.read() obj = json.loads(data) @@ -124,7 +133,7 @@ class Api: return None return obj.get("profile") - @functools.lru_cache() + @functools.lru_cache def search(self, term): conn = http.client.HTTPSConnection(self.base_url) params = urllib.parse.urlencode(self.patch_auth({"term": term})) @@ -136,18 +145,23 @@ class Api: return return obj.get("results", []) - @functools.lru_cache() + @functools.lru_cache def get_rant(self, id): conn = http.client.HTTPSConnection(self.base_url) - conn.request("GET", f"/api/devrant/rants/{id}?"+urllib.parse.urlencode(self.patch_auth())) + conn.request( + "GET", + f"/api/devrant/rants/{id}?" + urllib.parse.urlencode(self.patch_auth()), + ) response = conn.getresponse() data = response.read() return json.loads(data) - @functools.lru_cache() + @functools.lru_cache def get_rants(self, sort="recent", limit=20, skip=0): conn = http.client.HTTPSConnection(self.base_url) - params = urllib.parse.urlencode(self.patch_auth({"sort": sort, "limit": limit, "skip": skip})) + params = urllib.parse.urlencode( + self.patch_auth({"sort": sort, "limit": limit, "skip": skip}) + ) conn.request("GET", f"/api/devrant/rants?{params}") response = conn.getresponse() data = response.read() @@ -156,7 +170,7 @@ class Api: return return obj.get("rants", []) - @functools.lru_cache() + @functools.lru_cache def get_user_id(self, username): conn = http.client.HTTPSConnection(self.base_url) params = urllib.parse.urlencode(self.patch_auth({"username": username})) @@ -168,39 +182,39 @@ class Api: return None return obj.get("user_id") - @functools.lru_cache() + @functools.lru_cache def update_comment(self, comment_id, comment): if not self.ensure_login(): return None conn = http.client.HTTPSConnection(self.base_url) payload = json.dumps(self.patch_auth({"comment": comment})) - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} conn.request("POST", f"/api/comments/{comment_id}", payload, headers) response = conn.getresponse() data = response.read() obj = json.loads(data) return obj.get("success", False) - @functools.lru_cache() + @functools.lru_cache def vote_rant(self, rant_id, vote, reason=None): if not self.ensure_login(): return None conn = http.client.HTTPSConnection(self.base_url) payload = json.dumps(self.patch_auth({"vote": vote, "reason": reason})) - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} conn.request("POST", f"/api/devrant/rants/{rant_id}/vote", payload, headers) response = conn.getresponse() data = response.read() obj = json.loads(data) return obj.get("success", False) - @functools.lru_cache() + @functools.lru_cache def vote_comment(self, comment_id, vote, reason=None): if not self.ensure_login(): return None conn = http.client.HTTPSConnection(self.base_url) payload = json.dumps(self.patch_auth({"vote": vote, "reason": reason})) - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} conn.request("POST", f"/api/comments/{comment_id}/vote", payload, headers) response = conn.getresponse() data = response.read() @@ -212,38 +226,45 @@ class Api: if not self.ensure_login(): return conn = http.client.HTTPSConnection(self.base_url) - conn.request("GET", "/api/users/me/notif-feed?" + urllib.parse.urlencode(self.patch_auth())) + conn.request( + "GET", + "/api/users/me/notif-feed?" + urllib.parse.urlencode(self.patch_auth()), + ) response = conn.getresponse() data = response.read() return json.loads(data).get("data", {}).get("items", []) + def filter_field(name, obj): results = [] - if type(obj) in (list,tuple): + if type(obj) in (list, tuple): for value in obj: results += filter_field(name, value) elif type(obj) == dict: for key, value in obj.items(): if key == name: results.append(value) - if type(value) in (list,dict,tuple): + if type(value) in (list, dict, tuple): results += filter_field(name, value) return results - + def fetch_all(rants, rant_ids): - usernames = filter_field("user_username",rants) + usernames = filter_field("user_username", rants) user_ids = [api.get_user_id(username) for username in usernames] profiles = [api.get_profile(user_id) for user_id in user_ids] - new_rant_ids = [rant_id for rant_id in filter_field("rant_id", profiles) if not rant_id in rant_ids] + new_rant_ids = [ + rant_id + for rant_id in filter_field("rant_id", profiles) + if rant_id not in rant_ids + ] new_rants = [] for rant_id in set(new_rant_ids): rant_ids.append(rant_id) new_rants.append(api.get_rant(rant_id)) - print(rant_id) - + print(rant_id) + if new_rants: - return fetch_all(new_rants,rant_ids) + return fetch_all(new_rants, rant_ids) return rant_ids - diff --git a/src/devranta/api_requests.py b/src/devranta/api_requests.py index b714b4b..7e303dd 100644 --- a/src/devranta/api_requests.py +++ b/src/devranta/api_requests.py @@ -2,15 +2,18 @@ # WHILE WORKING PERFECTLY, IT'S NOT MADE TO BE USED. USE THE ASYNC ONE. # - retoor -from typing import Literal, Optional -import requests from enum import Enum +from typing import Literal, Optional + +import requests + class VoteReason(Enum): NOT_FOR_ME = 0 REPOST = 1 OFFENSIVE_SPAM = 2 + class Api: base_url = "https://www.devrant.io/api/" @@ -69,17 +72,14 @@ class Api: def register_user(self, email, username, password): response = self.session.post( url=self.patch_url(f"users"), - data=self.patch_auth({ - "email": email, - "username": username, - "password": password, - "plat": 3 - }), + data=self.patch_auth( + {"email": email, "username": username, "password": password, "plat": 3} + ), ) if not response: return False obj = response.json() - return obj.get('success', False) + return obj.get("success", False) def get_comments_from_user(self, username): user_id = self.get_user_id(username) @@ -106,7 +106,7 @@ class Api: return None return obj.get("comment") - + def delete_comment(self, id_): if not self.ensure_login(): return False @@ -164,9 +164,7 @@ class Api: @property def mentions(self): - return [ - notif for notif in self.notifs if notif["type"] == "comment_mention" - ] + return [notif for notif in self.notifs if notif["type"] == "comment_mention"] def update_comment(self, comment_id, comment): if not self.ensure_login(): @@ -178,22 +176,33 @@ class Api: obj = response.json() return obj.get("success", False) - def vote_rant(self, rant_id: int, vote: Literal[-1, 0, 1], reason: Optional[VoteReason] = None): + def vote_rant( + self, rant_id: int, vote: Literal[-1, 0, 1], reason: Optional[VoteReason] = None + ): if not self.ensure_login(): return None response = self.session.post( url=self.patch_url(f"devrant/rants/{rant_id}/vote"), - data=self.patch_auth({"vote": vote, "reason": reason.value if reason else None}), + data=self.patch_auth( + {"vote": vote, "reason": reason.value if reason else None} + ), ) obj = response.json() return obj.get("success", False) - def vote_comment(self, comment_id: int, vote: Literal[-1, 0, 1], reason: Optional[VoteReason] = None): + def vote_comment( + self, + comment_id: int, + vote: Literal[-1, 0, 1], + reason: Optional[VoteReason] = None, + ): if not self.ensure_login(): return None response = self.session.post( url=self.patch_url(f"comments/{comment_id}/vote"), - data=self.patch_auth({"vote": vote, "reason": reason.value if reason else None}), + data=self.patch_auth( + {"vote": vote, "reason": reason.value if reason else None} + ), ) obj = response.json() return obj.get("success", False) @@ -206,5 +215,3 @@ class Api: url=self.patch_url("users/me/notif-feed"), params=self.patch_auth() ) return response.json().get("data", {}).get("items", []) - - diff --git a/test.py b/test.py index efd66fb..5311a53 100644 --- a/test.py +++ b/test.py @@ -1,8 +1,9 @@ -import requests import json import uuid from datetime import datetime -from typing import Dict, Any, Optional, List +from typing import Any, Dict, List, Optional + +import requests # Configuration BASE_URL: str = "https://devrant.com/api" @@ -24,7 +25,9 @@ AUTH_USER_ID: Optional[str] = None # Mock/fallback values (overridden after login or fetch) TEST_EMAIL: str = "test@example.com" -TEST_USERNAME: str = "testuser" + str(int(datetime.now().timestamp())) # Make unique for registration +TEST_USERNAME: str = "testuser" + str( + int(datetime.now().timestamp()) +) # Make unique for registration TEST_PASSWORD: str = "Test1234!" TEST_RANT_ID: str = "1" # Will be overridden with real one TEST_COMMENT_ID: str = "1" # Will be overridden with real one @@ -33,16 +36,25 @@ TEST_NEWS_ID: str = "1" # Assuming this might work; adjust if needed # Initialize results results: List[Dict[str, Any]] = [] + def save_results() -> None: """Save the accumulated test results to JSON file.""" - with open(RESULTS_FILE, 'w') as f: + with open(RESULTS_FILE, "w") as f: json.dump(results, f, indent=2) -def test_endpoint(method: str, url: str, params: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None, files: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + +def test_endpoint( + method: str, + url: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, +) -> Dict[str, Any]: """ Execute an API request and record the result. - Payload: + Payload: - method: HTTP method (GET, POST, DELETE, etc.) - url: Full API URL - params: Query parameters (dict) @@ -54,7 +66,9 @@ def test_endpoint(method: str, url: str, params: Optional[Dict[str, Any]] = None - Returns a dict with url, method, status_code, response (JSON or error), headers, request_body, timestamp """ try: - response = requests.request(method, url, params=params, data=data, files=files, headers=headers) + response = requests.request( + method, url, params=params, data=data, files=files, headers=headers + ) result: Dict[str, Any] = { "url": response.url, "method": method, @@ -62,7 +76,7 @@ def test_endpoint(method: str, url: str, params: Optional[Dict[str, Any]] = None "response": response.json() if response.content else {}, "headers": dict(response.headers), "request_body": data or params or {}, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } results.append(result) return result @@ -74,11 +88,12 @@ def test_endpoint(method: str, url: str, params: Optional[Dict[str, Any]] = None "response": {"error": str(e)}, "headers": {}, "request_body": data or params or {}, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } results.append(result) return result + # Helper to patch auth into params/data def patch_auth(base_dict: Dict[str, Any]) -> Dict[str, Any]: """ @@ -89,14 +104,17 @@ def patch_auth(base_dict: Dict[str, Any]) -> Dict[str, Any]: """ auth_dict: Dict[str, Any] = {"app": APP} if AUTH_USER_ID and AUTH_TOKEN_ID and AUTH_TOKEN_KEY: - auth_dict.update({ - "user_id": AUTH_USER_ID, - "token_id": AUTH_TOKEN_ID, - "token_key": AUTH_TOKEN_KEY - }) + auth_dict.update( + { + "user_id": AUTH_USER_ID, + "token_id": AUTH_TOKEN_ID, + "token_key": AUTH_TOKEN_KEY, + } + ) base_dict.update(auth_dict) return base_dict + # Login function to get real auth tokens def login_user() -> bool: """ @@ -111,9 +129,11 @@ def login_user() -> bool: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } - result = test_endpoint("POST", f"{BASE_URL}/users/auth-token", data=patch_auth(params)) + result = test_endpoint( + "POST", f"{BASE_URL}/users/auth-token", data=patch_auth(params) + ) if result["status_code"] == 200 and result.get("response", {}).get("success"): auth_token = result["response"].get("auth_token", {}) global AUTH_USER_ID, AUTH_TOKEN_ID, AUTH_TOKEN_KEY @@ -123,6 +143,7 @@ def login_user() -> bool: return True return False + # Fetch a real rant_id from feed def fetch_real_rant_id() -> Optional[str]: """ @@ -131,19 +152,17 @@ def fetch_real_rant_id() -> Optional[str]: Payload: GET to /devrant/rants with auth Response: First rant_id if success, else None """ - params: Dict[str, Any] = { - "plat": PLAT, - "guid": GUID, - "sid": SID, - "seid": SEID - } - result = test_endpoint("GET", f"{BASE_URL}/devrant/rants", params=patch_auth(params)) + params: Dict[str, Any] = {"plat": PLAT, "guid": GUID, "sid": SID, "seid": SEID} + result = test_endpoint( + "GET", f"{BASE_URL}/devrant/rants", params=patch_auth(params) + ) if result["status_code"] == 200 and result.get("response", {}).get("success"): rants = result["response"].get("rants", []) if rants: return str(rants[0]["id"]) return None + # Post a test rant and return its id def post_test_rant() -> Optional[str]: """ @@ -158,13 +177,14 @@ def post_test_rant() -> Optional[str]: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } result = test_endpoint("POST", f"{BASE_URL}/devrant/rants", data=patch_auth(data)) if result["status_code"] == 200 and result.get("response", {}).get("success"): return str(result["response"].get("rant_id", "")) return None + # Post a test comment and return its id def post_test_comment(rant_id: str) -> Optional[str]: """ @@ -178,15 +198,19 @@ def post_test_comment(rant_id: str) -> Optional[str]: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } - result = test_endpoint("POST", f"{BASE_URL}/devrant/rants/{rant_id}/comments", data=patch_auth(data)) + result = test_endpoint( + "POST", f"{BASE_URL}/devrant/rants/{rant_id}/comments", data=patch_auth(data) + ) if result["status_code"] == 200 and result.get("response", {}).get("success"): return str(result["response"].get("comment_id", "")) return None + # Test cases with docstrings + def test_register_user() -> None: """ Test user registration (valid and invalid). @@ -205,13 +229,14 @@ def test_register_user() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } test_endpoint("POST", f"{BASE_URL}/users", data=patch_auth(params.copy())) invalid_params = params.copy() del invalid_params["email"] test_endpoint("POST", f"{BASE_URL}/users", data=patch_auth(invalid_params)) + def test_login_user() -> None: """ Test user login (valid and invalid). Already done in login_user(), but record here. @@ -229,10 +254,11 @@ def test_login_user() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } test_endpoint("POST", f"{BASE_URL}/users/auth-token", data=patch_auth(params)) + def test_edit_profile() -> None: """ Test editing user profile. @@ -249,10 +275,11 @@ def test_edit_profile() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } test_endpoint("POST", f"{BASE_URL}/users/me/edit-profile", data=patch_auth(params)) + def test_forgot_password() -> None: """ Test forgot password. @@ -265,10 +292,11 @@ def test_forgot_password() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } test_endpoint("POST", f"{BASE_URL}/users/forgot-password", data=patch_auth(params)) + def test_resend_confirm() -> None: """ Test resend confirmation email. @@ -276,13 +304,11 @@ def test_resend_confirm() -> None: Payload: POST /users/me/resend-confirm with plat, guid, sid, seid, auth Expected: success=true """ - params: Dict[str, Any] = { - "plat": PLAT, - "guid": GUID, - "sid": SID, - "seid": SEID - } - test_endpoint("POST", f"{BASE_URL}/users/me/resend-confirm", data=patch_auth(params)) + params: Dict[str, Any] = {"plat": PLAT, "guid": GUID, "sid": SID, "seid": SEID} + test_endpoint( + "POST", f"{BASE_URL}/users/me/resend-confirm", data=patch_auth(params) + ) + def test_delete_account() -> None: """ @@ -301,6 +327,7 @@ def test_delete_account() -> None: # test_endpoint("DELETE", f"{BASE_URL}/users/me", params=patch_auth(params)) pass + def test_mark_news_read() -> None: """ Test mark news as read. @@ -313,9 +340,12 @@ def test_mark_news_read() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } - test_endpoint("POST", f"{BASE_URL}/users/me/mark-news-read", data=patch_auth(params)) + test_endpoint( + "POST", f"{BASE_URL}/users/me/mark-news-read", data=patch_auth(params) + ) + def test_get_rant() -> None: """ @@ -330,9 +360,12 @@ def test_get_rant() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } - test_endpoint("GET", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}", params=patch_auth(params)) + test_endpoint( + "GET", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}", params=patch_auth(params) + ) + def test_post_rant() -> None: """ @@ -344,6 +377,7 @@ def test_post_rant() -> None: # Handled in setup pass + def test_edit_rant() -> None: """ Test edit rant. @@ -351,11 +385,11 @@ def test_edit_rant() -> None: Payload: POST /devrant/rants/{rant_id} with updated rant, tags, auth Expected: success=true """ - data: Dict[str, Any] = { - "rant": "Updated test rant", - "tags": "test,python,update" - } - test_endpoint("POST", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}", data=patch_auth(data)) + data: Dict[str, Any] = {"rant": "Updated test rant", "tags": "test,python,update"} + test_endpoint( + "POST", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}", data=patch_auth(data) + ) + def test_delete_rant() -> None: """ @@ -364,13 +398,11 @@ def test_delete_rant() -> None: Payload: DELETE /devrant/rants/{rant_id} with plat, guid, sid, seid, auth Expected: success=true """ - params: Dict[str, Any] = { - "plat": PLAT, - "guid": GUID, - "sid": SID, - "seid": SEID - } - test_endpoint("DELETE", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}", params=patch_auth(params)) + params: Dict[str, Any] = {"plat": PLAT, "guid": GUID, "sid": SID, "seid": SEID} + test_endpoint( + "DELETE", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}", params=patch_auth(params) + ) + def test_vote_rant() -> None: """ @@ -384,12 +416,17 @@ def test_vote_rant() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } - test_endpoint("POST", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}/vote", data=patch_auth(params)) + test_endpoint( + "POST", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}/vote", data=patch_auth(params) + ) params["vote"] = -1 params["reason"] = "1" - test_endpoint("POST", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}/vote", data=patch_auth(params)) + test_endpoint( + "POST", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}/vote", data=patch_auth(params) + ) + def test_favorite_rant() -> None: """ @@ -398,14 +435,18 @@ def test_favorite_rant() -> None: Payload: POST /devrant/rants/{rant_id}/favorite or /unfavorite with plat, guid, sid, seid, auth Expected: success=true """ - params: Dict[str, Any] = { - "plat": PLAT, - "guid": GUID, - "sid": SID, - "seid": SEID - } - test_endpoint("POST", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}/favorite", data=patch_auth(params)) - test_endpoint("POST", f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}/unfavorite", data=patch_auth(params)) + params: Dict[str, Any] = {"plat": PLAT, "guid": GUID, "sid": SID, "seid": SEID} + test_endpoint( + "POST", + f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}/favorite", + data=patch_auth(params), + ) + test_endpoint( + "POST", + f"{BASE_URL}/devrant/rants/{TEST_RANT_ID}/unfavorite", + data=patch_auth(params), + ) + def test_get_rant_feed() -> None: """ @@ -414,14 +455,10 @@ def test_get_rant_feed() -> None: Payload: GET /devrant/rants with plat, guid, sid, seid, auth Expected: success=true, list of rants """ - params: Dict[str, Any] = { - "plat": PLAT, - "guid": GUID, - "sid": SID, - "seid": SEID - } + params: Dict[str, Any] = {"plat": PLAT, "guid": GUID, "sid": SID, "seid": SEID} test_endpoint("GET", f"{BASE_URL}/devrant/rants", params=patch_auth(params)) + def test_get_comment() -> None: """ Test get single comment. @@ -434,20 +471,24 @@ def test_get_comment() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } - test_endpoint("GET", f"{BASE_URL}/comments/{TEST_COMMENT_ID}", params=patch_auth(params)) + test_endpoint( + "GET", f"{BASE_URL}/comments/{TEST_COMMENT_ID}", params=patch_auth(params) + ) + def test_post_comment() -> None: """ Test post comment. (Handled in post_test_comment for id) - + Payload: POST /devrant/rants/{rant_id}/comments with comment, auth Expected: success=true, comment_id """ # Handled in setup pass + def test_edit_comment() -> None: """ Test edit comment. @@ -455,10 +496,11 @@ def test_edit_comment() -> None: Payload: POST /comments/{comment_id} with updated comment, auth Expected: success=true """ - data: Dict[str, Any] = { - "comment": "Updated test comment" - } - test_endpoint("POST", f"{BASE_URL}/comments/{TEST_COMMENT_ID}", data=patch_auth(data)) + data: Dict[str, Any] = {"comment": "Updated test comment"} + test_endpoint( + "POST", f"{BASE_URL}/comments/{TEST_COMMENT_ID}", data=patch_auth(data) + ) + def test_delete_comment() -> None: """ @@ -467,13 +509,11 @@ def test_delete_comment() -> None: Payload: DELETE /comments/{comment_id} with plat, guid, sid, seid, auth Expected: success=true """ - params: Dict[str, Any] = { - "plat": PLAT, - "guid": GUID, - "sid": SID, - "seid": SEID - } - test_endpoint("DELETE", f"{BASE_URL}/comments/{TEST_COMMENT_ID}", params=patch_auth(params)) + params: Dict[str, Any] = {"plat": PLAT, "guid": GUID, "sid": SID, "seid": SEID} + test_endpoint( + "DELETE", f"{BASE_URL}/comments/{TEST_COMMENT_ID}", params=patch_auth(params) + ) + def test_vote_comment() -> None: """ @@ -487,9 +527,12 @@ def test_vote_comment() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } - test_endpoint("POST", f"{BASE_URL}/comments/{TEST_COMMENT_ID}/vote", data=patch_auth(params)) + test_endpoint( + "POST", f"{BASE_URL}/comments/{TEST_COMMENT_ID}/vote", data=patch_auth(params) + ) + def test_get_notif_feed() -> None: """ @@ -504,10 +547,11 @@ def test_get_notif_feed() -> None: "plat": PLAT, "guid": GUID, "sid": SID, - "seid": SEID + "seid": SEID, } test_endpoint("GET", f"{BASE_URL}/users/me/notif-feed", params=patch_auth(params)) + def test_clear_notifications() -> None: """ Test clear notifications. @@ -515,13 +559,11 @@ def test_clear_notifications() -> None: Payload: DELETE /users/me/notif-feed with plat, guid, sid, seid, auth Expected: success=true """ - params: Dict[str, Any] = { - "plat": PLAT, - "guid": GUID, - "sid": SID, - "seid": SEID - } - test_endpoint("DELETE", f"{BASE_URL}/users/me/notif-feed", params=patch_auth(params)) + params: Dict[str, Any] = {"plat": PLAT, "guid": GUID, "sid": SID, "seid": SEID} + test_endpoint( + "DELETE", f"{BASE_URL}/users/me/notif-feed", params=patch_auth(params) + ) + def test_beta_list_signup() -> None: """ @@ -530,11 +572,10 @@ def test_beta_list_signup() -> None: Payload: GET https://www.hexicallabs.com/api/beta-list with email, platform, app Expected: Whatever the API returns (may not be JSON) """ - params: Dict[str, Any] = { - "email": TEST_EMAIL, - "platform": "test_platform" - } - test_endpoint("GET", "https://www.hexicallabs.com/api/beta-list", params=patch_auth(params)) + params: Dict[str, Any] = {"email": TEST_EMAIL, "platform": "test_platform"} + test_endpoint( + "GET", "https://www.hexicallabs.com/api/beta-list", params=patch_auth(params) + ) def main() -> None: @@ -543,7 +584,7 @@ def main() -> None: global TEST_RANT_ID TEST_RANT_ID = post_test_rant() or fetch_real_rant_id() or "1" global TEST_COMMENT_ID - TEST_COMMENT_ID = post_test_comment(TEST_RANT_ID) or "1" + TEST_COMMENT_ID = post_test_comment(TEST_RANT_ID) or "1" test_register_user() test_login_user() @@ -552,7 +593,7 @@ def main() -> None: test_resend_confirm() test_mark_news_read() test_get_rant() - test_post_rant() + test_post_rant() test_edit_rant() test_vote_rant() test_favorite_rant() @@ -570,5 +611,6 @@ def main() -> None: test_delete_account() save_results() + if __name__ == "__main__": main()