]> git.sesse.net Git - cubemap/blobdiff - server.cpp
Stop leaking TLS contexts.
[cubemap] / server.cpp
index 608ed6b70b504e0d0ff4152f77568b6a5fac1236..7f07483390a8b2846013e594668831b48d5b884b 100644 (file)
@@ -17,6 +17,9 @@
 #include <utility>
 #include <vector>
 
+#include "tlse.h"
+
+#include "acceptor.h"
 #include "accesslog.h"
 #include "log.h"
 #include "metacube2.h"
@@ -204,14 +207,15 @@ CubemapStateProto Server::serialize()
        return serialized;
 }
 
-void Server::add_client_deferred(int sock)
+void Server::add_client_deferred(int sock, Acceptor *acceptor)
 {
        MutexLock lock(&queued_clients_mutex);
-       queued_add_clients.push_back(sock);
+       queued_add_clients.push_back(std::make_pair(sock, acceptor));
 }
 
-void Server::add_client(int sock)
+void Server::add_client(int sock, Acceptor *acceptor)
 {
+       const bool is_tls = acceptor->is_tls();
        pair<map<int, Client>::iterator, bool> ret =
                clients.insert(make_pair(sock, Client(sock)));
        assert(ret.second == true);  // Should not already exist.
@@ -230,13 +234,32 @@ void Server::add_client(int sock)
 
        // Start listening on data from this socket.
        epoll_event ev;
-       ev.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
+       if (is_tls) {
+               // Even in the initial state (READING_REQUEST), TLS needs to
+               // send data for the handshake, and thus might end up needing
+               // to know about EPOLLOUT.
+               ev.events = EPOLLIN | EPOLLOUT | EPOLLET | EPOLLRDHUP;
+       } else {
+               // EPOLLOUT will be added once we go out of READING_REQUEST.
+               ev.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
+       }
        ev.data.u64 = reinterpret_cast<uint64_t>(client_ptr);
        if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, sock, &ev) == -1) {
                log_perror("epoll_ctl(EPOLL_CTL_ADD)");
                exit(1);
        }
 
+       if (is_tls) {
+               assert(tls_server_contexts.count(acceptor));
+               client_ptr->tls_context = tls_accept(tls_server_contexts[acceptor]);
+               if (client_ptr->tls_context == NULL) {
+                       log(ERROR, "tls_accept() failed");
+                       close_client(client_ptr);
+                       return;
+               }
+               tls_make_exportable(client_ptr->tls_context, 1);
+       }
+
        process_client(client_ptr);
 }
 
@@ -264,7 +287,12 @@ void Server::add_client_from_serialized(const ClientProto &client)
        // Start listening on data from this socket.
        epoll_event ev;
        if (client.state() == Client::READING_REQUEST) {
-               ev.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
+               // See the corresponding comment in Server::add_client().
+               if (client.has_tls_context()) {
+                       ev.events = EPOLLIN | EPOLLOUT | EPOLLET | EPOLLRDHUP;
+               } else {
+                       ev.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
+               }
        } else {
                // If we don't have more data for this client, we'll be putting it into
                // the sleeping array again soon.
@@ -373,6 +401,24 @@ void Server::add_gen204(const std::string &url, const std::string &allow_origin)
        ping_url_map[url] = allow_origin;
 }
 
