286 lines
12 KiB
Python
286 lines
12 KiB
Python
|
import asyncio
|
||
|
import aiohttp
|
||
|
import json
|
||
|
import argparse
|
||
|
import sys
|
||
|
from typing import Dict, List, Any, Optional
|
||
|
import ais
|
||
|
|
||
|
class OpenAPIFormatter:
|
||
|
def __init__(self, openapi_spec: Dict[str, Any]):
|
||
|
self.spec = openapi_spec
|
||
|
self.title = openapi_spec.get("info", {}).get("title", "API")
|
||
|
self.description = openapi_spec.get("info", {}).get("description", "")
|
||
|
self.paths = openapi_spec.get("paths", {})
|
||
|
self.components = openapi_spec.get("components", {})
|
||
|
self.schemas = self.components.get("schemas", {})
|
||
|
|
||
|
def extract_endpoints(self) -> List[Dict[str, str]]:
|
||
|
endpoints = []
|
||
|
for path, methods in self.paths.items():
|
||
|
for method, details in methods.items():
|
||
|
if method.upper() in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
|
||
|
endpoints.append({
|
||
|
"endpoint": path,
|
||
|
"method": method.upper(),
|
||
|
"summary": details.get("summary", ""),
|
||
|
"description": details.get("description", "")
|
||
|
})
|
||
|
return endpoints
|
||
|
|
||
|
def get_request_body_schema(self, path: str, method: str) -> Optional[Dict]:
|
||
|
try:
|
||
|
method_info = self.paths[path][method.lower()]
|
||
|
request_body = method_info.get("requestBody", {})
|
||
|
content = request_body.get("content", {})
|
||
|
json_content = content.get("application/json", {})
|
||
|
schema = json_content.get("schema", {})
|
||
|
return self.resolve_schema_ref(schema) if schema else None
|
||
|
except KeyError:
|
||
|
return None
|
||
|
|
||
|
def get_parameters(self, path: str, method: str) -> List[Dict]:
|
||
|
try:
|
||
|
method_info = self.paths[path][method.lower()]
|
||
|
params = method_info.get("parameters", [])
|
||
|
resolved = []
|
||
|
for p in params:
|
||
|
if "$ref" in p:
|
||
|
ref_path = p["$ref"]
|
||
|
if ref_path.startswith("#/components/parameters/"):
|
||
|
param_name = ref_path.split("/")[-1]
|
||
|
p = self.components.get("parameters", {}).get(param_name, {})
|
||
|
schema = p.get("schema", {})
|
||
|
schema = self.resolve_schema_ref(schema)
|
||
|
resolved.append({
|
||
|
"name": p.get("name"),
|
||
|
"in": p.get("in"),
|
||
|
"description": p.get("description", ""),
|
||
|
"required": p.get("required", False),
|
||
|
"schema": schema
|
||
|
})
|
||
|
return resolved
|
||
|
except KeyError:
|
||
|
return []
|
||
|
|
||
|
def resolve_schema_ref(self, schema: Dict) -> Dict:
|
||
|
if "$ref" in schema:
|
||
|
ref_path = schema["$ref"]
|
||
|
if ref_path.startswith("#/components/schemas/"):
|
||
|
schema_name = ref_path.split("/")[-1]
|
||
|
return self.schemas.get(schema_name, {})
|
||
|
return schema
|
||
|
|
||
|
def generate_example_body(self, schema: Dict) -> Dict:
|
||
|
if not schema:
|
||
|
return {}
|
||
|
schema = self.resolve_schema_ref(schema)
|
||
|
properties = schema.get("properties", {})
|
||
|
example = {}
|
||
|
for prop_name, prop_schema in properties.items():
|
||
|
prop_schema = self.resolve_schema_ref(prop_schema)
|
||
|
prop_type = prop_schema.get("type", "string")
|
||
|
if prop_type == "array":
|
||
|
items_schema = self.resolve_schema_ref(prop_schema.get("items", {}))
|
||
|
if items_schema.get("type") == "object":
|
||
|
example[prop_name] = [self.generate_example_body(items_schema)]
|
||
|
else:
|
||
|
example[prop_name] = [self.get_example_value(items_schema)]
|
||
|
elif prop_type == "object":
|
||
|
example[prop_name] = self.generate_example_body(prop_schema)
|
||
|
else:
|
||
|
example[prop_name] = self.get_example_value(prop_schema)
|
||
|
return example
|
||
|
|
||
|
def get_example_value(self, schema: Dict) -> Any:
|
||
|
if not schema:
|
||
|
return "example_value"
|
||
|
prop_type = schema.get("type", "string")
|
||
|
description = schema.get("description", "").lower()
|
||
|
title = schema.get("title", "").lower()
|
||
|
if prop_type == "string":
|
||
|
if "name" in description or "name" in title:
|
||
|
return "Example Name"
|
||
|
elif "type" in description or "type" in title:
|
||
|
return "example_type"
|
||
|
elif "query" in description or "search" in description:
|
||
|
return "search_term"
|
||
|
else:
|
||
|
return "example_value"
|
||
|
elif prop_type == "integer":
|
||
|
return 1
|
||
|
elif prop_type == "number":
|
||
|
return 1.0
|
||
|
elif prop_type == "boolean":
|
||
|
return True
|
||
|
elif prop_type == "array":
|
||
|
return []
|
||
|
else:
|
||
|
return "example_value"
|
||
|
|
||
|
def generate_instructions(self) -> Dict[str, Any]:
|
||
|
endpoints = self.extract_endpoints()
|
||
|
example_single = self.generate_single_call_example(endpoints)
|
||
|
example_multiple = self.generate_multiple_call_example(endpoints)
|
||
|
instructions = {
|
||
|
"llm_api_instruction": {
|
||
|
"title": f"API Call Instructions for {self.title}",
|
||
|
"overview": f"When you need to make API calls to {self.title}, respond with a JSON array containing API call objects. Each object represents one API call to execute.",
|
||
|
"api_description": self.description,
|
||
|
"required_response_format": {
|
||
|
"description": "Your response must be a valid JSON array of objects in this exact format:",
|
||
|
"format": [
|
||
|
{
|
||
|
"endpoint": "string - endpoint path with path parameters substituted",
|
||
|
"method": "string - HTTP method (GET, POST, PUT, DELETE, etc.)",
|
||
|
"query": "object - query parameters (omit if not needed)",
|
||
|
"body": "object - request body (omit if not needed)"
|
||
|
}
|
||
|
],
|
||
|
"example_single_call": example_single,
|
||
|
"example_multiple_calls": example_multiple
|
||
|
},
|
||
|
"available_endpoints": endpoints,
|
||
|
"critical_rules": [
|
||
|
"ALWAYS return a JSON array, even for single calls: [{...}]",
|
||
|
"Use EXACT endpoint paths from the available_endpoints list, substituting path parameters {param} with actual values",
|
||
|
"For query parameters ('in': 'query' in endpoint_details), include them in the 'query' field",
|
||
|
"For path parameters ('in': 'path'), substitute directly in the endpoint path",
|
||
|
"Include ALL required parameters and fields",
|
||
|
"Match data types from schemas in endpoint_details",
|
||
|
"Replace example values with actual data from user's request",
|
||
|
"Omit 'query' and 'body' fields if not needed",
|
||
|
"Ensure your JSON is valid and parseable"
|
||
|
],
|
||
|
"endpoint_details": self.generate_endpoint_details()
|
||
|
}
|
||
|
}
|
||
|
return instructions
|
||
|
|
||
|
def generate_single_call_example(self, endpoints: List[Dict]) -> List[Dict]:
|
||
|
get_endpoints = [ep for ep in endpoints if ep["method"] == "GET"]
|
||
|
if get_endpoints:
|
||
|
endpoint = get_endpoints[0]
|
||
|
else:
|
||
|
post_endpoints = [ep for ep in endpoints if ep["method"] == "POST"]
|
||
|
if not post_endpoints:
|
||
|
return [{"endpoint": "/example", "method": "GET"}]
|
||
|
endpoint = post_endpoints[0]
|
||
|
return self._generate_example(endpoint)
|
||
|
|
||
|
def generate_multiple_call_example(self, endpoints: List[Dict]) -> List[Dict]:
|
||
|
examples = []
|
||
|
for endpoint in endpoints[:2]:
|
||
|
examples.append(self._generate_example(endpoint)[0])
|
||
|
return examples if examples else [{"endpoint": "/example", "method": "GET"}]
|
||
|
|
||
|
def _generate_example(self, endpoint: Dict) -> List[Dict]:
|
||
|
path = endpoint["endpoint"]
|
||
|
method = endpoint["method"]
|
||
|
params = self.get_parameters(path, method)
|
||
|
path_params = {p["name"]: self.get_example_value(p["schema"]) for p in params if p["in"] == "path"}
|
||
|
query_params = {p["name"]: self.get_example_value(p["schema"]) for p in params if p["in"] == "query"}
|
||
|
endpoint_str = path
|
||
|
if path_params:
|
||
|
try:
|
||
|
endpoint_str = path.format(**{k: str(v) for k, v in path_params.items()})
|
||
|
except:
|
||
|
pass
|
||
|
example = {
|
||
|
"endpoint": endpoint_str,
|
||
|
"method": method
|
||
|
}
|
||
|
if query_params:
|
||
|
example["query"] = query_params
|
||
|
schema = self.get_request_body_schema(path, method)
|
||
|
if schema:
|
||
|
example["body"] = self.generate_example_body(schema)
|
||
|
return [example]
|
||
|
|
||
|
def generate_endpoint_details(self) -> Dict[str, Any]:
|
||
|
details = {}
|
||
|
for path, methods in self.paths.items():
|
||
|
for method, info in methods.items():
|
||
|
if method.upper() in ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']:
|
||
|
key = f"{method.upper()} {path}"
|
||
|
schema = self.get_request_body_schema(path, method)
|
||
|
details[key] = {
|
||
|
"summary": info.get("summary", ""),
|
||
|
"description": info.get("description", ""),
|
||
|
"parameters": self.get_parameters(path, method),
|
||
|
"request_body_schema": schema if schema else None
|
||
|
}
|
||
|
return details
|
||
|
|
||
|
async def fetch_openapi_spec(url: str) -> Dict[str, Any]:
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
async with session.get(url, timeout=30) as resp:
|
||
|
if resp.status != 200:
|
||
|
text = await resp.text()
|
||
|
raise Exception(f"Failed to fetch OpenAPI spec from {url}: {resp.status} {text}")
|
||
|
return await resp.json()
|
||
|
|
||
|
async def get_system_message(url: str) -> Dict[str, Any]:
|
||
|
openapi_spec = await fetch_openapi_spec(url)
|
||
|
formatter = OpenAPIFormatter(openapi_spec)
|
||
|
return formatter.generate_instructions()
|
||
|
|
||
|
async def prompt(message, system, model="gpt-4o"):
|
||
|
system = json.dumps(system['llm_api_instruction'])
|
||
|
client = ais.AIS(model="gemma", system_message=system)
|
||
|
content = client.chat(message)
|
||
|
return content
|
||
|
|
||
|
async def make_call(session, BASE_URL, call):
|
||
|
url_ = f"{BASE_URL.rstrip('/')}/{call['endpoint'].lstrip('/')}"
|
||
|
method = call['method'].upper()
|
||
|
query = call.get('query')
|
||
|
body = call.get('body')
|
||
|
params = query
|
||
|
params = {k: v for k, v in (params or {}).items() if v is not None}
|
||
|
json_payload = body if method not in ["GET", "HEAD", "DELETE"] else None
|
||
|
print(call)
|
||
|
async with session.request(method, url_, params=params, json=json_payload) as resp:
|
||
|
try:
|
||
|
response = await resp.json()
|
||
|
except Exception:
|
||
|
response = await resp.text()
|
||
|
full_url = str(resp.url)
|
||
|
print(f"HTTP Query: {method} {full_url} with body: {json_payload}, got response: {response} ({resp.status})", flush=True, file=sys.stderr)
|
||
|
return response
|
||
|
|
||
|
async def run_service(url, p):
|
||
|
from urllib.parse import urljoin, urlparse
|
||
|
|
||
|
parsed_url = urlparse(url)
|
||
|
BASE_URL = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||
|
|
||
|
system_message_dict = await get_system_message(url)
|
||
|
system_message = system_message_dict
|
||
|
calls = await prompt(p, system_message)
|
||
|
results = []
|
||
|
if isinstance(calls, list):
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
tasks = [asyncio.create_task(make_call(session, BASE_URL, call)) for call in calls if isinstance(call, dict)]
|
||
|
results += await asyncio.gather(*tasks)
|
||
|
return results
|
||
|
|
||
|
DEFAULT_URLS = [
|
||
|
"https://ada.molodetz.nl/api/memory/openapi.json",
|
||
|
]
|
||
|
|
||
|
async def run_services(p, urls=None):
|
||
|
if not urls:
|
||
|
urls = DEFAULT_URLS
|
||
|
tasks = [asyncio.create_task(run_service(url, p)) for url in urls]
|
||
|
return await asyncio.gather(*tasks)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# Example:
|
||
|
# python3 openapi.py --url=http://localhost:8000/openapi.json --prompt "Make a note for X and make a note for Y"
|
||
|
parser = argparse.ArgumentParser(description="API Caller")
|
||
|
parser.add_argument('--url', type=str, default="https://tools.molodetz.online/openapi.json", help="Base URL for API calls")
|
||
|
parser.add_argument('--prompt', type=str, help="Prompt for LLM")
|
||
|
args = parser.parse_args()
|
||
|
print(asyncio.run(run_service(args.url, args.prompt)))
|