#include "r.h"
#include "db_utils.h"
#include "line.h"
#include "markdown.h"
#include "openai.h"
#include "tools.h"
#include "utils.h"
#include <locale.h>
#include <signal.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <time.h>
volatile sig_atomic_t sigint_count = 0;
time_t first_sigint_time = 0;
bool SYNTAX_HIGHLIGHT_ENABLED = true;
bool API_MODE = false;
static void render(const char *content);
static bool openai_include(const char *path);
static char *get_prompt_from_stdin(char *prompt);
static char *get_prompt_from_args(int argc, char **argv);
static bool try_prompt(int argc, char *argv[]);
static void repl(void);
static void init(void);
static void handle_sigint(int sig);
char * get_env_string(){
FILE *fp = popen("env", "r");
if (fp == NULL) {
perror("popen failed");
return NULL;
}
size_t buffer_size = 1024;
size_t total_size = 0;
char *output = malloc(buffer_size);
if (output == NULL) {
perror("malloc failed");
pclose(fp);
return NULL;
}
size_t bytes_read;
while ((bytes_read = fread(output + total_size, 1, buffer_size - total_size, fp)) > 0) {
total_size += bytes_read;
if (total_size >= buffer_size) {
buffer_size *= 2;
char *temp = realloc(output, buffer_size);
if (temp == NULL) {
perror("realloc failed");
free(output);
pclose(fp);
return NULL;
}
output = temp;
}
}
// Null-terminate the output
output[total_size] = '\0';
pclose(fp);
return output;
}
static char *get_prompt_from_stdin(char *prompt) {
int index = 0;
int c;
while ((c = getchar()) != EOF) {
prompt[index++] = (char)c;
}
prompt[index] = '\0';
return prompt;
}
static char *get_prompt_from_args(int argc, char **argv) {
char *prompt = malloc(10 * 1024 * 1024 + 1);
char *system = malloc(1024 * 1024);
if (!prompt || !system) {
fprintf(stderr, "Error: Memory allocation failed.\n");
free(prompt);
free(system);
return NULL;
}
bool get_from_std_in = false;
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "--stdin") == 0) {
fprintf(stderr, "Reading from stdin.\n");
get_from_std_in = true;
} else if (strcmp(argv[i], "--verbose") == 0) {
is_verbose = true;
} else if (strcmp(argv[i], "--py") == 0 && i + 1 < argc) {
char *py_file_path = expand_home_directory(argv[++i]);
fprintf(stderr, "Including \"%s\".\n", py_file_path);
openai_include(py_file_path);
free(py_file_path);
} else if (strcmp(argv[i], "--free") == 0) {
auth_free();
} else if (strcmp(argv[i], "--context") == 0 && i + 1 < argc) {
char *context_file_path = argv[++i];
fprintf(stderr, "Including \"%s\".\n", context_file_path);
openai_include(context_file_path);
} else if (strcmp(argv[i], "--api") == 0) {
API_MODE = true;
} else if (strcmp(argv[i], "--nh") == 0) {
SYNTAX_HIGHLIGHT_ENABLED = false;
fprintf(stderr, "Syntax highlighting disabled.\n");
} else {
strcat(system, argv[i]);
strcat(system, (i < argc - 1) ? " " : ".");
}
}
if (get_from_std_in) {
if (*system)
openai_system(system);
prompt = get_prompt_from_stdin(prompt);
} else {
free(prompt);
prompt = system;
}
if (!*prompt) {
free(prompt);
return NULL;
}
return prompt;
}
static bool try_prompt(int argc, char *argv[]) {
char *prompt = get_prompt_from_args(argc, argv);
if (prompt) {
char *response = openai_chat("user", prompt);
if (!response) {
printf("Could not get response from server\n");
free(prompt);
return false;
}
render(response);
free(response);
free(prompt);
return true;
}
return false;
}
static bool openai_include(const char *path) {
char *file_content = read_file(path);
if (!file_content)
return false;
openai_system(file_content);
free(file_content);
return true;
}
static void render(const char *content) {
if (SYNTAX_HIGHLIGHT_ENABLED) {
parse_markdown_to_ansi(content);
} else {
printf("%s", content);
}
}
static void repl(void) {
line_init();
char *line = NULL;
while (true) {
line = line_read("> ");
if (!line || !*line)
continue;
if (!strncmp(line, "!dump", 5)) {
printf("%s\n", message_json());
continue;
}
if (!strncmp(line, "!verbose", 8)) {
is_verbose = !is_verbose;
fprintf(stderr, "%s\n", is_verbose ? "Verbose mode enabled" : "Verbose mode disabled");
continue;
}
if (line && *line != '\n')
line_add_history(line);
if (!strncmp(line, "!tools", 6)) {
printf("Available tools: %s\n", json_object_to_json_string(tools_descriptions()));
continue;
}
if (!strncmp(line, "!models", 7)) {
printf("Current model: %s\n", openai_fetch_models());
continue;
}
if (!strncmp(line, "!model", 6)) {
if (line[6] == ' ') {
set_prompt_model(line + 7);
}
printf("Current model: %s\n", get_prompt_model());
continue;
}
if (!strncmp(line, "exit", 4))
exit(0);
while (line && *line != '\n') {
char *response = openai_chat("user", line);
if (response) {
render(response);
printf("\n");
if (strstr(response, "_STEP_")) {
line = "continue";
} else {
line = NULL;
}
free(response);
} else {
exit(0);
}
}
}
}
static void init(void) {
setbuf(stdout, NULL);
line_init();
auth_init();
db_initialize();
char *schema = db_get_schema();
char payload[1024 * 1024] = {0};
snprintf(payload, sizeof(payload),
"# LOCAL DATABASE"
"Your have a local database that you can mutate using the query tool and "
"the get and set tool."
"If you set a value using the tool, make sure that the key is stemmed and lowercased to prevent double entries."
"Dialect is sqlite. This is the schema in json format: %s. ",
schema);
free(schema);
fprintf(stderr, "Loading... 📨");
openai_system(payload);
char * env_system_message = get_env_system_message();
if(env_system_message && *env_system_message){
openai_system(env_system_message);
free(env_system_message);
}
if (!openai_include(".rcontext.txt")) {
openai_include("~/.rcontext.txt");
}
fprintf(stderr, "\r \r");
}
static void handle_sigint(int sig) {
time_t current_time = time(NULL);
printf("\n");
if (sigint_count == 0) {
first_sigint_time = current_time;
sigint_count++;
} else {
if (difftime(current_time, first_sigint_time) <= 1) {
exit(0);
} else {
sigint_count = 1;
first_sigint_time = current_time;
}
}
}
int main(int argc, char *argv[]) {
signal(SIGINT, handle_sigint);
init();
char * env_string = get_env_string();
if(env_string && *env_string){
openai_system(env_string);
free(env_string);
}
if (try_prompt(argc, argv))
return 0;
repl();
return 0;
}