#ifndef PGS_SOCK_H
#define PGS_SOCK_H
#include "protocol.h"
#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <pthread.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include "py.h"
#define MAX_EVENTS 8096
#define BUFFER_SIZE 1024

typedef struct {
  PROTOCOL_STATUS status;
  PROTOCOL_NAME protocol_name;
  int client_fd;
  int upstream_fd;
  char *buffer;
  size_t buffer_size;
  size_t buffer_offset;
  char *http_intercepted_headers;
} connection_t;

int listen_fd = 0;
int epoll_fd = 0;
connection_t connections[MAX_EVENTS][sizeof(connection_t)] = {0};

int sock_init(void);
void sock_exit(void);

void set_nonblocking(int fd) {
  int flags = fcntl(fd, F_GETFL, 0);
  if (flags == -1) {
    perror("fcntl get");
    exit(EXIT_FAILURE);
  }
  if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) {
    perror("fcntl set");
    exit(EXIT_FAILURE);
  }
}

int prepare_upstream() {
  int sockfd = socket(AF_INET, SOCK_STREAM, 0);

  return sockfd;
}

int connect_upstream(const char *host, int port) {
  int sockfd = socket(AF_INET, SOCK_STREAM, 0);
  if (sockfd == -1) {
    perror("socket");
    return -1;
  }

  set_nonblocking(sockfd);

  struct sockaddr_in server_addr;
  memset(&server_addr, 0, sizeof(server_addr));
  server_addr.sin_family = AF_INET;
  server_addr.sin_port = htons(port);
  if (inet_pton(AF_INET, host, &server_addr.sin_addr) <= 0) {
    perror("inet_pton");
    close(sockfd);
    return -1;
  }

  if (connect(sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) ==
      -1) {
    if (errno != EINPROGRESS) {
      perror("connect");
      close(sockfd);
      return -1;
    }
  }

  return sockfd;
}

int create_listening_socket(int port) {
  int listen_fd = socket(AF_INET, SOCK_STREAM, 0);
  if (listen_fd == -1) {
    perror("socket");
    return -1;
  }

  int opt = 1;
  if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) ==
      -1) {
    perror("setsockopt");
    close(listen_fd);
    return -1;
  }

  struct sockaddr_in server_addr;
  memset(&server_addr, 0, sizeof(server_addr));
  server_addr.sin_family = AF_INET;
  server_addr.sin_addr.s_addr = INADDR_ANY;
  server_addr.sin_port = htons(port);

  if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) ==
      -1) {
    perror("bind");
    close(listen_fd);
    return -1;
  }

  if (listen(listen_fd, SOMAXCONN) == -1) {
    perror("listen");
    close(listen_fd);
    return -1;
  }

  set_nonblocking(listen_fd);
  return listen_fd;
}

char *sock_read(int fd, char *buf, size_t size) {
  connection_t *conn = connections[fd];
  size_t left_in_buffer = conn->buffer_size - conn->buffer_offset;
  size_t bytes_to_read = size > left_in_buffer ? left_in_buffer : size;
  ssize_t bytes_read = 0;
  char *buffer;
  buffer[size] = 0;
  if (bytes_to_read) {
    bytes_read = recv(fd, buffer, size, 0);
    buffer[bytes_read] = 0;
  }
  memcpy(buf, conn->buffer + conn->buffer_offset, bytes_to_read);

  if (bytes_read > 0) {
    return buf;
  } else if (bytes_read == 0) {
    printf("Connection closed by remote (fd=%d)\n", fd);
  } else {
    perror("read");
  }
  return NULL;
}

void close_connection(int epoll_fd, connection_t *conn) {
  if (conn->client_fd != -1) {
    epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->client_fd, NULL);
    close(conn->client_fd);
  }
  if (conn->upstream_fd != -1) {
    epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->upstream_fd, NULL);
    close(conn->upstream_fd);
  }
}

int forward_data(int from_fd, int to_fd) {
  static char buffer[BUFFER_SIZE];
  // Feels great to do somehow. Better safe than sorry.
  memset(buffer, 0, BUFFER_SIZE);
  ssize_t bytes_read = recv(from_fd, buffer, sizeof(buffer), 0);
  if (bytes_read > 0) {
    ssize_t bytes_written = send(to_fd, buffer, bytes_read, 0);
    if (bytes_written == -1) {
      perror("write");
    }
  } else if (bytes_read == 0) {
    printf("Connection closed by remote (fd=%d)\n", from_fd);
  } else {
    perror("read");
  }
  return (int)bytes_read;
}

ssize_t sock_send_all(int fd, char * content, ssize_t length) {
    ssize_t bytes_total_sent = 0;
    while(bytes_total_sent < length) {
      ssize_t bytes_sent = send(fd, content + bytes_total_sent, length, 0);
      if(bytes_sent <= 0){
        return -1;
      }
      bytes_total_sent += bytes_sent; 
    }
    return bytes_total_sent;
}

bool handle_connect(struct epoll_event event, int epoll_fd) {
  struct sockaddr_in client_addr;
  socklen_t client_len = sizeof(client_addr);
  int client_fd =
      accept(listen_fd, (struct sockaddr *)&client_addr, &client_len);
  if (client_fd == -1) {
    perror("accept");
    return false;
  }
  set_nonblocking(client_fd);

  struct epoll_event client_event;
  client_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP;
  client_event.data.ptr = connections[client_fd];
  client_event.data.fd = client_fd;
  connections[client_fd]->upstream_fd = -1;
  connections[client_fd]->client_fd = client_fd;
  connections[client_fd]->status = PS_SNIFF;
  connections[client_fd]->protocol_name = PS_NONE;
  printf("New connection: client_fd=%d\n", client_fd);

  epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event);
  return true;
}

