diff --git a/setup.cfg b/setup.cfg index 78dc458..18b36bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,7 @@ install_requires = aiohttp dataset ipython + openai [options.packages.find] where = src diff --git a/src/app.egg-info/PKG-INFO b/src/app.egg-info/PKG-INFO index 7a62d0b..e856895 100644 --- a/src/app.egg-info/PKG-INFO +++ b/src/app.egg-info/PKG-INFO @@ -10,3 +10,4 @@ Description-Content-Type: text/markdown Requires-Dist: aiohttp Requires-Dist: dataset Requires-Dist: ipython +Requires-Dist: openai diff --git a/src/app.egg-info/SOURCES.txt b/src/app.egg-info/SOURCES.txt index 33692c7..b813992 100644 --- a/src/app.egg-info/SOURCES.txt +++ b/src/app.egg-info/SOURCES.txt @@ -2,11 +2,13 @@ pyproject.toml setup.cfg src/app/__init__.py src/app/__main__.py +src/app/agent.py src/app/app.py src/app/args.py src/app/cli.py src/app/kim.py src/app/repl.py +src/app/rpc.py src/app/server.py src/app/tests.py src/app.egg-info/PKG-INFO diff --git a/src/app.egg-info/requires.txt b/src/app.egg-info/requires.txt index 7228e0b..bf2defb 100644 --- a/src/app.egg-info/requires.txt +++ b/src/app.egg-info/requires.txt @@ -1,3 +1,4 @@ aiohttp dataset ipython +openai diff --git a/src/app/agent.py b/src/app/agent.py new file mode 100644 index 0000000..7221213 --- /dev/null +++ b/src/app/agent.py @@ -0,0 +1,214 @@ +""" +Written in 2024 by retoor@molodetz.nl. + +MIT license. Enjoy! + +You'll need a paid OpenAI account, named a project in it, requested an api key and created an assistant. +URL's to all these pages are described in the class for convenience. + +The API keys described in this document are fake but are in the correct format for educational purposes. + +How to start: + - sudo apt install python3.12-venv python3-pip -y + - python3 -m venv .venv + - . .venv/bin/activate + - pip install openapi + +This file is to be used as part of your project or a standalone after doing +some modifications at the end of the file. +""" + +try: + import os + import sys + + sys.path.append(os.getcwd()) + import env + + API_KEY = env.API_KEY + ASSISTANT_ID = env.ASSISTANT_ID +except: + pass + + +import asyncio +import functools +from collections.abc import Generator +from typing import Optional + +from openai import OpenAI + + +class Agent: + """ + This class translates into an instance a single user session with its own memory. + + The messages property of this class is a list containing the full chat history about + what the user said and what the assistant (agent) said. This can be used in future to continue + where you left off. Format is described in the docs of __init__ function below. + + Introduction API usage for if you want to extend this class: + https://platform.openai.com/docs/api-reference/introduction + """ + + def __init__( + self, api_key: str, assistant_id: int, messages: Optional[list] = None + ): + """ + You can find and create API keys here: + https://platform.openai.com/api-keys + + You can find assistant_id (agent_id) here. It is the id that starts with 'asst_', not your custom name: + https://platform.openai.com/assistants/ + + Messages are optional in this format, this is to keep a message history that you can later use again: + [ + {"role": "user", "message": "What is choking the chicken?"}, + {"role": "assistant", "message": "Lucky for the cock."} + ] + """ + + self.assistant_id = assistant_id + self.api_key = api_key + self.client = OpenAI(api_key=self.api_key) + self.messages = messages or [] + self.thread = self.client.beta.threads.create(messages=self.messages) + + async def dalle2( + self, prompt: str, width: Optional[int] = 512, height: Optional[int] = 512 + ) -> dict: + """ + In my opinion dall-e-2 produces unusual results. + Sizes: 256x256, 512x512 or 1024x1024. + """ + result = self.client.images.generate( + model="dall-e-2", prompt=prompt, n=1, size=f"{width}x{height}" + ) + return result + + @property + async def models(self): + """ + List models in dict format. That's more convenient than the original + list method because this can be directly converted to json to be used + in your front end or api. That's not the original result which is a + custom list with unserializable models. + """ + return [ + { + "id": model.id, + "owned_by": model.owned_by, + "object": model.object, + "created": model.created, + } + for model in self.client.models.list() + ] + + async def dalle3( + self, prompt: str, height: Optional[int] = 1024, width: Optional[int] = 1024 + ) -> dict: + """ + Sadly only big sizes allowed. Is more pricy. + Sizes: 1024x1024, 1792x1024, or 1024x1792. + """ + result = self.client.images.generate( + model="dall-e-3", prompt=prompt, n=1, size=f"{width}x{height}" + ) + print(result) + return result + + async def chat( + self, message: str, interval: Optional[float] = 0.2 + ) -> Generator[None, None, str]: + """ + Chat with the agent. It yields on given interval to inform the caller it' still busy so you can + update the user with live status. It doesn't hang. You can use this fully async with other + instances of this class. + + This function also updates the self.messages list with chat history for later use. + """ + message_object = {"role": "user", "content": message} + self.messages.append(message_object) + self.client.beta.threads.messages.create( + self.thread.id, + role=message_object["role"], + content=message_object["content"], + ) + run = self.client.beta.threads.runs.create( + thread_id=self.thread.id, assistant_id=self.assistant_id + ) + + while run.status != "completed": + run = self.client.beta.threads.runs.retrieve( + thread_id=self.thread.id, run_id=run.id + ) + yield None + await asyncio.sleep(interval) + + response_messages = self.client.beta.threads.messages.list( + thread_id=self.thread.id + ).data + last_message = response_messages[0].content[0].text.value + self.messages.append({"role": "assistant", "content": last_message}) + print(last_message) + yield str(last_message) + + async def chatp(self, message: str) -> str: + """ + Just like regular chat function but with progress indication and returns string directly. + This is handy for interactive usage or for a process log. + """ + asyncio.get_event_loop() + print("Processing", end="") + async for message in self.chat(message): + if not message: + print(".", end="", flush=True) + continue + print("") + break + return message + + async def read_line(self, ps: Optional[str] = "> "): + """ + Non blocking read_line. + Blocking read line can break web socket connections. + That's why. + """ + loop = asyncio.get_event_loop() + patched_input = functools.partial(input, ps) + return await loop.run_in_executor(None, patched_input) + + async def cli(self): + """ + Interactive client. Can be used on terminal by user or a different process. + The bottom new line is so that a process can check for \n\n to check if it's end response + and there's nothing left to wait for and thus can send next prompt if the '>' shows. + """ + while True: + try: + message = await self.read_line("> ") + if not message.strip(): + continue + response = await self.chatp(message) + print(response.content[0].text.value) + print("") + except KeyboardInterrupt: + print("Exiting..") + break + + +async def main(): + """ + Example main function. The keys here are not real but look exactly like + the real ones for example purposes and that you're sure your key is in the + right format. + """ + agent = Agent(api_key=API_KEY, assistant_id=ASSISTANT_ID) + + # Run interactive chat + await agent.cli() + + +if __name__ == "__main__": + # Only gets executed by direct execution of script. Not when important. + asyncio.run(main()) diff --git a/src/app/app.py b/src/app/app.py index 2c8d204..cf3b8f5 100644 --- a/src/app/app.py +++ b/src/app/app.py @@ -7,6 +7,9 @@ import uuid import dataset from aiohttp import web +from app.agent import Agent +from app.rpc import Application as RPCApplication + from . import log @@ -18,7 +21,7 @@ def get_timestamp(): return formatted_datetime -class BaseApplication(web.Application): +class BaseApplication(RPCApplication, web.Application): def __init__( self, @@ -37,6 +40,7 @@ class BaseApplication(web.Application): middlewares.append(self.request_middleware) middlewares.append(self.base64_auth_middleware) middlewares.append(self.session_middleware) + self.agents = {} super().__init__(middlewares=middlewares, *args, **kwargs) def run(self, *args, **kwargs): @@ -48,6 +52,24 @@ class BaseApplication(web.Application): async def authenticate(self, username, password): return self.basic_username == username and self.basic_password == password + async def agent_create_thread(self, api_key, assistent_id): + agent = Agent(api_key, assistent_id) + self.agents[str(agent.thread.id)] = agent + return str(agent.thread.id) + + async def rpc_agent_create_thread(self, api_key, assistent_id): + return await self.agent_create_thread(api_key, assistent_id) + + async def agent_prompt(self, thread_id, message): + try: + agent = self.agents[str(thread_id)] + return await agent.chat(message) + except Exception as ex: + return str(ex) + + async def rpc_agent_prompt(self, thread_id, message): + return await self.agent_prompt(str(thread_id), message) + @web.middleware async def base64_auth_middleware(self, request, handler): auth_header = request.headers.get("Authorization") @@ -126,6 +148,14 @@ class WebDbApplication(BaseApplication): self.router.add_post("/db/delete", self.delete_handler) self.router.add_post("/db/get", self.get_handler) self.router.add_post("/db/set", self.set_handler) + self.rpc_set = self.set + self.rpc_get = self.get + self.rpc_insert = self.insert + self.rpc_update = self.update + self.rpc_upsert = self.upsert + self.rpc_find = self.find + self.rpc_fine_one = self.find_one + self.rpc_delete = self.delete async def set_handler(self, request): obj = await request.json() @@ -191,10 +221,10 @@ class WebDbApplication(BaseApplication): async def insert(self, table_name, data): return self.db[table_name].insert(data) - async def update(self, table_name, data, where): - return self.db[table_name].update(data, where) + async def update(self, table_name, data, where=None): + return self.db[table_name].update(data, where or {}) - async def upsert(self, table_name, data, keys): + async def upsert(self, table_name, data, keys=None): return self.db[table_name].upsert(data, keys or []) async def find(self, table_name, filters=None): @@ -210,7 +240,8 @@ class WebDbApplication(BaseApplication): except ValueError: return None - async def delete(self, table_name, where): + async def delete(self, table_name, where=None): + where = where or {} return self.db[table_name].delete(**where) diff --git a/src/app/rpc.py b/src/app/rpc.py new file mode 100644 index 0000000..cd0ec8b --- /dev/null +++ b/src/app/rpc.py @@ -0,0 +1,252 @@ + +from xmlrpc.server import resolve_dotted_attribute +from xmlrpc.client import Fault, dumps, loads, gzip_encode, gzip_decode, ServerProxy,MultiCall +from functools import partial +from inspect import signature +from aiohttp import web +from datetime import datetime + + +class AsyncSimpleXMLRPCDispatcher: + """ + Original not async version of this class is in the original python std lib: + https://github.com/python/cpython/blob/main/Lib/xmlrpc/server.py. + + use_builtin_types=True allows the use of bytes-object which is preferred + because else it's a custom xmlrpc.client.Binary which sucks. + """ + + def __init__(self, instance,allow_none=True, encoding="utf-8", use_builtin_types=True): + self.setup_rpc(allow_none=allow_none, encoding=encoding, use_builtin_types=True) + self.register_instance(instance,True) + + def setup_rpc(self, allow_none=True, encoding="utf-8", + use_builtin_types=True): + self.funcs = {} + self.instance = None + self.allow_none = allow_none + self.encoding = encoding or 'utf-8' + self.use_builtin_types = use_builtin_types + + def register_instance(self, instance, allow_dotted_names=True): + self.instance = instance + self.allow_dotted_names = allow_dotted_names + self.register_multicall_functions() + self.register_introspection_functions() + + def register_function(self, function=None, name=None): + if function is None: + return partial(self.register_function, name=name) + + if name is None: + name = function.__name__ + self.funcs[name] = function + + return function + + def register_introspection_functions(self): + self.funcs.update({'system.listMethods' : self.system_listMethods, + 'system.methodSignature' : self.system_methodSignature, + 'system.methodHelp' : self.system_methodHelp}) + + def register_multicall_functions(self): + self.funcs.update({'system.multicall' : self.system_multicall}) + + async def _marshaled_dispatch(self, data, dispatch_method = None, path = None): + try: + params, method = loads(data, use_builtin_types=self.use_builtin_types) + + if dispatch_method is not None: + response = dispatch_method(method, params) + else: + response = await self._dispatch(method, params) + response = (response,) + response = dumps(response, methodresponse=1, + allow_none=self.allow_none, encoding=self.encoding) + except Fault as fault: + response = dumps(fault, allow_none=self.allow_none, + encoding=self.encoding) + except BaseException as exc: + response = dumps( + Fault(1, "%s:%s" % (type(exc), exc)), + encoding=self.encoding, allow_none=self.allow_none, + ) + + return response.encode(self.encoding, 'xmlcharrefreplace') + + def system_listMethods(self): + methods = set(self.funcs.keys()) + if self.instance is not None: + if hasattr(self.instance, '_listMethods'): + methods |= set(self.instance._listMethods()) + elif not hasattr(self.instance, '_dispatch'): + methods |= set(list_public_methods(self.instance)) + return sorted(methods) + + def system_methodSignature(self, method_name): + return 'signatures not supported' + + def system_methodHelp(self, method_name): + method = None + if method_name in self.funcs: + method = self.funcs[method_name] + elif self.instance is not None: + if hasattr(self.instance, '_methodHelp'): + return self.instance._methodHelp(method_name) + elif not hasattr(self.instance, '_dispatch'): + try: + method = resolve_dotted_attribute( + self.instance, + method_name, + self.allow_dotted_names + ) + except AttributeError: + pass + if method is None: + return "" + else: + return pydoc.getdoc(method) + + async def system_multicall(self, call_list): + results = [] + for call in call_list: + method_name = call['methodName'] + params = call['params'] + + try: + results.append([await self._dispatch(method_name, params)]) + except Fault as fault: + results.append( + {'faultCode' : fault.faultCode, + 'faultString' : fault.faultString} + ) + except BaseException as exc: + results.append( + {'faultCode' : 1, + 'faultString' : "%s:%s" % (type(exc), exc)} + ) + return results + + async def _dispatch(self, method, params): + try: + func = self.funcs[method] + except KeyError: + pass + else: + if func is not None: + return await func(*params) + raise Exception('method "%s" is not supported' % method) + + if self.instance is not None: + if hasattr(self.instance, '_dispatch'): + return await self.instance._dispatch(method, params) + + try: + func = resolve_dotted_attribute( + self.instance, + method, + self.allow_dotted_names + ) + except AttributeError: + pass + else: + if func is not None: + return await func(*params) + + raise Exception('method "%s" is not supported' % method) + +def rpc_wrap_instance(obj): + + class Session: + + def __init__(self,data=None): + self._data = data or {} + + async def get(self, key, default=None): + return self._data.get(key,default) + + async def set(self,key, value): + self._data[key] = value + + async def delete(self, key): + try: + del self._data[key] + return True + except KeyError: + return False + + async def exists(self,key): + return key in self._data + + class Instance: + + def __init__(self, _self): + self._self = self + self.session = Session() + + def __get__(self,key): + return getattr(self._self,key) + + def ping(self,*args,**kwargs): + return dict( + args=args, + kwargs=kwargs, + timestamp=str(datetime.now()) + ) + + instance = Instance(obj) + + for attr in dir(obj): + if attr == 'rpc_handler': + continue + if attr.startswith("rpc_") and callable(getattr(obj, attr)): + setattr(instance,attr[4:], getattr(obj,attr)) + + return instance + + +class Application(web.Application): + + def __init__(self, url=None,host=None,port=None, *args, **kwargs): + self.host = host + self.port = port + self._url = url + self._rpc = None + if self.rpc_url: + self._rpc = ServerProxy(self.rpc_url) + super().__init__(*args, **kwargs) + self.arpc = rpc_wrap_instance(self) + self.rpc_dispatcher = AsyncSimpleXMLRPCDispatcher(self.arpc) + self.router.add_post("/rpc", self.rpc_handler) + + def __get__(self, key): + if self._rpc: + return getattr(self._rpc,key) + return getattr(self.arpc,key) + + @property + def url(self): + if self._url: + return self._url + return "http://{}:{}".format(self.host,self.port) + + @property + def rpc_url(self): + return self.url.rstrip("/") + "/rpc" + + def connect(self, url): + return ServerProxy(url) + + def multicall(self, url): + return MultiCall(self.connect(url)) + + @property + def rpc(self): + if not self._rpc: + self._rpc = ServerProxy(url or self.rpc_url) + return self._rpc + + async def rpc_handler(self, request): + request_body = await request.text() + response_body = await self.rpc_dispatcher._marshaled_dispatch(request_body) + return web.Response(text=response_body.decode())