]> git.sesse.net Git - cubemap/blobdiff - server.cpp
Support HTTP/1.1 persistent connections (not that useful yet).
[cubemap] / server.cpp
index 5026deb7ecb6fd911c16888e070fa33e31482a4b..1deb087b056dc2ff1bf3cf426d71fff8c090d29b 100644 (file)
@@ -23,7 +23,6 @@
 #include "accesslog.h"
 #include "log.h"
 #include "metacube2.h"
-#include "mutexlock.h"
 #include "parse.h"
 #include "server.h"
 #include "state.pb.h"
@@ -57,9 +56,6 @@ inline bool is_earlier(timespec a, timespec b)
 
 Server::Server()
 {
-       pthread_mutex_init(&mutex, NULL);
-       pthread_mutex_init(&queued_clients_mutex, NULL);
-
        epoll_fd = epoll_create(1024);  // Size argument is ignored.
        if (epoll_fd == -1) {
                log_perror("epoll_fd");
@@ -69,10 +65,6 @@ Server::Server()
 
 Server::~Server()
 {
-       for (Stream *stream : streams) {
-               delete stream;
-       }
-
        safe_close(epoll_fd);
 }
 
@@ -80,7 +72,7 @@ vector<ClientStats> Server::get_client_stats() const
 {
        vector<ClientStats> ret;
 
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        for (const auto &fd_and_client : clients) {
                ret.push_back(fd_and_client.second.get_stats());
        }
@@ -104,7 +96,7 @@ void Server::do_work()
                        exit(1);
                }
 
-               MutexLock lock(&mutex);  // We release the mutex between iterations.
+               lock_guard<mutex> lock(mu);  // We release the mutex between iterations.
        
                process_queued_data();
 
@@ -122,7 +114,7 @@ void Server::do_work()
 
                // Process each client where its stream has new data,
                // even if there was no socket activity.
-               for (Stream *stream : streams) {
+               for (unique_ptr<Stream> &stream : streams) {
                        vector<Client *> to_process;
                        swap(stream->to_process, to_process);
                        for (Client *client : to_process) {
@@ -195,7 +187,7 @@ CubemapStateProto Server::serialize()
        for (const auto &fd_and_client : clients) {
                serialized.add_clients()->MergeFrom(fd_and_client.second.serialize());
        }
-       for (Stream *stream : streams) {
+       for (unique_ptr<Stream> &stream : streams) {
                serialized.add_streams()->MergeFrom(stream->serialize());
        }
        return serialized;
@@ -203,7 +195,7 @@ CubemapStateProto Server::serialize()
 
 void Server::add_client_deferred(int sock, Acceptor *acceptor)
 {
-       MutexLock lock(&queued_clients_mutex);
+       lock_guard<mutex> lock(queued_clients_mutex);
        queued_add_clients.push_back(std::make_pair(sock, acceptor));
 }
 
@@ -245,7 +237,7 @@ void Server::add_client(int sock, Acceptor *acceptor)
        if (is_tls) {
                assert(tls_server_contexts.count(acceptor));
                client_ptr->tls_context = tls_accept(tls_server_contexts[acceptor]);
-               if (client_ptr->tls_context == NULL) {
+               if (client_ptr->tls_context == nullptr) {
                        log(ERROR, "tls_accept() failed");
                        close_client(client_ptr);
                        return;
@@ -258,14 +250,14 @@ void Server::add_client(int sock, Acceptor *acceptor)
 
 void Server::add_client_from_serialized(const ClientProto &client)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        Stream *stream;
        int stream_index = lookup_stream_by_url(client.url());
        if (stream_index == -1) {
                assert(client.state() != Client::SENDING_DATA);
-               stream = NULL;
+               stream = nullptr;
        } else {
-               stream = streams[stream_index];
+               stream = streams[stream_index].get();
        }
        auto inserted = clients.insert(make_pair(client.sock(), Client(client, stream)));
        assert(inserted.second == true);  // Should not already exist.
@@ -317,51 +309,51 @@ int Server::lookup_stream_by_url(const string &url) const
 
 int Server::add_stream(const string &url, size_t backlog_size, size_t prebuffering_bytes, Stream::Encoding encoding, Stream::Encoding src_encoding)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        stream_url_map.insert(make_pair(url, streams.size()));
-       streams.push_back(new Stream(url, backlog_size, prebuffering_bytes, encoding, src_encoding));
+       streams.emplace_back(new Stream(url, backlog_size, prebuffering_bytes, encoding, src_encoding));
        return streams.size() - 1;
 }
 
 int Server::add_stream_from_serialized(const StreamProto &stream, int data_fd)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        stream_url_map.insert(make_pair(stream.url(), streams.size()));
-       streams.push_back(new Stream(stream, data_fd));
+       streams.emplace_back(new Stream(stream, data_fd));
        return streams.size() - 1;
 }
        
 void Server::set_backlog_size(int stream_index, size_t new_size)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        assert(stream_index >= 0 && stream_index < ssize_t(streams.size()));
        streams[stream_index]->set_backlog_size(new_size);
 }
 
 void Server::set_prebuffering_bytes(int stream_index, size_t new_amount)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        assert(stream_index >= 0 && stream_index < ssize_t(streams.size()));
        streams[stream_index]->prebuffering_bytes = new_amount;
 }
        
 void Server::set_encoding(int stream_index, Stream::Encoding encoding)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        assert(stream_index >= 0 && stream_index < ssize_t(streams.size()));
        streams[stream_index]->encoding = encoding;
 }
 
 void Server::set_src_encoding(int stream_index, Stream::Encoding encoding)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        assert(stream_index >= 0 && stream_index < ssize_t(streams.size()));
        streams[stream_index]->src_encoding = encoding;
 }
        
 void Server::set_header(int stream_index, const string &http_header, const string &stream_header)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        assert(stream_index >= 0 && stream_index < ssize_t(streams.size()));
        streams[stream_index]->http_header = http_header;
 
@@ -380,7 +372,7 @@ void Server::set_header(int stream_index, const string &http_header, const strin
        
 void Server::set_pacing_rate(int stream_index, uint32_t pacing_rate)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        assert(clients.empty());
        assert(stream_index >= 0 && stream_index < ssize_t(streams.size()));
        streams[stream_index]->pacing_rate = pacing_rate;
@@ -388,7 +380,7 @@ void Server::set_pacing_rate(int stream_index, uint32_t pacing_rate)
 
 void Server::add_gen204(const std::string &url, const std::string &allow_origin)
 {
-       MutexLock lock(&mutex);
+       lock_guard<mutex> lock(mu);
        assert(clients.empty());
        ping_url_map[url] = allow_origin;
 }
@@ -422,7 +414,7 @@ void Server::process_client(Client *client)
 {
        switch (client->state) {
        case Client::READING_REQUEST: {
-               if (client->tls_context != NULL) {
+               if (client->tls_context != nullptr) {
                        if (send_pending_tls_data(client)) {
                                // send_pending_tls_data() hit postconditions #1 or #4.
                                return;
@@ -433,7 +425,7 @@ read_request_again:
                // Try to read more of the request.
                char buf[1024];
                int ret;
-               if (client->tls_context == NULL) {
+               if (client->tls_context == nullptr) {
                        ret = read_nontls_data(client, buf, sizeof(buf));
                        if (ret == -1) {
                                // read_nontls_data() hit postconditions #1 or #2.
@@ -534,9 +526,14 @@ sending_header_or_short_response_again:
                client->header_or_short_response.clear();
 
                if (client->state == Client::SENDING_SHORT_RESPONSE) {
-                       // We're done sending the error, so now close.  
-                       // This is postcondition #1.
-                       close_client(client);
+                       if (more_requests(client)) {
+                               // We're done sending the error, but should keep on reading new requests.
+                               goto read_request_again;
+                       } else {
+                               // We're done sending the error, so now close.
+                               // This is postcondition #1.
+                               close_client(client);
+                       }
                        return;
                }
 
@@ -665,9 +662,9 @@ sending_data_again:
 bool Server::send_pending_tls_data(Client *client)
 {
        // See if there's data from the TLS library to write.
-       if (client->tls_data_to_send == NULL) {
+       if (client->tls_data_to_send == nullptr) {
                client->tls_data_to_send = tls_get_write_buffer(client->tls_context, &client->tls_data_left_to_send);
-               if (client->tls_data_to_send == NULL) {
+               if (client->tls_data_to_send == nullptr) {
                        // Really no data to send.
                        return false;
                }
@@ -696,7 +693,7 @@ send_data_again:
        if (ret > 0 && size_t(ret) == client->tls_data_left_to_send) {
                // All data has been sent, so we don't need to go to sleep.
                tls_buffer_clear(client->tls_context);
-               client->tls_data_to_send = NULL;
+               client->tls_data_to_send = nullptr;
                return false;
        }
 
@@ -794,7 +791,7 @@ read_again:
 void Server::skip_lost_data(Client *client)
 {
        Stream *stream = client->stream;
-       if (stream == NULL) {
+       if (stream == nullptr) {
                return;
        }
        size_t bytes_to_send = stream->bytes_received - client->stream_pos;
@@ -809,6 +806,7 @@ void Server::skip_lost_data(Client *client)
 int Server::parse_request(Client *client)
 {
        vector<string> lines = split_lines(client->request);
+       client->request.clear();
        if (lines.empty()) {
                return 400;  // Bad request (empty).
        }
@@ -826,7 +824,7 @@ int Server::parse_request(Client *client)
        }
 
        vector<string> request_tokens = split_tokens(lines[0]);
-       if (request_tokens.size() < 2) {
+       if (request_tokens.size() < 3) {
                return 400;  // Bad request (empty).
        }
        if (request_tokens[0] != "GET") {
@@ -842,6 +840,24 @@ int Server::parse_request(Client *client)
                client->stream_pos = -1;
        }
 
+       // Figure out if we're supposed to close the socket after we've delivered the response.
+       string protocol = request_tokens[2];
+       if (protocol.find("HTTP/") != 0) {
+               return 400;  // Bad request.
+       }
+       client->close_after_response = false;
+       client->http_11 = true;
+       if (protocol == "HTTP/1.0") {
+               // No persistent connections.
+               client->close_after_response = true;
+               client->http_11 = false;
+       } else {
+               multimap<string, string>::const_iterator connection_it = headers.find("Connection");
+               if (connection_it != headers.end() && connection_it->second == "close") {
+                       client->close_after_response = true;
+               }
+       }
+
        map<string, int>::const_iterator stream_url_map_it = stream_url_map.find(url);
        if (stream_url_map_it == stream_url_map.end()) {
                map<string, string>::const_iterator ping_url_map_it = ping_url_map.find(url);
@@ -852,11 +868,17 @@ int Server::parse_request(Client *client)
                }
        }
 
-       Stream *stream = streams[stream_url_map_it->second];
+       Stream *stream = streams[stream_url_map_it->second].get();
        if (stream->http_header.empty()) {
                return 503;  // Service unavailable.
        }
 
+       // Streams currently never end, so we don't have a content-length,
+       // and can just as well tell the client it's Connection: close
+       // (otherwise, we'd have to implement chunking TE for no good reason).
+       // When we start to support fragments, this will change.
+       client->close_after_response = true;
+
        client->stream = stream;
        if (setsockopt(client->sock, SOL_SOCKET, SO_MAX_PACING_RATE, &client->stream->pacing_rate, sizeof(client->stream->pacing_rate)) == -1) {
                if (client->stream->pacing_rate != ~0U) {
@@ -871,14 +893,22 @@ int Server::parse_request(Client *client)
 void Server::construct_header(Client *client)
 {
        Stream *stream = client->stream;
+       client->header_or_short_response = stream->http_header;
+       if (client->http_11) {
+               assert(client->header_or_short_response.find("HTTP/1.0") == 0);
+               client->header_or_short_response[7] = '1';  // Change to HTTP/1.1.
+               if (client->close_after_response) {
+                       client->header_or_short_response.append("Connection: close\r\n");
+               }
+       } else {
+               assert(client->close_after_response);
+       }
        if (stream->encoding == Stream::STREAM_ENCODING_RAW) {
-               client->header_or_short_response = stream->http_header +
-                       "\r\n" +
-                       stream->stream_header;
+               client->header_or_short_response.append("\r\n");
        } else if (stream->encoding == Stream::STREAM_ENCODING_METACUBE) {
-               client->header_or_short_response = stream->http_header +
-                       "Content-encoding: metacube\r\n" +
-                       "\r\n";
+               client->header_or_short_response.append(
+                       "Content-encoding: metacube\r\n"
+                       "\r\n");
                if (!stream->stream_header.empty()) {
                        metacube2_block_header hdr;
                        memcpy(hdr.sync, METACUBE2_SYNC, sizeof(hdr.sync));
@@ -888,10 +918,10 @@ void Server::construct_header(Client *client)
                        client->header_or_short_response.append(
                                string(reinterpret_cast<char *>(&hdr), sizeof(hdr)));
                }
-               client->header_or_short_response.append(stream->stream_header);
        } else {
                assert(false);
        }
+       client->header_or_short_response.append(stream->stream_header);
 
        // Switch states.
        client->state = Client::SENDING_HEADER;
@@ -901,8 +931,15 @@ void Server::construct_header(Client *client)
 void Server::construct_error(Client *client, int error_code)
 {
        char error[256];
-       snprintf(error, 256, "HTTP/1.0 %d Error\r\nContent-type: text/plain\r\n\r\nSomething went wrong. Sorry.\r\n",
-               error_code);
+       if (client->http_11 && client->close_after_response) {
+               snprintf(error, sizeof(error),
+                       "HTTP/1.1 %d Error\r\nContent-type: text/plain\r\nConnection: close\r\n\r\nSomething went wrong. Sorry.\r\n",
+                       error_code);
+       } else {
+               snprintf(error, sizeof(error),
+                       "HTTP/1.%d %d Error\r\nContent-type: text/plain\r\nContent-length: 30\r\n\r\nSomething went wrong. Sorry.\r\n",
+                       client->http_11, error_code);
+       }
        client->header_or_short_response = error;
 
        // Switch states.
@@ -915,19 +952,20 @@ void Server::construct_204(Client *client)
        map<string, string>::const_iterator ping_url_map_it = ping_url_map.find(client->url);
        assert(ping_url_map_it != ping_url_map.end());
 
-       if (ping_url_map_it->second.empty()) {
-               client->header_or_short_response =
-                       "HTTP/1.0 204 No Content\r\n"
-                       "\r\n";
+       if (client->http_11) {
+               client->header_or_short_response = "HTTP/1.1 204 No Content\r\n";
+               if (client->close_after_response) {
+                       client->header_or_short_response.append("Connection: close\r\n");
+               }
        } else {
-               char response[256];
-               snprintf(response, 256,
-                        "HTTP/1.0 204 No Content\r\n"
-                        "Access-Control-Allow-Origin: %s\r\n"
-                        "\r\n",
-                        ping_url_map_it->second.c_str());
-               client->header_or_short_response = response;
+               client->header_or_short_response = "HTTP/1.0 204 No Content\r\n";
+               assert(client->close_after_response);
+       }
+       if (!ping_url_map_it->second.empty()) {
+               client->header_or_short_response.append("Access-Control-Allow-Origin: ");
+               client->header_or_short_response.append(ping_url_map_it->second);
        }
+       client->header_or_short_response.append("\r\n");
 
        // Switch states.
        client->state = Client::SENDING_SHORT_RESPONSE;
@@ -943,13 +981,13 @@ void delete_from(vector<T> *v, T elem)
        
 void Server::close_client(Client *client)
 {
-       if (epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client->sock, NULL) == -1) {
+       if (epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client->sock, nullptr) == -1) {
                log_perror("epoll_ctl(EPOLL_CTL_DEL)");
                exit(1);
        }
 
        // This client could be sleeping, so we'll need to fix that. (Argh, O(n).)
-       if (client->stream != NULL) {
+       if (client->stream != nullptr) {
                delete_from(&client->stream->sleeping_clients, client);
                delete_from(&client->stream->to_process, client);
        }
@@ -979,10 +1017,28 @@ void Server::change_epoll_events(Client *client, uint32_t events)
        }
 }
 
+bool Server::more_requests(Client *client)
+{
+       if (client->close_after_response) {
+               return false;
+       }
+
+       // Switch states and reset the parsers. We don't reset statistics.
+       client->state = Client::READING_REQUEST;
+       client->url.clear();
+       client->stream = NULL;
+       client->header_or_short_response.clear();
+       client->header_or_short_response_bytes_sent = 0;
+
+       change_epoll_events(client, EPOLLIN | EPOLLET | EPOLLRDHUP);  // No TLS handshake, so no EPOLLOUT needed.
+
+       return true;
+}
+
 void Server::process_queued_data()
 {
        {
-               MutexLock lock(&queued_clients_mutex);
+               lock_guard<mutex> lock(queued_clients_mutex);
 
                for (const pair<int, Acceptor *> &id_and_acceptor : queued_add_clients) {
                        add_client(id_and_acceptor.first, id_and_acceptor.second);
@@ -990,7 +1046,7 @@ void Server::process_queued_data()
                queued_add_clients.clear();
        }
 
-       for (Stream *stream : streams) {
+       for (unique_ptr<Stream> &stream : streams) {
                stream->process_queued_data();
        }
 }