|
import asyncio
|
|
import http.client
|
|
import json
|
|
|
|
|
|
class GrokAPIClient:
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
system_message: str | None = None,
|
|
model: str = "grok-3-mini",
|
|
temperature: float = 0.0,
|
|
):
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.base_url = "api.x.ai"
|
|
self.temperature = temperature
|
|
self._messages: list[dict[str, str]] = []
|
|
if system_message:
|
|
self._messages.append({"role": "system", "content": system_message})
|
|
|
|
def chat_json(self, user_message: str, *, clear_history: bool = False) -> str:
|
|
return self.chat(user_message, clear_history=clear_history, use_json=True)
|
|
|
|
def chat_text(self, user_message: str, *, clear_history: bool = False) -> str:
|
|
return self.chat(user_message, clear_history=clear_history, use_json=False)
|
|
|
|
async def chat_async(self, *args, **kwargs):
|
|
return await asyncio.to_thread(self.chat, *args, **kwargs)
|
|
|
|
def chat(
|
|
self,
|
|
user_message: str,
|
|
*,
|
|
clear_history: bool = False,
|
|
use_json=False,
|
|
temperature: float = None,
|
|
) -> str:
|
|
if clear_history:
|
|
self.reset_history(keep_system=True)
|
|
self._messages.append({"role": "user", "content": user_message})
|
|
conn = http.client.HTTPSConnection(self.base_url)
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
if temperature is None:
|
|
temperature = self.temperature
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": self._messages,
|
|
"temperature": temperature,
|
|
}
|
|
conn.request(
|
|
"POST", "/v1/chat/completions", body=json.dumps(payload), headers=headers
|
|
)
|
|
response = conn.getresponse()
|
|
data = response.read()
|
|
try:
|
|
data = json.loads(data.decode())
|
|
except Exception as e:
|
|
print(data, flush=True)
|
|
raise e
|
|
conn.close()
|
|
try:
|
|
assistant_reply = data["choices"][0]["message"]["content"]
|
|
except Exception as e:
|
|
print(e)
|
|
print(data)
|
|
assistant_reply = data
|
|
self._messages.append({"role": "assistant", "content": assistant_reply})
|
|
if use_json:
|
|
return self._force_json(assistant_reply)
|
|
return assistant_reply
|
|
|
|
def _force_json(self, user_message: str) -> str:
|
|
try:
|
|
return json.loads(user_message)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
try:
|
|
return json.loads(user_message.split("\n")[1:-1])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
try:
|
|
index_start = -1
|
|
index_end = -1
|
|
chunks = []
|
|
for index, line in enumerate(user_message.split("\n")):
|
|
if "```json" in line:
|
|
index_start = index + 1
|
|
if index_start != -1 and "```" in line:
|
|
index_end = index - 1
|
|
chunks.append(
|
|
self._force_json(
|
|
user_message.split("\n")[index_start:index_end]
|
|
)
|
|
)
|
|
index_start = -1
|
|
index_end = -1
|
|
if chunks:
|
|
return chunks
|
|
except:
|
|
pass
|
|
return user_message
|
|
|
|
def reset_history(self, *, keep_system: bool = True) -> None:
|
|
if keep_system and self._messages and self._messages[0]["role"] == "system":
|
|
self._messages = [self._messages[0]]
|
|
else:
|
|
self._messages = []
|
|
|
|
@property
|
|
def messages(self) -> list[dict[str, str]]:
|
|
return list(self._messages)
|
|
|
|
|
|
def prompt(
|
|
prompt_str: str, system_message: str = "You are a helpful assistan", use_json=True
|
|
) -> str:
|
|
client = GrokAPIClient(system_message=system_message)
|
|
return client.chat(prompt_str, use_json=use_json)
|