]> git.sesse.net Git - cubemap/blob - main.cpp
Do not crash on invalid src= URLs.
[cubemap] / main.cpp
1 #include <assert.h>
2 #include <errno.h>
3 #include <getopt.h>
4 #include <limits.h>
5 #include <signal.h>
6 #include <stddef.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <string.h>
10 #include <sys/time.h>
11 #include <sys/wait.h>
12 #include <unistd.h>
13 #include <map>
14 #include <set>
15 #include <string>
16 #include <utility>
17 #include <vector>
18
19 #include "acceptor.h"
20 #include "accesslog.h"
21 #include "config.h"
22 #include "input.h"
23 #include "input_stats.h"
24 #include "log.h"
25 #include "markpool.h"
26 #include "serverpool.h"
27 #include "state.pb.h"
28 #include "stats.h"
29 #include "stream.h"
30 #include "util.h"
31 #include "version.h"
32
33 using namespace std;
34
35 AccessLogThread *access_log = NULL;
36 ServerPool *servers = NULL;
37 vector<MarkPool *> mark_pools;
38 volatile bool hupped = false;
39 volatile bool stopped = false;
40
41 struct InputWithRefcount {
42         Input *input;
43         int refcount;
44 };
45
46 void hup(int signum)
47 {
48         hupped = true;
49         if (signum == SIGINT) {
50                 stopped = true;
51         }
52 }
53
54 void do_nothing(int signum)
55 {
56 }
57
58 CubemapStateProto collect_state(const timeval &serialize_start,
59                                 const vector<Acceptor *> acceptors,
60                                 const multimap<string, InputWithRefcount> inputs,
61                                 ServerPool *servers)
62 {
63         CubemapStateProto state = servers->serialize();  // Fills streams() and clients().
64         state.set_serialize_start_sec(serialize_start.tv_sec);
65         state.set_serialize_start_usec(serialize_start.tv_usec);
66         
67         for (size_t i = 0; i < acceptors.size(); ++i) {
68                 state.add_acceptors()->MergeFrom(acceptors[i]->serialize());
69         }
70
71         for (multimap<string, InputWithRefcount>::const_iterator input_it = inputs.begin();
72              input_it != inputs.end();
73              ++input_it) {
74                 state.add_inputs()->MergeFrom(input_it->second.input->serialize());
75         }
76
77         return state;
78 }
79
80 // Find all port statements in the configuration file, and create acceptors for htem.
81 vector<Acceptor *> create_acceptors(
82         const Config &config,
83         map<int, Acceptor *> *deserialized_acceptors)
84 {
85         vector<Acceptor *> acceptors;
86         for (unsigned i = 0; i < config.acceptors.size(); ++i) {
87                 const AcceptorConfig &acceptor_config = config.acceptors[i];
88                 Acceptor *acceptor = NULL;
89                 map<int, Acceptor *>::iterator deserialized_acceptor_it =
90                         deserialized_acceptors->find(acceptor_config.port);
91                 if (deserialized_acceptor_it != deserialized_acceptors->end()) {
92                         acceptor = deserialized_acceptor_it->second;
93                         deserialized_acceptors->erase(deserialized_acceptor_it);
94                 } else {
95                         int server_sock = create_server_socket(acceptor_config.port, TCP_SOCKET);
96                         acceptor = new Acceptor(server_sock, acceptor_config.port);
97                 }
98                 acceptor->run();
99                 acceptors.push_back(acceptor);
100         }
101
102         // Close all acceptors that are no longer in the configuration file.
103         for (map<int, Acceptor *>::iterator acceptor_it = deserialized_acceptors->begin();
104              acceptor_it != deserialized_acceptors->end();
105              ++acceptor_it) {
106                 acceptor_it->second->close_socket();
107                 delete acceptor_it->second;
108         }
109
110         return acceptors;
111 }
112
113 void create_config_input(const string &src, multimap<string, InputWithRefcount> *inputs)
114 {
115         if (src.empty()) {
116                 return;
117         }
118         if (inputs->count(src) != 0) {
119                 return;
120         }
121
122         InputWithRefcount iwr;
123         iwr.input = create_input(src);
124         if (iwr.input == NULL) {
125                 log(ERROR, "did not understand URL '%s', clients will not get any data.",
126                         src.c_str());
127                 return;
128         }
129         iwr.refcount = 0;
130         inputs->insert(make_pair(src, iwr));
131 }
132
133 // Find all streams in the configuration file, and create inputs for them.
134 void create_config_inputs(const Config &config, multimap<string, InputWithRefcount> *inputs)
135 {
136         for (unsigned i = 0; i < config.streams.size(); ++i) {
137                 const StreamConfig &stream_config = config.streams[i];
138                 create_config_input(stream_config.src, inputs);
139         }
140         for (unsigned i = 0; i < config.udpstreams.size(); ++i) {
141                 const UDPStreamConfig &udpstream_config = config.udpstreams[i];
142                 create_config_input(udpstream_config.src, inputs);
143         }
144 }
145
146 void create_streams(const Config &config,
147                     const set<string> &deserialized_urls,
148                     multimap<string, InputWithRefcount> *inputs)
149 {
150         for (unsigned i = 0; i < config.mark_pools.size(); ++i) {
151                 const MarkPoolConfig &mp_config = config.mark_pools[i];
152                 mark_pools.push_back(new MarkPool(mp_config.from, mp_config.to));
153         }
154
155         // HTTP streams.
156         set<string> expecting_urls = deserialized_urls;
157         for (unsigned i = 0; i < config.streams.size(); ++i) {
158                 const StreamConfig &stream_config = config.streams[i];
159                 int stream_index;
160                 if (deserialized_urls.count(stream_config.url) == 0) {
161                         stream_index = servers->add_stream(stream_config.url,
162                                                            stream_config.backlog_size,
163                                                            Stream::Encoding(stream_config.encoding));
164                 } else {
165                         stream_index = servers->lookup_stream_by_url(stream_config.url);
166                         assert(stream_index != -1);
167                         servers->set_backlog_size(stream_index, stream_config.backlog_size);
168                         servers->set_encoding(stream_index,
169                                               Stream::Encoding(stream_config.encoding));
170                 }
171                 expecting_urls.erase(stream_config.url);
172
173                 if (stream_config.mark_pool != -1) {
174                         servers->set_mark_pool(stream_index, mark_pools[stream_config.mark_pool]);
175                 }
176
177                 string src = stream_config.src;
178                 if (!src.empty()) {
179                         multimap<string, InputWithRefcount>::iterator input_it = inputs->find(src);
180                         if (input_it != inputs->end()) {
181                                 input_it->second.input->add_destination(stream_index);
182                                 ++input_it->second.refcount;
183                         }
184                 }
185         }
186
187         // Warn about any HTTP servers we've lost.
188         // TODO: Make an option (delete=yes?) to actually shut down streams.
189         for (set<string>::const_iterator stream_it = expecting_urls.begin();
190              stream_it != expecting_urls.end();
191              ++stream_it) {
192                 string url = *stream_it;
193                 log(WARNING, "stream '%s' disappeared from the configuration file. "
194                              "It will not be deleted, but clients will not get any new inputs.",
195                              url.c_str());
196         }
197
198         // UDP streams.
199         for (unsigned i = 0; i < config.udpstreams.size(); ++i) {
200                 const UDPStreamConfig &udpstream_config = config.udpstreams[i];
201                 MarkPool *mark_pool = NULL;
202                 if (udpstream_config.mark_pool != -1) {
203                         mark_pool = mark_pools[udpstream_config.mark_pool];
204                 }
205                 int stream_index = servers->add_udpstream(udpstream_config.dst, mark_pool);
206
207                 string src = udpstream_config.src;
208                 if (!src.empty()) {
209                         multimap<string, InputWithRefcount>::iterator input_it = inputs->find(src);
210                         assert(input_it != inputs->end());
211                         input_it->second.input->add_destination(stream_index);
212                         ++input_it->second.refcount;
213                 }
214         }
215 }
216         
217 void open_logs(const vector<LogConfig> &log_destinations)
218 {
219         for (size_t i = 0; i < log_destinations.size(); ++i) {
220                 if (log_destinations[i].type == LogConfig::LOG_TYPE_FILE) {
221                         add_log_destination_file(log_destinations[i].filename);
222                 } else if (log_destinations[i].type == LogConfig::LOG_TYPE_CONSOLE) {
223                         add_log_destination_console();
224                 } else if (log_destinations[i].type == LogConfig::LOG_TYPE_SYSLOG) {
225                         add_log_destination_syslog();
226                 } else {
227                         assert(false);
228                 }
229         }
230         start_logging();
231 }
232         
233 bool dry_run_config(const std::string &argv0, const std::string &config_filename)
234 {
235         char *argv0_copy = strdup(argv0.c_str());
236         char *config_filename_copy = strdup(config_filename.c_str());
237
238         pid_t pid = fork();
239         switch (pid) {
240         case -1:
241                 log_perror("fork()");
242                 free(argv0_copy);
243                 free(config_filename_copy);
244                 return false;
245         case 0:
246                 // Child.
247                 execlp(argv0_copy, argv0_copy, "--test-config", config_filename_copy, NULL);
248                 log_perror(argv0_copy);
249                 _exit(1);
250         default:
251                 // Parent.
252                 break;
253         }
254                 
255         free(argv0_copy);
256         free(config_filename_copy);
257
258         int status;
259         pid_t err;
260         do {
261                 err = waitpid(pid, &status, 0);
262         } while (err == -1 && errno == EINTR);
263
264         if (err == -1) {
265                 log_perror("waitpid()");
266                 return false;
267         }       
268
269         return (WIFEXITED(status) && WEXITSTATUS(status) == 0);
270 }
271
272 int main(int argc, char **argv)
273 {
274         signal(SIGHUP, hup);
275         signal(SIGINT, hup);
276         signal(SIGUSR1, do_nothing);  // Used in internal signalling.
277         signal(SIGPIPE, SIG_IGN);
278         
279         // Parse options.
280         int state_fd = -1;
281         bool test_config = false;
282         for ( ;; ) {
283                 static const option long_options[] = {
284                         { "state", required_argument, 0, 's' },
285                         { "test-config", no_argument, 0, 't' },
286                         { 0, 0, 0, 0 }
287                 };
288                 int option_index = 0;
289                 int c = getopt_long(argc, argv, "s:t", long_options, &option_index);
290      
291                 if (c == -1) {
292                         break;
293                 }
294                 switch (c) {
295                 case 's':
296                         state_fd = atoi(optarg);
297                         break;
298                 case 't':
299                         test_config = true;
300                         break;
301                 default:
302                         fprintf(stderr, "Unknown option '%s'\n", argv[option_index]);
303                         exit(1);
304                 }
305         }
306
307         string config_filename = "cubemap.config";
308         if (optind < argc) {
309                 config_filename = argv[optind++];
310         }
311
312         // Canonicalize argv[0] and config_filename.
313         char argv0_canon[PATH_MAX];
314         char config_filename_canon[PATH_MAX];
315
316         if (realpath(argv[0], argv0_canon) == NULL) {
317                 log_perror(argv[0]);
318                 exit(1);
319         }
320         if (realpath(config_filename.c_str(), config_filename_canon) == NULL) {
321                 log_perror(config_filename.c_str());
322                 exit(1);
323         }
324
325         // Now parse the configuration file.
326         Config config;
327         if (!parse_config(config_filename_canon, &config)) {
328                 exit(1);
329         }
330         if (test_config) {
331                 exit(0);
332         }
333         
334         // Ideally we'd like to daemonize only when we've started up all threads etc.,
335         // but daemon() forks, which is not good in multithreaded software, so we'll
336         // have to do it here.
337         if (config.daemonize) {
338                 if (daemon(0, 0) == -1) {
339                         log_perror("daemon");
340                         exit(1);
341                 }
342         }
343
344 start:
345         // Open logs as soon as possible.
346         open_logs(config.log_destinations);
347
348         log(INFO, "Cubemap " SERVER_VERSION " starting.");
349         if (config.access_log_file.empty()) {
350                 // Create a dummy logger.
351                 access_log = new AccessLogThread();
352         } else {
353                 access_log = new AccessLogThread(config.access_log_file);
354         }
355         access_log->run();
356
357         servers = new ServerPool(config.num_servers);
358
359         CubemapStateProto loaded_state;
360         struct timeval serialize_start;
361         set<string> deserialized_urls;
362         map<int, Acceptor *> deserialized_acceptors;
363         multimap<string, InputWithRefcount> inputs;  // multimap due to older versions without deduplication.
364         if (state_fd != -1) {
365                 log(INFO, "Deserializing state from previous process...");
366                 string serialized;
367                 if (!read_tempfile(state_fd, &serialized)) {
368                         exit(1);
369                 }
370                 if (!loaded_state.ParseFromString(serialized)) {
371                         log(ERROR, "Failed deserialization of state.");
372                         exit(1);
373                 }
374
375                 serialize_start.tv_sec = loaded_state.serialize_start_sec();
376                 serialize_start.tv_usec = loaded_state.serialize_start_usec();
377
378                 // Deserialize the streams.
379                 for (int i = 0; i < loaded_state.streams_size(); ++i) {
380                         const StreamProto &stream = loaded_state.streams(i);
381
382                         vector<int> data_fds;
383                         for (int j = 0; j < stream.data_fds_size(); ++j) {
384                                 data_fds.push_back(stream.data_fds(j));
385                         }
386
387                         // Older versions stored the data once in the protobuf instead of
388                         // sending around file descriptors.
389                         if (data_fds.empty() && stream.has_data()) {
390                                 data_fds.push_back(make_tempfile(stream.data()));
391                         }
392
393                         servers->add_stream_from_serialized(stream, data_fds);
394                         deserialized_urls.insert(stream.url());
395                 }
396
397                 // Deserialize the inputs. Note that we don't actually add them to any stream yet.
398                 for (int i = 0; i < loaded_state.inputs_size(); ++i) {
399                         InputWithRefcount iwr;
400                         iwr.input = create_input(loaded_state.inputs(i));
401                         iwr.refcount = 0;
402                         inputs.insert(make_pair(loaded_state.inputs(i).url(), iwr));
403                 } 
404
405                 // Deserialize the acceptors.
406                 for (int i = 0; i < loaded_state.acceptors_size(); ++i) {
407                         deserialized_acceptors.insert(make_pair(
408                                 loaded_state.acceptors(i).port(),
409                                 new Acceptor(loaded_state.acceptors(i))));
410                 }
411
412                 log(INFO, "Deserialization done.");
413         }
414
415         // Add any new inputs coming from the config.
416         create_config_inputs(config, &inputs);
417         
418         // Find all streams in the configuration file, create them, and connect to the inputs.
419         create_streams(config, deserialized_urls, &inputs);
420         vector<Acceptor *> acceptors = create_acceptors(config, &deserialized_acceptors);
421         
422         // Put back the existing clients. It doesn't matter which server we
423         // allocate them to, so just do round-robin. However, we need to add
424         // them after the mark pools have been set up.
425         for (int i = 0; i < loaded_state.clients_size(); ++i) {
426                 servers->add_client_from_serialized(loaded_state.clients(i));
427         }
428         
429         servers->run();
430
431         // Now delete all inputs that are longer in use, and start the others.
432         for (multimap<string, InputWithRefcount>::iterator input_it = inputs.begin();
433              input_it != inputs.end(); ) {
434                 if (input_it->second.refcount == 0) {
435                         log(WARNING, "Input '%s' no longer in use, closing.",
436                             input_it->first.c_str());
437                         input_it->second.input->close_socket();
438                         delete input_it->second.input;
439                         inputs.erase(input_it++);
440                 } else {
441                         input_it->second.input->run();
442                         ++input_it;
443                 }
444         }
445
446         // Start writing statistics.
447         StatsThread *stats_thread = NULL;
448         if (!config.stats_file.empty()) {
449                 stats_thread = new StatsThread(config.stats_file, config.stats_interval);
450                 stats_thread->run();
451         }
452
453         InputStatsThread *input_stats_thread = NULL;
454         if (!config.input_stats_file.empty()) {
455                 vector<Input*> inputs_no_refcount;
456                 for (multimap<string, InputWithRefcount>::iterator input_it = inputs.begin();
457                      input_it != inputs.end(); ++input_it) {
458                         inputs_no_refcount.push_back(input_it->second.input);
459                 }
460
461                 input_stats_thread = new InputStatsThread(config.input_stats_file, config.input_stats_interval, inputs_no_refcount);
462                 input_stats_thread->run();
463         }
464
465         struct timeval server_start;
466         gettimeofday(&server_start, NULL);
467         if (state_fd != -1) {
468                 // Measure time from we started deserializing (below) to now, when basically everything
469                 // is up and running. This is, in other words, a conservative estimate of how long our
470                 // “glitch” period was, not counting of course reconnects if the configuration changed.
471                 double glitch_time = server_start.tv_sec - serialize_start.tv_sec +
472                         1e-6 * (server_start.tv_usec - serialize_start.tv_usec);
473                 log(INFO, "Re-exec happened in approx. %.0f ms.", glitch_time * 1000.0);
474         }
475
476         while (!hupped) {
477                 usleep(100000);
478         }
479
480         // OK, we've been HUPed. Time to shut down everything, serialize, and re-exec.
481         gettimeofday(&serialize_start, NULL);
482
483         if (input_stats_thread != NULL) {
484                 input_stats_thread->stop();
485                 delete input_stats_thread;
486         }
487         if (stats_thread != NULL) {
488                 stats_thread->stop();
489                 delete stats_thread;
490         }
491         for (size_t i = 0; i < acceptors.size(); ++i) {
492                 acceptors[i]->stop();
493         }
494         for (multimap<string, InputWithRefcount>::iterator input_it = inputs.begin();
495              input_it != inputs.end();
496              ++input_it) {
497                 input_it->second.input->stop();
498         }
499         servers->stop();
500
501         CubemapStateProto state;
502         if (stopped) {
503                 log(INFO, "Shutting down.");
504         } else {
505                 log(INFO, "Serializing state and re-execing...");
506                 state = collect_state(
507                         serialize_start, acceptors, inputs, servers);
508                 string serialized;
509                 state.SerializeToString(&serialized);
510                 state_fd = make_tempfile(serialized);
511                 if (state_fd == -1) {
512                         exit(1);
513                 }
514         }
515         delete servers;
516
517         for (unsigned i = 0; i < mark_pools.size(); ++i) {
518                 delete mark_pools[i];
519         }
520         mark_pools.clear();
521
522         access_log->stop();
523         delete access_log;
524         shut_down_logging();
525
526         if (stopped) {
527                 exit(0);
528         }
529
530         // OK, so the signal was SIGHUP. Check that the new config is okay, then exec the new binary.
531         if (!dry_run_config(argv0_canon, config_filename_canon)) {
532                 open_logs(config.log_destinations);
533                 log(ERROR, "%s --test-config failed. Restarting old version instead of new.", argv[0]);
534                 hupped = false;
535                 shut_down_logging();
536                 goto start;
537         }
538          
539         char buf[16];
540         sprintf(buf, "%d", state_fd);
541
542         for ( ;; ) {
543                 execlp(argv0_canon, argv0_canon, config_filename_canon, "--state", buf, NULL);
544                 open_logs(config.log_destinations);
545                 log_perror("execlp");
546                 log(ERROR, "re-exec of %s failed. Waiting 0.2 seconds and trying again...", argv0_canon);
547                 shut_down_logging();
548                 usleep(200000);
549         }
550 }