]> git.sesse.net Git - cubemap/blob - client.cpp
Fix infinite CPU usage on waitpid().
[cubemap] / client.cpp
1 #include <stdio.h>
2 #include <arpa/inet.h>
3 #include <fcntl.h>
4 #include <netinet/in.h>
5 #include <stdint.h>
6 #include <sys/socket.h>
7 #include <unistd.h>
8
9 #include "client.h"
10 #include "log.h"
11 #include "state.pb.h"
12 #include "stream.h"
13
14 #ifndef SO_MAX_PACING_RATE
15 #define SO_MAX_PACING_RATE 47
16 #endif
17
18 using namespace std;
19
20 Client::Client(int sock)
21         : sock(sock)
22 {
23         request.reserve(1024);
24
25         // Find the remote address, and convert it to ASCII.
26         sockaddr_in6 addr;
27         socklen_t addr_len = sizeof(addr);
28
29         if (getpeername(sock, reinterpret_cast<sockaddr *>(&addr), &addr_len) == -1) {
30                 log_perror("getpeername");
31                 remote_addr = "";
32                 return;
33         }
34
35         char buf[INET6_ADDRSTRLEN];
36         if (IN6_IS_ADDR_V4MAPPED(&addr.sin6_addr)) {
37                 // IPv4 address, really.
38                 if (inet_ntop(AF_INET, &addr.sin6_addr.s6_addr32[3], buf, sizeof(buf)) == nullptr) {
39                         log_perror("inet_ntop");
40                         remote_addr = "";
41                 } else {
42                         remote_addr = buf;
43                 }
44         } else {
45                 if (inet_ntop(addr.sin6_family, &addr.sin6_addr, buf, sizeof(buf)) == nullptr) {
46                         log_perror("inet_ntop");
47                         remote_addr = "";
48                 } else {
49                         remote_addr = buf;
50                 }
51         }
52 }
53         
54 Client::Client(const ClientProto &serialized, const vector<shared_ptr<const string>> &short_responses, Stream *stream)
55         : sock(serialized.sock()),
56           remote_addr(serialized.remote_addr()),
57           referer(serialized.referer()),
58           user_agent(serialized.user_agent()),
59           x_playback_session_id(serialized.x_playback_session_id()),
60           state(State(serialized.state())),
61           request(serialized.request()),
62           url(serialized.url()),
63           stream(stream),
64           close_after_response(serialized.close_after_response()),
65           http_11(serialized.http_11()),
66           header_or_short_response_bytes_sent(serialized.header_or_short_response_bytes_sent()),
67           stream_pos(serialized.stream_pos()),
68           stream_pos_end(serialized.stream_pos_end()),
69           bytes_sent(serialized.bytes_sent()),
70           bytes_lost(serialized.bytes_lost()),
71           num_loss_events(serialized.num_loss_events())
72 {
73         // Set back the close-on-exec flag for the socket.
74         // (This can't leak into a child, since we haven't been started yet.)
75         fcntl(sock, F_SETFD, 1);
76
77         if (stream != nullptr) {
78                 if (setsockopt(sock, SOL_SOCKET, SO_MAX_PACING_RATE, &stream->pacing_rate, sizeof(stream->pacing_rate)) == -1) {
79                         if (stream->pacing_rate != ~0U) {
80                                 log_perror("setsockopt(SO_MAX_PACING_RATE)");
81                         }
82                 }
83         }
84
85         if (serialized.has_header_or_short_response_old()) {
86                 // Pre-1.4.0.
87                 header_or_short_response_holder = serialized.header_or_short_response_old();
88                 header_or_short_response = &header_or_short_response_holder;
89         } else if (serialized.has_header_or_short_response_index()) {
90                 assert(size_t(serialized.header_or_short_response_index()) < short_responses.size());
91                 header_or_short_response_ref = short_responses[serialized.header_or_short_response_index()];
92                 header_or_short_response = header_or_short_response_ref.get();
93         }
94         connect_time.tv_sec = serialized.connect_time_sec();
95         connect_time.tv_nsec = serialized.connect_time_nsec();
96
97         in_ktls_mode = false;
98         if (serialized.has_tls_context()) {
99                 tls_context = tls_import_context(
100                         reinterpret_cast<const unsigned char *>(serialized.tls_context().data()),
101                         serialized.tls_context().size());
102                 if (tls_context == nullptr) {
103                         log(WARNING, "tls_import_context() failed, TLS client might not survive across restart");
104                 } else {
105                         tls_data_to_send = tls_get_write_buffer(tls_context, &tls_data_left_to_send);
106
107                         assert(serialized.tls_output_bytes_already_consumed() <= tls_data_left_to_send);
108                         if (serialized.tls_output_bytes_already_consumed() >= tls_data_left_to_send) {
109                                 tls_buffer_clear(tls_context);
110                                 tls_data_to_send = nullptr;
111                         } else {
112                                 tls_data_to_send += serialized.tls_output_bytes_already_consumed();
113                                 tls_data_left_to_send -= serialized.tls_output_bytes_already_consumed();
114                         }
115                         in_ktls_mode = serialized.in_ktls_mode();
116                 }
117         } else {
118                 tls_context = nullptr;
119         }
120 }
121
122 ClientProto Client::serialize(unordered_map<const string *, size_t> *short_response_pool) const
123 {
124         // Unset the close-on-exec flag for the socket.
125         // (This can't leak into a child, since there's only one thread left.)
126         fcntl(sock, F_SETFD, 0);
127
128         ClientProto serialized;
129         serialized.set_sock(sock);
130         serialized.set_remote_addr(remote_addr);
131         serialized.set_referer(referer);
132         serialized.set_user_agent(user_agent);
133         serialized.set_x_playback_session_id(x_playback_session_id);
134         serialized.set_connect_time_sec(connect_time.tv_sec);
135         serialized.set_connect_time_nsec(connect_time.tv_nsec);
136         serialized.set_state(state);
137         serialized.set_request(request);
138         serialized.set_url(url);
139
140         if (header_or_short_response != nullptr) {
141                 // See if this string is already in the pool (deduplicated by the pointer); if not, insert it.
142                 auto iterator_and_inserted = short_response_pool->emplace(
143                         header_or_short_response, short_response_pool->size());
144                 serialized.set_header_or_short_response_index(iterator_and_inserted.first->second);
145         }
146
147         serialized.set_header_or_short_response_bytes_sent(header_or_short_response_bytes_sent);
148         serialized.set_stream_pos(stream_pos);
149         serialized.set_stream_pos_end(stream_pos_end);
150         serialized.set_bytes_sent(bytes_sent);
151         serialized.set_bytes_lost(bytes_lost);
152         serialized.set_num_loss_events(num_loss_events);
153         serialized.set_http_11(http_11);
154         serialized.set_close_after_response(close_after_response);
155
156         if (tls_context != nullptr) {
157                 bool small_version = false;
158                 int required_size = tls_export_context(tls_context, nullptr, 0, small_version);
159                 if (required_size <= 0) {
160                         // Can happen if we're in the middle of the key exchange, unfortunately.
161                         // We'll get an error fairly fast, and this client hasn't started playing
162                         // anything yet, so just log the error and continue.
163                         //
164                         // In theory, we could still rescue it if we had sent _zero_ bytes,
165                         // by doing an entirely new TLS context, but it's an edge case
166                         // that's not really worth it.
167                         log(WARNING, "tls_export_context() failed (returned %d), TLS client might not survive across restart",
168                                 required_size);
169                 } else {
170                         string *serialized_context = serialized.mutable_tls_context();
171                         serialized_context->resize(required_size);
172
173                         int ret = tls_export_context(tls_context,
174                                 reinterpret_cast<unsigned char *>(&(*serialized_context)[0]),
175                                 serialized_context->size(),
176                                 small_version);
177                         assert(ret == required_size);
178
179                         // tls_export_context() has exported the contents of the write buffer, but it doesn't
180                         // know how much of that we've consumed, so we need to figure that out by ourselves.
181                         // In a sense, it's unlikely that this will ever be relevant, though, since TLSe can't
182                         // currently serialize in-progress key exchanges.
183                         unsigned base_tls_data_left_to_send;
184                         const unsigned char *base_tls_data_to_send = tls_get_write_buffer(tls_context, &base_tls_data_left_to_send);
185                         if (base_tls_data_to_send == nullptr) {
186                                 assert(tls_data_to_send == nullptr);
187                         } else {
188                                 assert(tls_data_to_send + tls_data_left_to_send == base_tls_data_to_send + base_tls_data_left_to_send);
189                         }
190                         serialized.set_tls_output_bytes_already_consumed(tls_data_to_send - base_tls_data_to_send);
191                         serialized.set_in_ktls_mode(in_ktls_mode);
192                 }
193         }
194
195         return serialized;
196 }
197
198 namespace {
199
200 string escape_string(const string &str) {
201         string ret;
202         for (size_t i = 0; i < str.size(); ++i) {
203                 char buf[16];
204                 if (isprint(str[i]) && str[i] >= 32 && str[i] != '"' && str[i] != '\\') {
205                         ret.push_back(str[i]);
206                 } else {
207                         snprintf(buf, sizeof(buf), "\\x%02x", (unsigned char)str[i]);
208                         ret += buf;
209                 }
210         }
211         return ret;
212 }
213
214 } // namespace
215         
216 ClientStats Client::get_stats() const
217 {
218         ClientStats stats;
219         if (url.empty()) {
220                 stats.url = "-";
221         } else {
222                 stats.url = url;
223         }
224         stats.sock = sock;
225         stats.remote_addr = remote_addr;
226         stats.referer = escape_string(referer);
227         stats.user_agent = escape_string(user_agent);
228         stats.connect_time = connect_time;
229         stats.bytes_sent = bytes_sent;
230         stats.bytes_lost = bytes_lost;
231         stats.num_loss_events = num_loss_events;
232         stats.hls_zombie_key = get_hls_zombie_key();
233         return stats;
234 }