#define PY_SSIZE_T_CLEAN 1
#include "pgs.h"
#include <Python.h>
#include "py.h"
#include "sock.h"
#include <arpa/inet.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>

static PyObject *pgs_api_connect(PyObject *self, PyObject *args) {
  char *host;
  int port;
  int fd;
  if (!PyArg_ParseTuple(args, "si", &host, &port)) {
    return PyLong_FromLong(-1);
  }
  if ((fd = connect_upstream(host, port)) == -1) {
    return PyLong_FromLong(-1);
  }
  return PyLong_FromLong(fd);
}
static PyObject *pgs_api_is_http(PyObject *self, PyObject *args) {
  const char *py_bytes;
  if (!PyArg_ParseTuple(args, "y", &py_bytes)) {
    return NULL;
  }
  bool is_http = strstr(py_bytes, "HTTP/") != NULL;

  return is_http ? Py_True : Py_False;
}
static PyObject *pgs_api_is_ssh(PyObject *self, PyObject *args) {
  const char *py_bytes;
  if (!PyArg_ParseTuple(args, "y", &py_bytes)) {
    return NULL;
  }

  bool is_ssh = strstr(py_bytes, "SSH") != NULL;

  return is_ssh ? Py_True : Py_False;
}
static PyObject *pgs_api_read(PyObject *self, PyObject *args) {
  int fd, length;
  if (!PyArg_ParseTuple(args, "ii", &fd, &length)) {
    return NULL;
  }
  char buffer[length + 1];
  ssize_t bytes_read = read(fd, buffer, length);
  buffer[bytes_read] = 0;
  return PyBytes_FromString(buffer);
}

static PyObject *pgs_api_peek(PyObject *self, PyObject *args) {
  int fd, length;
  if (!PyArg_ParseTuple(args, "ii", &fd, &length)) {
    return NULL;
  }
  char buffer[length + 1];
  ssize_t bytes_read = recv(fd, buffer, length, MSG_PEEK);
  buffer[bytes_read] = 0;
  return PyBytes_FromString(buffer);
}

static PyObject *pgs_api_write(PyObject *self, PyObject *args) {
  int fd, length;
  const char *input_data;
  Py_ssize_t input_length;

  if (!PyArg_ParseTuple(args, "is", &fd, &input_data, &input_length)) {
    return NULL;
  }
  ssize_t bytes_sent_total = 0;
  input_length = strlen(input_data);
  while (true) {
    ssize_t bytes_sent =
        send(fd, input_data + bytes_sent_total, input_length, 0);
    if (bytes_sent < 1) {
      return PyLong_FromSsize_t(bytes_sent);
    }
    bytes_sent_total += bytes_sent;
    if (bytes_sent_total == input_length) {
      break;
    }
  }
  return PyLong_FromSsize_t(input_length);
}
// A simple function: add two numbers
static PyObject *mymodule_add(PyObject *self, PyObject *args) {
  double a, b, result;
  if (!PyArg_ParseTuple(args, "dd", &a, &b)) {
    return NULL;
  }
  result = a + b;
  return PyFloat_FromDouble(result);
}

// Method table for the module
static PyMethodDef MyModuleMethods[] = {
    {"add", mymodule_add, METH_VARARGS, "Add two numbers"},
    {"connect", pgs_api_connect, METH_VARARGS, "Connect to upstream"},
    {"read", pgs_api_read, METH_VARARGS, "Read fd"},
    {"peek", pgs_api_peek, METH_VARARGS, "Peek fd"},
    {"write", pgs_api_write, METH_VARARGS, "Write fd"},
    {"is_ssh", pgs_api_is_ssh, METH_VARARGS,
     "Check if header contains SSH data."},
    {"is_http", pgs_api_is_http, METH_VARARGS,
     "Check if header contains HTTP data."},
    {NULL, NULL, 0, NULL} // Sentinel
};

// Module definition structure
static struct PyModuleDef mymodule = {
    PyModuleDef_HEAD_INIT,
    "pgs",                     // Module name
    "Pretty Good Server API.", // Module docstring
    -1,                        // Size of per-interpreter state of the module
    MyModuleMethods            // Methods
};

// Module initialization function
PyMODINIT_FUNC PyInit_mymodule(void) { return PyModule_Create(&mymodule); }