+void Server::create_tls_context_for_acceptor(const Acceptor *acceptor)
+{
+       assert(acceptor->is_tls());
+
+       bool is_server = true;
+       TLSContext *server_context = tls_create_context(is_server, TLS_V12);
+
+       const string &cert = acceptor->get_certificate_chain();
+       int num_cert = tls_load_certificates(server_context, reinterpret_cast<const unsigned char *>(cert.data()), cert.size());
+       assert(num_cert > 0);  // Should have been checked by config earlier.
+
+       const string &key = acceptor->get_private_key();
+       int num_key = tls_load_private_key(server_context, reinterpret_cast<const unsigned char *>(key.data()), key.size());
+       assert(num_key > 0);  // Should have been checked by config earlier.
+
+       tls_server_contexts.insert(make_pair(acceptor, server_context));
+}
+
 void Server::add_data_deferred(int stream_index, const char *data, size_t bytes, uint16_t metacube_flags)
 {
        assert(stream_index >= 0 && stream_index < ssize_t(streams.size()));
@@ -384,28 +430,29 @@ void Server::process_client(Client *client)
 {
        switch (client->state) {
        case Client::READING_REQUEST: {
+               if (client->tls_context != NULL) {
+                       if (send_pending_tls_data(client)) {
+                               // send_pending_tls_data() hit postconditions #1 or #4.
+                               return;
+                       }
+               }
+
 read_request_again:
                // Try to read more of the request.
                char buf[1024];
                int ret;
-               do {
-                       ret = read(client->sock, buf, sizeof(buf));
-               } while (ret == -1 && errno == EINTR);
-
-               if (ret == -1 && errno == EAGAIN) {
-                       // No more data right now. Nothing to do.
-                       // This is postcondition #2.
-                       return;
-               }
-               if (ret == -1) {
-                       log_perror("read");
-                       close_client(client);
-                       return;
-               }
-               if (ret == 0) {
-                       // OK, the socket is closed.
-                       close_client(client);
-                       return;
+               if (client->tls_context == NULL) {
+                       ret = read_nontls_data(client, buf, sizeof(buf));
+                       if (ret == -1) {
+                               // read_nontls_data() hit postconditions #1 or #2.
+                               return;
+                       }
+               } else {
+                       ret = read_tls_data(client, buf, sizeof(buf));
+                       if (ret == -1) {
+                               // read_tls_data() hit postconditions #1, #2 or #4.
+                               return;
+                       }
                }
 
                RequestParseStatus status = wait_for_double_newline(&client->request, buf, ret);
@@ -429,6 +476,22 @@ read_request_again:
 
                assert(status == RP_FINISHED);
 
+               if (client->tls_context && !client->in_ktls_mode && tls_established(client->tls_context)) {
+                       // We're ready to enter kTLS mode, unless we still have some
+                       // handshake data to send (which then must be sent as non-kTLS).
+                       if (send_pending_tls_data(client)) {
+                               // send_pending_tls_data() hit postconditions #1 or #4.
+                               return;
+                       }
+                       ret = tls_make_ktls(client->tls_context, client->sock);
+                       if (ret < 0) {
+                               log_tls_error("tls_make_ktls", ret);
+                               close_client(client);
+                               return;
+                       }
+                       client->in_ktls_mode = true;
+               }
+
                int error_code = parse_request(client);
                if (error_code == 200) {
                        construct_header(client);
@@ -607,6 +670,133 @@ sending_data_again:
        }
 }
 
+bool Server::send_pending_tls_data(Client *client)
+{
+       // See if there's data from the TLS library to write.
+       if (client->tls_data_to_send == NULL) {
+               client->tls_data_to_send = tls_get_write_buffer(client->tls_context, &client->tls_data_left_to_send);
+               if (client->tls_data_to_send == NULL) {
+                       // Really no data to send.
+                       return false;
+               }
+       }
+
+send_data_again:
+       int ret;
+       do {
+               ret = write(client->sock, client->tls_data_to_send, client->tls_data_left_to_send);
+       } while (ret == -1 && errno == EINTR);
+       assert(ret < 0 || size_t(ret) <= client->tls_data_left_to_send);
+
+       if (ret == -1 && errno == EAGAIN) {
+               // We're out of socket space, so now we're at the “low edge” of epoll's
+               // edge triggering. epoll will tell us when there is more room, so for now,
+               // just return.
+               // This is postcondition #4.
+               return true;
+       }
+       if (ret == -1) {
+               // Error! Postcondition #1.
+               log_perror("write");
+               close_client(client);
+               return true;
+       }
+       if (ret > 0 && size_t(ret) == client->tls_data_left_to_send) {
+               // All data has been sent, so we don't need to go to sleep.
+               tls_buffer_clear(client->tls_context);
+               client->tls_data_to_send = NULL;
+               return false;
+       }
+
+       // More data to send, so try again.
+       client->tls_data_to_send += ret;
+       client->tls_data_left_to_send -= ret;
+       goto send_data_again;
+}
+
+int Server::read_nontls_data(Client *client, char *buf, size_t max_size)
+{
+       int ret;
+       do {
+               ret = read(client->sock, buf, max_size);
+       } while (ret == -1 && errno == EINTR);
+
+       if (ret == -1 && errno == EAGAIN) {
+               // No more data right now. Nothing to do.
+               // This is postcondition #2.
+               return -1;
+       }
+       if (ret == -1) {
+               log_perror("read");
+               close_client(client);
+               return -1;
+       }
+       if (ret == 0) {
+               // OK, the socket is closed.
+               close_client(client);
+               return -1;
+       }
+
+       return ret;
+}
+
+int Server::read_tls_data(Client *client, char *buf, size_t max_size)
+{
+read_again:
+       int ret;
+       do {
+               ret = read(client->sock, buf, max_size);
+       } while (ret == -1 && errno == EINTR);
+
+       if (ret == -1 && errno == EAGAIN) {
+               // No more data right now. Nothing to do.
+               // This is postcondition #2.
+               return -1;
+       }
+       if (ret == -1) {
+               log_perror("read");
+               close_client(client);
+               return -1;
+       }
+       if (ret == 0) {
+               // OK, the socket is closed.
+               close_client(client);
+               return -1;
+       }
+
+       // Give it to the TLS library.
+       int err = tls_consume_stream(client->tls_context, reinterpret_cast<const unsigned char *>(buf), ret, nullptr);
+       if (err < 0) {
+               log_tls_error("tls_consume_stream", err);
+               close_client(client);
+               return -1;
+       }
+       if (err == 0) {
+               // Not consumed any data. See if we can read more.
+               goto read_again;
+       }
+
+       // Read any decrypted data available for us. (We can reuse buf, since it's free now.)
+       ret = tls_read(client->tls_context, reinterpret_cast<unsigned char *>(buf), max_size);
+       if (ret == 0) {
+               // No decrypted data for us yet, but there might be some more handshaking
+               // to send. Do that if needed, then look for more data.
+               if (send_pending_tls_data(client)) {
+                       // send_pending_tls_data() hit postconditions #1 or #4.
+                       return -1;
+               }
+               goto read_again;
+       }
+       if (ret < 0) {
+               log_tls_error("tls_read", ret);
+               close_client(client);
+               return -1;
+       }
+
+       assert(ret > 0);
+       return ret;
+}
+
 // See if there's some data we've lost. Ideally, we should drop to a block boundary,
 // but resync will be the mux's problem.
 void Server::skip_lost_data(Client *client)
@@ -796,6 +986,10 @@ void Server::close_client(Client *client)
                delete_from(&client->stream->to_process, client);
        }
 
+       if (client->tls_context) {
+               tls_destroy_context(client->tls_context);
+       }
+
        // Log to access_log.
        access_log->write(client->get_stats());
 
@@ -811,7 +1005,7 @@ void Server::process_queued_data()
                MutexLock lock(&queued_clients_mutex);
 
                for (size_t i = 0; i < queued_add_clients.size(); ++i) {
-                       add_client(queued_add_clients[i]);
+                       add_client(queued_add_clients[i].first, queued_add_clients[i].second);
                }
                queued_add_clients.clear();
        }