diff --git a/chat.h b/chat.h index 4fb2306..b86a189 100644 --- a/chat.h +++ b/chat.h @@ -31,6 +31,7 @@ #ifndef R_PROMPT_H #define R_PROMPT_H +#include "tools.h" #include <json-c/json.h> #include "messages.h" #include "http.h" @@ -56,10 +57,17 @@ void chat_free() { char *chat_json(char *role, char *message) { chat_free(); - message_add(role, message); - struct json_object *root_object = json_object_new_object(); json_object_object_add(root_object, "model", json_object_new_string(prompt_model)); + + if (role != NULL && message != NULL) { + message_add(role, message); + + json_object_object_add(root_object, "tools", tools_descriptions()); + } + + + json_object_object_add(root_object, "messages", message_list()); json_object_object_add(root_object, "max_tokens", json_object_new_int(prompt_max_tokens)); json_object_object_add(root_object, "temperature", json_object_new_double(prompt_temperature)); @@ -67,4 +75,4 @@ char *chat_json(char *role, char *message) { return (char *)json_object_to_json_string_ext(root_object, JSON_C_TO_STRING_PRETTY); } -#endif \ No newline at end of file +#endif diff --git a/http.h b/http.h index c5456ea..bf7a9f2 100644 --- a/http.h +++ b/http.h @@ -127,6 +127,39 @@ size_t hex_to_int(const char *hex) { return result; } +bool http_has_header(char * http_headers, char * http_header_name){ + char search_for[strlen(http_header_name)+10]; + search_for[0] = '\0'; + strcpy(search_for,http_header_name); + strcat(search_for,": "); + return strstr(http_headers, search_for) != NULL; +} + +char * http_header_get_str(char * http_headers, char * http_header_name){ + char search_for[strlen(http_header_name)+10]; + search_for[0] = '\0'; + strcpy(search_for,http_header_name); + strcat(search_for,": "); + char * header = strstr(http_headers, search_for); + if(header == NULL){ + return NULL; + } + header += strlen(search_for); + char * end = strstr(header, "\r\n"); + *end = '\0'; + char * result = (char *)malloc(end - header + 1); + strcpy(result, header); + return result; +} +long http_header_get_long(char * http_headers, char * http_header_name){ + char * str_value = http_header_get_str(http_headers, http_header_name); + if(str_value == NULL) + return 0; + long long_value = atol(str_value); + free(str_value); + return long_value; +} + char *https_post(char *url, char *data) { url_t parsed_url; parse_url(url, &parsed_url); @@ -142,13 +175,13 @@ char *https_post(char *url, char *data) { SSL_set_fd(ssl, sock); int buffer_size = 1024 * 1024; - char *buffer = malloc(buffer_size); + char *buffer = (char *)malloc(buffer_size); size_t chunk_size_total = 0; if (SSL_connect(ssl) <= 0) { ERR_print_errors_fp(stderr); } else { size_t len = strlen(data); - char *request = malloc(len + 1024 * 1024); + char *request = (char *)malloc(len + 1024 * 1024); sprintf(request, "POST %s HTTP/1.1\r\n" "Content-Length: %ld\r\n" @@ -164,6 +197,24 @@ char *https_post(char *url, char *data) { char *headers = read_until_ssl(ssl, "\r\n\r\n"); (void)headers; + long content_length = http_header_get_long(headers, "Content-Length"); + if (content_length){ + if( content_length > buffer_size) { + buffer_size = content_length; + buffer = (char *)realloc(buffer, buffer_size); + } + size_t bytes_read = 0; + while(bytes_read < content_length) { + int bytes_read_chunk = SSL_read(ssl, buffer + bytes_read, buffer_size - bytes_read); + if(bytes_read_chunk <= 0){ + free(buffer); + return NULL; + } + bytes_read += bytes_read_chunk; + } + return buffer; + } + size_t actual_buffer_size = buffer_size; while (true) { char *header = read_until_ssl(ssl, "\r\n"); @@ -173,7 +224,7 @@ char *https_post(char *url, char *data) { size_t remaining = chunk_size; while (remaining > 0) { size_t to_read = (remaining < buffer_size) ? remaining : buffer_size; - buffer = realloc(buffer, actual_buffer_size + to_read + 1); + buffer = (char *)realloc(buffer, actual_buffer_size + to_read + 1); actual_buffer_size += to_read; size_t bytes_read = SSL_read(ssl, buffer + chunk_size_total, to_read); chunk_size_total += bytes_read; @@ -371,4 +422,4 @@ char *http_get(char *url) { return buffer; } -#endif \ No newline at end of file +#endif diff --git a/http_curl.h b/http_curl.h index 39bd1db..72f804a 100644 --- a/http_curl.h +++ b/http_curl.h @@ -78,4 +78,40 @@ char *curl_post(const char *url, const char *data) { } return NULL; } -#endif \ No newline at end of file + + +char *curl_get(const char *url) { + CURL *curl; + CURLcode res; + struct ResponseBuffer response = {malloc(1), 0}; + + if (!response.data) return NULL; + + curl = curl_easy_init(); + if (curl) { + struct curl_slist *headers = NULL; + curl_easy_setopt(curl, CURLOPT_URL, url); + headers = curl_slist_append(headers, "Content-Type: application/json"); + + char *bearer_header = malloc(1337); + sprintf(bearer_header, "Authorization: Bearer %s", resolve_api_key()); + headers = curl_slist_append(headers, bearer_header); + free(bearer_header); + + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void *)&response); + + res = curl_easy_perform(curl); + if (res != CURLE_OK) { + fprintf(stderr, "An error occurred: %s\n", curl_easy_strerror(res)); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + } + + return response.data; +} + +#endif diff --git a/messages.h b/messages.h index df4ae3a..b2fa837 100644 --- a/messages.h +++ b/messages.h @@ -12,6 +12,8 @@ #ifndef R_MESSAGES_H #define R_MESSAGES_H +#include <string.h> +#include "tools.h" #include "json-c/json.h" struct json_object *message_array = NULL; @@ -23,11 +25,38 @@ struct json_object *message_list() { return message_array; } +struct json_object *message_add_tool_call(struct json_object *message) { + struct json_object *messages = message_list(); + json_object_array_add(messages, message); + return message; +} +struct json_object *message_add_tool_result(char * tool_call_id, char * tool_result) { + struct json_object *messages = message_list(); + struct json_object * message = json_object_new_object(); + + json_object_object_add(message, "tool_call_id", json_object_new_string(tool_call_id)); + json_object_object_add(message, "tool_result", json_object_new_string(tool_result)); + + json_object_array_add(messages, message); + return message; + +} + struct json_object *message_add(const char *role, const char *content) { struct json_object *messages = message_list(); struct json_object *message = json_object_new_object(); json_object_object_add(message, "role", json_object_new_string(role)); + + json_object_object_add(message, "content", json_object_new_string(content)); + + if(!strcmp(role, "user")) { + json_object_object_add(message, "tools", tools_descriptions()); + } + + + + json_object_array_add(messages, message); return message; } @@ -43,4 +72,4 @@ void message_free() { } } -#endif \ No newline at end of file +#endif diff --git a/openai.h b/openai.h index 4c90651..b08bb31 100644 --- a/openai.h +++ b/openai.h @@ -31,10 +31,13 @@ bool openai_system(char* content) { return is_done; } -char* openai_chat(char* role, char* content) { - char* url = "https://api.openai.com/v1/chat/completions"; - char* data = chat_json(role, content); + +struct json_object * chat_message(char * url, char * data){ + printf("QQQ %s\n",data); + char* result = curl_post(url, data); + printf("<<%s>>",result); + struct json_object* parsed_json = json_tokener_parse(result); if (!parsed_json) { fprintf(stderr, "Failed to parse JSON.\n"); @@ -56,20 +59,50 @@ char* openai_chat(char* role, char* content) { return NULL; } - struct json_object* message_object; + struct json_object* message_object; if (!json_object_object_get_ex(first_choice, "message", &message_object)) { fprintf(stderr, "Failed to get 'message' object.\n"); json_object_put(parsed_json); return NULL; } + return message_object; +} +char* openai_chat(char* role, char* content) { + char* url = "https://api.openai.com/v1/chat/completions"; + char* data = chat_json(role, content); + + json_object * message_object = chat_message(url, data); + + struct json_object *tool_calls; + json_object_object_get_ex(message_object, "tool_calls", &tool_calls); + + if (tool_calls) { + message_add_tool_call(message_object); + // char* tool_calls_str = (char *)json_object_to_json_string(tool_calls); + json_object * tool_call_results = tools_execute(tool_calls); + int array_len = json_object_array_length(tool_call_results); + // message_add_tool_call(tool_call_results); + for(int i = 0; i < array_len; i++){ + + json_object * tool_call_result = json_object_array_get_idx(tool_call_results, i); + message_add_tool_call(tool_call_result); + } + + char * tool_calls_result_str = chat_json(NULL,NULL); + + message_object = chat_message(url, tool_calls_result_str); + message_add_tool_call(message_object); + + } char* content_str = (char*)json_object_get_string(json_object_object_get(message_object, "content")); - message_add("assistant", content_str); - free(data); - free(result); + + + //message_add("assistant", content_str); + // free(data); + char* final_result = strdup(content_str); - json_object_put(parsed_json); return final_result; } -#endif \ No newline at end of file +#endif diff --git a/tools.h b/tools.h new file mode 100644 index 0000000..1b7f1b8 --- /dev/null +++ b/tools.h @@ -0,0 +1,134 @@ +#ifndef R_TOOLS_H +#define R_TOOLS_H + +#include <json-c/json.h> +#include <json-c/json_object.h> +#include <string.h> +#include "http_curl.h" + +struct json_object * tool_description_http_get(); + +struct json_object * tools_descriptions(){ + struct json_object *root = json_object_new_array(); + json_object_array_add(root, tool_description_http_get()); + return root; +} + + +char * tool_function_http_get(char * url) { + return curl_get(url); +} + +struct json_object * tool_description_http_get(){ + // Create root object + struct json_object *root = json_object_new_object(); + json_object_object_add(root, "type", json_object_new_string("function")); + + // Create function object + struct json_object *function = json_object_new_object(); + json_object_object_add(function, "name", json_object_new_string("http_fetch")); + json_object_object_add(function, "description", json_object_new_string("Get the contents of an url.")); + + // Create parameters object + struct json_object *parameters = json_object_new_object(); + json_object_object_add(parameters, "type", json_object_new_string("object")); + + // Create properties object + struct json_object *properties = json_object_new_object(); + struct json_object *url = json_object_new_object(); + json_object_object_add(url, "type", json_object_new_string("string")); + json_object_object_add(url, "description", json_object_new_string("fetch url contents.")); + json_object_object_add(properties, "url", url); + + json_object_object_add(parameters, "properties", properties); + + // Add required array + struct json_object *required = json_object_new_array(); + json_object_array_add(required, json_object_new_string("url")); + json_object_object_add(parameters, "required", required); + + // Add additionalProperties as false + json_object_object_add(parameters, "additionalProperties", json_object_new_boolean(0)); + + // Add parameters to function object + json_object_object_add(function, "parameters", parameters); + + // Add strict mode + json_object_object_add(function, "strict", json_object_new_boolean(1)); + + // Add function object to root + json_object_object_add(root, "function", function); + + + return root; +} + +struct json_object * tools_execute(struct json_object * tools_array){ + //struct json_object * tools_result = json_object_new_object(); + struct json_object * tools_result_messages = json_object_new_array(); + int array_len = json_object_array_length(tools_array); + // Iterate over array + for (int i = 0; i < array_len; i++) { + struct json_object *obj = json_object_array_get_idx(tools_array, i); + struct json_object * tool_result = json_object_new_object(); + + + json_object_object_add(tool_result, "tool_call_id", json_object_new_string(json_object_get_string(json_object_object_get(obj, "id")))); + + json_object_object_add(tool_result, "role", json_object_new_string("tool")); + + // Get "id" + struct json_object *id_obj; + if (json_object_object_get_ex(obj, "id", &id_obj)) { + + + printf("ID: %s\n", json_object_get_string(id_obj)); + } + + // Get "type" + struct json_object *type_obj; + if (json_object_object_get_ex(obj, "type", &type_obj)) { + printf("Type: %s\n", json_object_get_string(type_obj)); + } + + // Get "function" object + struct json_object *function_obj; + if (json_object_object_get_ex(obj, "function", &function_obj)) { + // Get function "name" + struct json_object *name_obj; + char * function_name = NULL; + if (json_object_object_get_ex(function_obj, "name", &name_obj)) { + function_name = (char *)json_object_get_string(name_obj); + printf("Function Name: %s\n", json_object_get_string(name_obj)); + } + + // Get function "arguments" + if(!strcmp(function_name,"http_fetch")){ + + + struct json_object *arguments_obj; + if (json_object_object_get_ex(function_obj, "arguments", &arguments_obj)) { + struct json_object *arguments = json_tokener_parse(json_object_get_string(arguments_obj));; + struct json_object *url_obj; + if (json_object_object_get_ex(arguments, "url", &url_obj)) { + char * url = (char *)json_object_get_string(url_obj); + char * http_result = curl_get(url); + printf("URL: %s\n", json_object_get_string(url_obj)); + printf("content: %s\n",http_result); + json_object_object_add(tool_result, "content", json_object_new_string(http_result)); + } + + printf("Function Arguments: %s\n", json_object_get_string(arguments_obj)); + } + } + json_object_array_add(tools_result_messages, tool_result); + } + + printf("\n"); // Add spacing for readability + } + + //json_object_object_add(tools_result, "messages", tools_result_messages); + return tools_result_messages; +} + +#endif