void handle_close(int epoll_fd, connection_t *conn) {
  printf("Connection closed: client_fd=%d, upstream_fd=%d\n", conn->client_fd,
         conn->upstream_fd);
  close_connection(epoll_fd, conn);
}

char * sock_read_until(int fd, char * until, int until_len) {

  char * result = (char *)malloc(BUFFER_SIZE);
  result[0] = 0;
  int bytes;
  int bytes_total = 0;
  char buffer[2];
  while((bytes = recv(fd, buffer, 1,0)) > 0){
    result[bytes_total] = buffer[0];
    bytes_total += bytes;
    if(bytes_total >= until_len){
      if(!strncmp(buffer + bytes_total - until_len, until, until_len)){
        break;
      }
    }
  }
  result[bytes_total] = 0;
  return result;
}

void http_intercept_headers(int epoll_fd, connection_t *conn) {
  char * headers = sock_read_until(conn->client_fd, "\r\n\r\n", 4);

  char * intercepted_headers = py_http_intercept_headers(conn->client_fd,headers);
  if(!intercepted_headers){
    close_connection(epoll_fd, conn);
    return;
  } 
  sock_send_all(conn->client_fd, intercepted_headers, strlen(intercepted_headers));
  printf("Intercepted headers: %s\n", intercepted_headers);
  conn->http_intercepted_headers = intercepted_headers;
  conn->status = PS_NONE;
  connections[conn->upstream_fd]->status = PS_HTTP_READ_HEADER;
}


void http_intercept_body(int epoll_fd, connection_t *conn) {
   printf("HIERR\n"); 
}

void handle_stream(struct epoll_event event, int epoll_fd, connection_t *conn) {
  if (conn->upstream_fd == -1) {
    conn->upstream_fd = py_on_connect(conn->client_fd);
    if(conn->upstream_fd == -1){
      close_connection(epoll_fd, conn);
      return;
    }
    conn->protocol_name = protocol_sniff(conn->client_fd);
    if(conn->protocol_name == PN_ERROR){
      close_connection(epoll_fd, conn);
      return;
    }else if(conn->protocol_name == PN_HTTP){
      conn->status = PS_HTTP_READ_HEADER;
    }else if(conn->protocol_name == PN_HTTP_CHUNKED){
      conn->status = PS_HTTP_READ_HEADER;
    }else if(conn->protocol_name == PN_HTTP_KEEP_ALIVE){
      conn->status = PS_HTTP_READ_HEADER;
    }else if(conn->protocol_name == PN_HTTP_REQUEST){
      conn->status = PS_HTTP_READ_HEADER;
    }else if(conn->protocol_name == PN_HTTP_WEBSOCKET){
      conn->status = PS_HTTP_READ_HEADER;
    }else if(conn->protocol_name == PN_SSH){
      conn->status = PS_STREAM;
    }else if(conn->protocol_name == PN_RAW){
      conn->status = PS_STREAM;
    }else {
      conn->status = PS_ERROR;
    }

    if (conn->upstream_fd == -1 || conn->status == PS_ERROR) {
      close_connection(epoll_fd, conn);
      return;
    }
    set_nonblocking(conn->upstream_fd);
    struct epoll_event upstream_event;
    upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP;
    upstream_event.data.ptr = connections[conn->upstream_fd];
    upstream_event.data.fd = conn->upstream_fd;
    connections[conn->client_fd]->upstream_fd = conn->upstream_fd;
    connections[conn->upstream_fd]->client_fd = conn->client_fd;
    connections[conn->upstream_fd]->upstream_fd = conn->upstream_fd;

    epoll_ctl(epoll_fd, EPOLL_CTL_ADD, conn->upstream_fd, &upstream_event);

    printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd,
           conn->upstream_fd);

    return;
  }

  if (event.data.fd == conn->client_fd) {
    if(conn->status = PS_HTTP_READ_HEADER){
      http_intercept_headers(epoll_fd, conn);
    }
    if(conn->status == PS_HTTP_READ_BODY){
      http_intercept_body(epoll_fd, conn);
    }
    if (forward_data(conn->client_fd, conn->upstream_fd) < 1) {
      close_connection(epoll_fd, conn);
    }
  } else if (event.data.fd == conn->upstream_fd) {

    if (forward_data(conn->upstream_fd, conn->client_fd) < 1) {
      close_connection(epoll_fd, conn);
    }
  }
}

void serve(int port) {

  listen_fd = create_listening_socket(port);
  if (listen_fd == -1) {
    fprintf(stderr, "Failed to create listening socket\n");
    return;
  }

  epoll_fd = epoll_create1(0);
  if (epoll_fd == -1) {
    perror("epoll_create1");
    close(listen_fd);
    return;
  }

  struct epoll_event event;
  event.events = EPOLLIN;
  event.data.fd = listen_fd;

  if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_fd, &event) == -1) {
    perror("epoll_ctl");
    close(listen_fd);
    close(epoll_fd);
    return;
  }

  struct epoll_event events[MAX_EVENTS];
  memset(events, 0, sizeof(events));

  printf("Pretty Good Server listening on port %d\n", port);

  while (1) {
    int num_events = epoll_wait(epoll_fd, events, MAX_EVENTS, -1);
    if (num_events == -1) {
      perror("epoll_wait");
      break;
    }

    for (int i = 0; i < num_events; i++) {
      if (events[i].data.fd == listen_fd) {
        handle_connect(events[i], epoll_fd);
      } else {
        connection_t *conn = connections[events[i].data.fd];
        if (events[i].events & (EPOLLHUP | EPOLLERR)) {
          handle_close(epoll_fd, conn);
        } else if (events[i].events & EPOLLIN) {
          handle_stream(events[i], epoll_fd, conn);
        }
      }
    }
  }
}
#endif