This commit is contained in:
parent
378f64fc75
commit
dbd4fe22ed
@ -478,59 +478,38 @@ void connection_connect_to_upstream(connection_t *client, const char *data, size
|
||||
size_t len_to_send = data_len;
|
||||
char *modified_request = NULL;
|
||||
|
||||
if (route->rewrite_host) {
|
||||
char new_host_header[512];
|
||||
const char *old_host_header_start = NULL;
|
||||
const char *old_host_header_end = NULL;
|
||||
if (route->rewrite_host) {
|
||||
char new_host_header[512];
|
||||
const char *old_host_header_start = NULL;
|
||||
const char *old_host_header_end = NULL;
|
||||
|
||||
const char *current = data;
|
||||
const char *end = data + data_len;
|
||||
while(current < end) {
|
||||
const char* line_end = memchr(current, '\n', end - current);
|
||||
if (!line_end) break;
|
||||
|
||||
if (strncasecmp(current, "Host", 4) == 0) {
|
||||
const char *colon = current + 4;
|
||||
while (colon < line_end && (*colon == ' ' || *colon == '\t')) {
|
||||
colon++;
|
||||
if (http_find_header_line_bounds(data, data_len, "Host", &old_host_header_start, &old_host_header_end)) {
|
||||
int is_default_port = (route->use_ssl && route->upstream_port == 443) ||
|
||||
(!route->use_ssl && route->upstream_port == 80);
|
||||
if (is_default_port) {
|
||||
snprintf(new_host_header, sizeof(new_host_header), "Host: %s\r\n", route->upstream_host);
|
||||
} else {
|
||||
snprintf(new_host_header, sizeof(new_host_header), "Host: %s:%d\r\n", route->upstream_host, route->upstream_port);
|
||||
}
|
||||
if (colon < line_end && *colon == ':') {
|
||||
old_host_header_start = current;
|
||||
old_host_header_end = line_end + 1;
|
||||
break;
|
||||
size_t new_host_len = strlen(new_host_header);
|
||||
size_t old_host_len = old_host_header_end - old_host_header_start;
|
||||
|
||||
len_to_send = data_len - old_host_len + new_host_len;
|
||||
modified_request = malloc(len_to_send + 1);
|
||||
if (modified_request) {
|
||||
char* p = modified_request;
|
||||
size_t prefix_len = old_host_header_start - data;
|
||||
memcpy(p, data, prefix_len);
|
||||
p += prefix_len;
|
||||
memcpy(p, new_host_header, new_host_len);
|
||||
p += new_host_len;
|
||||
size_t suffix_len = data_len - (old_host_header_end - data);
|
||||
memcpy(p, old_host_header_end, suffix_len);
|
||||
data_to_send = modified_request;
|
||||
}
|
||||
}
|
||||
|
||||
current = line_end + 1;
|
||||
}
|
||||
|
||||
if (old_host_header_start) {
|
||||
int is_default_port = (route->use_ssl && route->upstream_port == 443) ||
|
||||
(!route->use_ssl && route->upstream_port == 80);
|
||||
if (is_default_port) {
|
||||
snprintf(new_host_header, sizeof(new_host_header), "Host: %s\r\n", route->upstream_host);
|
||||
} else {
|
||||
snprintf(new_host_header, sizeof(new_host_header), "Host: %s:%d\r\n", route->upstream_host, route->upstream_port);
|
||||
}
|
||||
size_t new_host_len = strlen(new_host_header);
|
||||
size_t old_host_len = old_host_header_end - old_host_header_start;
|
||||
|
||||
len_to_send = data_len - old_host_len + new_host_len;
|
||||
modified_request = malloc(len_to_send + 1);
|
||||
if (modified_request) {
|
||||
char* p = modified_request;
|
||||
size_t prefix_len = old_host_header_start - data;
|
||||
memcpy(p, data, prefix_len);
|
||||
p += prefix_len;
|
||||
memcpy(p, new_host_header, new_host_len);
|
||||
p += new_host_len;
|
||||
size_t suffix_len = data_len - (old_host_header_end - data);
|
||||
memcpy(p, old_host_header_end, suffix_len);
|
||||
data_to_send = modified_request;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer_ensure_capacity(&up->write_buf, len_to_send) == 0) {
|
||||
memcpy(up->write_buf.data, data_to_send, len_to_send);
|
||||
up->write_buf.tail = len_to_send;
|
||||
|
||||
32
src/http.c
32
src/http.c
@ -251,3 +251,35 @@ int http_rewrite_content_length(char *headers, size_t *headers_len, size_t max_l
|
||||
*headers_len = new_total_len;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int http_find_header_line_bounds(const char* data, size_t len, const char* name, const char** line_start, const char** line_end) {
|
||||
size_t name_len = strlen(name);
|
||||
const char* end = data + len;
|
||||
|
||||
for (const char* p = data; p < end; ) {
|
||||
const char* current_line_end = memchr(p, '\n', end - p);
|
||||
if (!current_line_end) break;
|
||||
|
||||
const char* line_actual_end = current_line_end;
|
||||
if (line_actual_end > p && *(line_actual_end - 1) == '\r') {
|
||||
line_actual_end--;
|
||||
}
|
||||
|
||||
// Check for header name followed by optional whitespace and a colon
|
||||
if ((size_t)(line_actual_end - p) > name_len && strncasecmp(p, name, name_len) == 0) {
|
||||
const char* colon_pos = p + name_len;
|
||||
while (colon_pos < line_actual_end && (*colon_pos == ' ' || *colon_pos == '\t')) {
|
||||
colon_pos++;
|
||||
}
|
||||
if (colon_pos < line_actual_end && *colon_pos == ':') {
|
||||
*line_start = p;
|
||||
*line_end = current_line_end + 1; // Include the \n (and \r if present)
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
p = current_line_end + 1;
|
||||
}
|
||||
*line_start = NULL;
|
||||
*line_end = NULL;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -11,5 +11,6 @@ int http_detect_binary_content(const char *data, size_t len);
|
||||
long http_get_content_length(const char *headers, size_t headers_len);
|
||||
int http_find_headers_end(const char *data, size_t len, size_t *headers_end);
|
||||
int http_rewrite_content_length(char *headers, size_t *headers_len, size_t max_len, long new_length);
|
||||
int http_find_header_line_bounds(const char* data, size_t len, const char* name, const char** line_start, const char** line_end);
|
||||
|
||||
#endif
|
||||
|
||||
Loading…
Reference in New Issue
Block a user