]> git.sesse.net Git - cubemap/blobdiff - client.cpp
Update metacube2.h with the latest version (sync with Nageru).
[cubemap] / client.cpp
index 786e17825817e9bfa9360042a099c411646254c7..7c58d515485b03d27e9e26650dacb4c36116781f 100644 (file)
@@ -19,11 +19,15 @@ Client::Client(int sock)
        : sock(sock),
          state(Client::READING_REQUEST),
          stream(NULL),
-         header_or_error_bytes_sent(0),
+         header_or_short_response_bytes_sent(0),
          stream_pos(0),
          bytes_sent(0),
          bytes_lost(0),
-         num_loss_events(0)
+         num_loss_events(0),
+         tls_context(NULL),
+         tls_data_to_send(NULL),
+         tls_data_left_to_send(0),
+         in_ktls_mode(false)
 {
        request.reserve(1024);
 
@@ -70,8 +74,8 @@ Client::Client(const ClientProto &serialized, Stream *stream)
          request(serialized.request()),
          url(serialized.url()),
          stream(stream),
-         header_or_error(serialized.header_or_error()),
-         header_or_error_bytes_sent(serialized.header_or_error_bytes_sent()),
+         header_or_short_response(serialized.header_or_short_response()),
+         header_or_short_response_bytes_sent(serialized.header_or_short_response_bytes_sent()),
          stream_pos(serialized.stream_pos()),
          bytes_sent(serialized.bytes_sent()),
          bytes_lost(serialized.bytes_lost()),
@@ -86,6 +90,30 @@ Client::Client(const ClientProto &serialized, Stream *stream)
        }
        connect_time.tv_sec = serialized.connect_time_sec();
        connect_time.tv_nsec = serialized.connect_time_nsec();
+
+       in_ktls_mode = false;
+       if (serialized.has_tls_context()) {
+               tls_context = tls_import_context(
+                       reinterpret_cast<const unsigned char *>(serialized.tls_context().data()),
+                       serialized.tls_context().size());
+               if (tls_context == NULL) {
+                       log(WARNING, "tls_import_context() failed, TLS client might not survive across restart");
+               } else {
+                       tls_data_to_send = tls_get_write_buffer(tls_context, &tls_data_left_to_send);
+
+                       assert(serialized.tls_output_bytes_already_consumed() <= tls_data_left_to_send);
+                       if (serialized.tls_output_bytes_already_consumed() >= tls_data_left_to_send) {
+                               tls_buffer_clear(tls_context);
+                               tls_data_to_send = NULL;
+                       } else {
+                               tls_data_to_send += serialized.tls_output_bytes_already_consumed();
+                               tls_data_left_to_send -= serialized.tls_output_bytes_already_consumed();
+                       }
+                       in_ktls_mode = serialized.in_ktls_mode();
+               }
+       } else {
+               tls_context = NULL;
+       }
 }
 
 ClientProto Client::serialize() const
@@ -100,12 +128,52 @@ ClientProto Client::serialize() const
        serialized.set_state(state);
        serialized.set_request(request);
        serialized.set_url(url);
-       serialized.set_header_or_error(header_or_error);
-       serialized.set_header_or_error_bytes_sent(serialized.header_or_error_bytes_sent());
+       serialized.set_header_or_short_response(header_or_short_response);
+       serialized.set_header_or_short_response_bytes_sent(serialized.header_or_short_response_bytes_sent());
        serialized.set_stream_pos(stream_pos);
        serialized.set_bytes_sent(bytes_sent);
        serialized.set_bytes_lost(bytes_lost);
        serialized.set_num_loss_events(num_loss_events);
+
+       if (tls_context != NULL) {
+               bool small_version = false;
+               int required_size = tls_export_context(tls_context, NULL, 0, small_version);
+               if (required_size <= 0) {
+                       // Can happen if we're in the middle of the key exchange, unfortunately.
+                       // We'll get an error fairly fast, and this client hasn't started playing
+                       // anything yet, so just log the error and continue.
+                       //
+                       // In theory, we could still rescue it if we had sent _zero_ bytes,
+                       // by doing an entirely new TLS context, but it's an edge case
+                       // that's not really worth it.
+                       log(WARNING, "tls_export_context() failed (returned %d), TLS client might not survive across restart",
+                               required_size);
+               } else {
+                       string *serialized_context = serialized.mutable_tls_context();
+                       serialized_context->resize(required_size);
+
+                       int ret = tls_export_context(tls_context,
+                               reinterpret_cast<unsigned char *>(&(*serialized_context)[0]),
+                               serialized_context->size(),
+                               small_version);
+                       assert(ret == required_size);
+
+                       // tls_export_context() has exported the contents of the write buffer, but it doesn't
+                       // know how much of that we've consumed, so we need to figure that out by ourselves.
+                       // In a sense, it's unlikely that this will ever be relevant, though, since TLSe can't
+                       // currently serialize in-progress key exchanges.
+                       unsigned base_tls_data_left_to_send;
+                       const unsigned char *base_tls_data_to_send = tls_get_write_buffer(tls_context, &base_tls_data_left_to_send);
+                       if (base_tls_data_to_send == NULL) {
+                               assert(tls_data_to_send == NULL);
+                       } else {
+                               assert(tls_data_to_send + tls_data_left_to_send == base_tls_data_to_send + base_tls_data_left_to_send);
+                       }
+                       serialized.set_tls_output_bytes_already_consumed(tls_data_to_send - base_tls_data_to_send);
+                       serialized.set_in_ktls_mode(in_ktls_mode);
+               }
+       }
+
        return serialized;
 }