]> git.sesse.net Git - cubemap/blobdiff - client.cpp
Fix serialization of Client::header_or_short_response_bytes_sent.
[cubemap] / client.cpp
index eea41a5e0da43d40b0c468b206158f69198ac1b6..f02050b5588901fd31142558134899f59512464b 100644 (file)
@@ -1,3 +1,4 @@
+#include <stdio.h>
 #include <arpa/inet.h>
 #include <netinet/in.h>
 #include <stdint.h>
@@ -5,7 +6,6 @@
 
 #include "client.h"
 #include "log.h"
-#include "markpool.h"
 #include "state.pb.h"
 #include "stream.h"
 
 using namespace std;
 
 Client::Client(int sock)
-       : sock(sock),
-         fwmark(0),
-         connect_time(time(NULL)),
-         state(Client::READING_REQUEST),
-         stream(NULL),
-         header_or_error_bytes_sent(0),
-         stream_pos(0),
-         bytes_sent(0),
-         bytes_lost(0),
-         num_loss_events(0)
+       : sock(sock)
 {
        request.reserve(1024);
 
@@ -42,14 +33,14 @@ Client::Client(int sock)
        char buf[INET6_ADDRSTRLEN];
        if (IN6_IS_ADDR_V4MAPPED(&addr.sin6_addr)) {
                // IPv4 address, really.
-               if (inet_ntop(AF_INET, &addr.sin6_addr.s6_addr32[3], buf, sizeof(buf)) == NULL) {
+               if (inet_ntop(AF_INET, &addr.sin6_addr.s6_addr32[3], buf, sizeof(buf)) == nullptr) {
                        log_perror("inet_ntop");
                        remote_addr = "";
                } else {
                        remote_addr = buf;
                }
        } else {
-               if (inet_ntop(addr.sin6_family, &addr.sin6_addr, buf, sizeof(buf)) == NULL) {
+               if (inet_ntop(addr.sin6_family, &addr.sin6_addr, buf, sizeof(buf)) == nullptr) {
                        log_perror("inet_ntop");
                        remote_addr = "";
                } else {
@@ -58,56 +49,159 @@ Client::Client(int sock)
        }
 }
        
-Client::Client(const ClientProto &serialized, Stream *stream)
+Client::Client(const ClientProto &serialized, const vector<shared_ptr<const string>> &short_responses, Stream *stream)
        : sock(serialized.sock()),
          remote_addr(serialized.remote_addr()),
-         connect_time(serialized.connect_time()),
+         referer(serialized.referer()),
+         user_agent(serialized.user_agent()),
+         x_playback_session_id(serialized.x_playback_session_id()),
          state(State(serialized.state())),
          request(serialized.request()),
          url(serialized.url()),
          stream(stream),
-         header_or_error(serialized.header_or_error()),
-         header_or_error_bytes_sent(serialized.header_or_error_bytes_sent()),
+         close_after_response(serialized.close_after_response()),
+         http_11(serialized.http_11()),
+         header_or_short_response_bytes_sent(serialized.header_or_short_response_bytes_sent()),
          stream_pos(serialized.stream_pos()),
+         stream_pos_end(serialized.stream_pos_end()),
          bytes_sent(serialized.bytes_sent()),
          bytes_lost(serialized.bytes_lost()),
          num_loss_events(serialized.num_loss_events())
 {
-       if (stream != NULL && stream->mark_pool != NULL) {
-               fwmark = stream->mark_pool->get_mark();
-       } else {
-               fwmark = 0;  // No mark.
-       }
-       if (setsockopt(sock, SOL_SOCKET, SO_MARK, &fwmark, sizeof(fwmark)) == -1) {
-               if (fwmark != 0) {
-                       log_perror("setsockopt(SO_MARK)");
+       if (stream != nullptr) {
+               if (setsockopt(sock, SOL_SOCKET, SO_MAX_PACING_RATE, &stream->pacing_rate, sizeof(stream->pacing_rate)) == -1) {
+                       if (stream->pacing_rate != ~0U) {
+                               log_perror("setsockopt(SO_MAX_PACING_RATE)");
+                       }
                }
-               fwmark = 0;
        }
-       if (setsockopt(sock, SOL_SOCKET, SO_MAX_PACING_RATE, &stream->pacing_rate, sizeof(stream->pacing_rate)) == -1) {
-               if (stream->pacing_rate != ~0U) {
-                       log_perror("setsockopt(SO_MAX_PACING_RATE)");
+
+       if (serialized.has_header_or_short_response_old()) {
+               // Pre-1.4.0.
+               header_or_short_response_holder = serialized.header_or_short_response_old();
+               header_or_short_response = &header_or_short_response_holder;
+       } else if (serialized.has_header_or_short_response_index()) {
+               assert(size_t(serialized.header_or_short_response_index()) < short_responses.size());
+               header_or_short_response_ref = short_responses[serialized.header_or_short_response_index()];
+               header_or_short_response = header_or_short_response_ref.get();
+       }
+       connect_time.tv_sec = serialized.connect_time_sec();
+       connect_time.tv_nsec = serialized.connect_time_nsec();
+
+       in_ktls_mode = false;
+       if (serialized.has_tls_context()) {
+               tls_context = tls_import_context(
+                       reinterpret_cast<const unsigned char *>(serialized.tls_context().data()),
+                       serialized.tls_context().size());
+               if (tls_context == nullptr) {
+                       log(WARNING, "tls_import_context() failed, TLS client might not survive across restart");
+               } else {
+                       tls_data_to_send = tls_get_write_buffer(tls_context, &tls_data_left_to_send);
+
+                       assert(serialized.tls_output_bytes_already_consumed() <= tls_data_left_to_send);
+                       if (serialized.tls_output_bytes_already_consumed() >= tls_data_left_to_send) {
+                               tls_buffer_clear(tls_context);
+                               tls_data_to_send = nullptr;
+                       } else {
+                               tls_data_to_send += serialized.tls_output_bytes_already_consumed();
+                               tls_data_left_to_send -= serialized.tls_output_bytes_already_consumed();
+                       }
+                       in_ktls_mode = serialized.in_ktls_mode();
                }
+       } else {
+               tls_context = nullptr;
        }
 }
 
-ClientProto Client::serialize() const
+ClientProto Client::serialize(unordered_map<const string *, size_t> *short_response_pool) const
 {
        ClientProto serialized;
        serialized.set_sock(sock);
        serialized.set_remote_addr(remote_addr);
-       serialized.set_connect_time(connect_time);
+       serialized.set_referer(referer);
+       serialized.set_user_agent(user_agent);
+       serialized.set_x_playback_session_id(x_playback_session_id);
+       serialized.set_connect_time_sec(connect_time.tv_sec);
+       serialized.set_connect_time_nsec(connect_time.tv_nsec);
        serialized.set_state(state);
        serialized.set_request(request);
        serialized.set_url(url);
-       serialized.set_header_or_error(header_or_error);
-       serialized.set_header_or_error_bytes_sent(serialized.header_or_error_bytes_sent());
+
+       if (header_or_short_response != nullptr) {
+               // See if this string is already in the pool (deduplicated by the pointer); if not, insert it.
+               auto iterator_and_inserted = short_response_pool->emplace(
+                       header_or_short_response, short_response_pool->size());
+               serialized.set_header_or_short_response_index(iterator_and_inserted.first->second);
+       }
+
+       serialized.set_header_or_short_response_bytes_sent(header_or_short_response_bytes_sent);
        serialized.set_stream_pos(stream_pos);
+       serialized.set_stream_pos_end(stream_pos_end);
        serialized.set_bytes_sent(bytes_sent);
        serialized.set_bytes_lost(bytes_lost);
        serialized.set_num_loss_events(num_loss_events);
+       serialized.set_http_11(http_11);
+       serialized.set_close_after_response(close_after_response);
+
+       if (tls_context != nullptr) {
+               bool small_version = false;
+               int required_size = tls_export_context(tls_context, nullptr, 0, small_version);
+               if (required_size <= 0) {
+                       // Can happen if we're in the middle of the key exchange, unfortunately.
+                       // We'll get an error fairly fast, and this client hasn't started playing
+                       // anything yet, so just log the error and continue.
+                       //
+                       // In theory, we could still rescue it if we had sent _zero_ bytes,
+                       // by doing an entirely new TLS context, but it's an edge case
+                       // that's not really worth it.
+                       log(WARNING, "tls_export_context() failed (returned %d), TLS client might not survive across restart",
+                               required_size);
+               } else {
+                       string *serialized_context = serialized.mutable_tls_context();
+                       serialized_context->resize(required_size);
+
+                       int ret = tls_export_context(tls_context,
+                               reinterpret_cast<unsigned char *>(&(*serialized_context)[0]),
+                               serialized_context->size(),
+                               small_version);
+                       assert(ret == required_size);
+
+                       // tls_export_context() has exported the contents of the write buffer, but it doesn't
+                       // know how much of that we've consumed, so we need to figure that out by ourselves.
+                       // In a sense, it's unlikely that this will ever be relevant, though, since TLSe can't
+                       // currently serialize in-progress key exchanges.
+                       unsigned base_tls_data_left_to_send;
+                       const unsigned char *base_tls_data_to_send = tls_get_write_buffer(tls_context, &base_tls_data_left_to_send);
+                       if (base_tls_data_to_send == nullptr) {
+                               assert(tls_data_to_send == nullptr);
+                       } else {
+                               assert(tls_data_to_send + tls_data_left_to_send == base_tls_data_to_send + base_tls_data_left_to_send);
+                       }
+                       serialized.set_tls_output_bytes_already_consumed(tls_data_to_send - base_tls_data_to_send);
+                       serialized.set_in_ktls_mode(in_ktls_mode);
+               }
+       }
+
        return serialized;
 }
+
+namespace {
+
+string escape_string(const string &str) {
+       string ret;
+       for (size_t i = 0; i < str.size(); ++i) {
+               char buf[16];
+               if (isprint(str[i]) && str[i] >= 32 && str[i] != '"' && str[i] != '\\') {
+                       ret.push_back(str[i]);
+               } else {
+                       snprintf(buf, sizeof(buf), "\\x%02x", (unsigned char)str[i]);
+                       ret += buf;
+               }
+       }
+       return ret;
+}
+
+} // namespace
        
 ClientStats Client::get_stats() const
 {
@@ -118,11 +212,13 @@ ClientStats Client::get_stats() const
                stats.url = url;
        }
        stats.sock = sock;
-       stats.fwmark = fwmark;
        stats.remote_addr = remote_addr;
+       stats.referer = escape_string(referer);
+       stats.user_agent = escape_string(user_agent);
        stats.connect_time = connect_time;
        stats.bytes_sent = bytes_sent;
        stats.bytes_lost = bytes_lost;
        stats.num_loss_events = num_loss_events;
+       stats.hls_zombie_key = get_hls_zombie_key();
        return stats;
 }