When re-execing, try with --test-config first, and revert if it fails.
[cubemap] / main.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <stdint.h>
4 #include <assert.h>
5 #include <getopt.h>
6 #include <arpa/inet.h>
7 #include <sys/socket.h>
8 #include <pthread.h>
9 #include <sys/types.h>
10 #include <sys/ioctl.h>
11 #include <sys/poll.h>
12 #include <sys/time.h>
13 #include <sys/types.h>
14 #include <sys/wait.h>
15 #include <signal.h>
16 #include <errno.h>
17 #include <ctype.h>
18 #include <fcntl.h>
19 #include <vector>
20 #include <string>
21 #include <map>
22 #include <set>
23
24 #include "acceptor.h"
25 #include "config.h"
26 #include "markpool.h"
27 #include "metacube.h"
28 #include "parse.h"
29 #include "server.h"
30 #include "serverpool.h"
31 #include "input.h"
32 #include "stats.h"
33 #include "version.h"
34 #include "state.pb.h"
35
36 using namespace std;
37
38 ServerPool *servers = NULL;
39 volatile bool hupped = false;
40
41 void hup(int ignored)
42 {
43         hupped = true;
44 }
45
46 // Serialize the given state to a file descriptor, and return the (still open)
47 // descriptor.
48 int make_tempfile(const CubemapStateProto &state)
49 {
50         char tmpl[] = "/tmp/cubemapstate.XXXXXX";
51         int state_fd = mkstemp(tmpl);
52         if (state_fd == -1) {
53                 perror("mkstemp");
54                 exit(1);
55         }
56
57         string serialized;
58         state.SerializeToString(&serialized);
59
60         const char *ptr = serialized.data();
61         size_t to_write = serialized.size();
62         while (to_write > 0) {
63                 ssize_t ret = write(state_fd, ptr, to_write);
64                 if (ret == -1) {
65                         perror("write");
66                         exit(1);
67                 }
68
69                 ptr += ret;
70                 to_write -= ret;
71         }
72
73         return state_fd;
74 }
75
76 CubemapStateProto collect_state(const timeval &serialize_start,
77                                 const vector<Acceptor *> acceptors,
78                                 const vector<Input *> inputs,
79                                 ServerPool *servers)
80 {
81         CubemapStateProto state = servers->serialize();  // Fills streams() and clients().
82         state.set_serialize_start_sec(serialize_start.tv_sec);
83         state.set_serialize_start_usec(serialize_start.tv_usec);
84         
85         for (size_t i = 0; i < acceptors.size(); ++i) {
86                 state.add_acceptors()->MergeFrom(acceptors[i]->serialize());
87         }
88
89         for (size_t i = 0; i < inputs.size(); ++i) {
90                 state.add_inputs()->MergeFrom(inputs[i]->serialize());
91         }
92
93         return state;
94 }
95
96 // Read the state back from the file descriptor made by make_tempfile,
97 // and close it.
98 CubemapStateProto read_tempfile(int state_fd)
99 {
100         if (lseek(state_fd, 0, SEEK_SET) == -1) {
101                 perror("lseek");
102                 exit(1);
103         }
104
105         string serialized;
106         char buf[4096];
107         for ( ;; ) {
108                 ssize_t ret = read(state_fd, buf, sizeof(buf));
109                 if (ret == -1) {
110                         perror("read");
111                         exit(1);
112                 }
113                 if (ret == 0) {
114                         // EOF.
115                         break;
116                 }
117
118                 serialized.append(string(buf, buf + ret));
119         }
120
121         close(state_fd);  // Implicitly deletes the file.
122
123         CubemapStateProto state;
124         if (!state.ParseFromString(serialized)) {
125                 fprintf(stderr, "PANIC: Failed deserialization of state.\n");
126                 exit(1);
127         }
128
129         return state;
130 }
131         
132 // Find all port statements in the configuration file, and create acceptors for htem.
133 vector<Acceptor *> create_acceptors(
134         const Config &config,
135         map<int, Acceptor *> *deserialized_acceptors)
136 {
137         vector<Acceptor *> acceptors;
138         for (unsigned i = 0; i < config.acceptors.size(); ++i) {
139                 const AcceptorConfig &acceptor_config = config.acceptors[i];
140                 Acceptor *acceptor = NULL;
141                 map<int, Acceptor *>::iterator deserialized_acceptor_it =
142                         deserialized_acceptors->find(acceptor_config.port);
143                 if (deserialized_acceptor_it != deserialized_acceptors->end()) {
144                         acceptor = deserialized_acceptor_it->second;
145                         deserialized_acceptors->erase(deserialized_acceptor_it);
146                 } else {
147                         int server_sock = create_server_socket(acceptor_config.port, TCP_SOCKET);
148                         acceptor = new Acceptor(server_sock, acceptor_config.port);
149                 }
150                 acceptor->run();
151                 acceptors.push_back(acceptor);
152         }
153
154         // Close all acceptors that are no longer in the configuration file.
155         for (map<int, Acceptor *>::iterator acceptor_it = deserialized_acceptors->begin();
156              acceptor_it != deserialized_acceptors->end();
157              ++acceptor_it) {
158                 acceptor_it->second->close_socket();
159                 delete acceptor_it->second;
160         }
161
162         return acceptors;
163 }
164
165 // Find all streams in the configuration file, and create inputs for them.
166 vector<Input *> create_inputs(const Config &config,
167                               map<string, Input *> *deserialized_inputs)
168 {
169         vector<Input *> inputs;
170         for (unsigned i = 0; i < config.streams.size(); ++i) {
171                 const StreamConfig &stream_config = config.streams[i];
172                 if (stream_config.src.empty()) {
173                         continue;
174                 }
175
176                 string stream_id = stream_config.stream_id;
177                 string src = stream_config.src;
178
179                 Input *input = NULL;
180                 map<string, Input *>::iterator deserialized_input_it =
181                         deserialized_inputs->find(stream_id);
182                 if (deserialized_input_it != deserialized_inputs->end()) {
183                         input = deserialized_input_it->second;
184                         if (input->get_url() != src) {
185                                 fprintf(stderr, "INFO: Stream '%s' has changed URL from '%s' to '%s', restarting input.\n",
186                                         stream_id.c_str(), input->get_url().c_str(), src.c_str());
187                                 input->close_socket();
188                                 delete input;
189                                 input = NULL;
190                         }
191                         deserialized_inputs->erase(deserialized_input_it);
192                 }
193                 if (input == NULL) {
194                         input = create_input(stream_id, src);
195                         if (input == NULL) {
196                                 fprintf(stderr, "ERROR: did not understand URL '%s', clients will not get any data.\n",
197                                         src.c_str());
198                                 continue;
199                         }
200                 }
201                 input->run();
202                 inputs.push_back(input);
203         }
204         return inputs;
205 }
206
207 void create_streams(const Config &config,
208                     const set<string> &deserialized_stream_ids,
209                     map<string, Input *> *deserialized_inputs)
210 {
211         vector<MarkPool *> mark_pools;  // FIXME: leak
212         for (unsigned i = 0; i < config.mark_pools.size(); ++i) {
213                 const MarkPoolConfig &mp_config = config.mark_pools[i];
214                 mark_pools.push_back(new MarkPool(mp_config.from, mp_config.to));
215         }
216
217         set<string> expecting_stream_ids = deserialized_stream_ids;
218         for (unsigned i = 0; i < config.streams.size(); ++i) {
219                 const StreamConfig &stream_config = config.streams[i];
220                 if (deserialized_stream_ids.count(stream_config.stream_id) == 0) {
221                         servers->add_stream(stream_config.stream_id);
222                 }
223                 expecting_stream_ids.erase(stream_config.stream_id);
224
225                 if (stream_config.mark_pool != -1) {
226                         servers->set_mark_pool(stream_config.stream_id,
227                                                mark_pools[stream_config.mark_pool]);
228                 }
229         }
230
231         // Warn about any servers we've lost.
232         // TODO: Make an option (delete=yes?) to actually shut down streams.
233         for (set<string>::const_iterator stream_it = expecting_stream_ids.begin();
234              stream_it != expecting_stream_ids.end();
235              ++stream_it) {
236                 string stream_id = *stream_it;
237                 fprintf(stderr, "WARNING: stream '%s' disappeared from the configuration file.\n",
238                         stream_id.c_str());
239                 fprintf(stderr, "         It will not be deleted, but clients will not get any new inputs.\n");
240                 if (deserialized_inputs->count(stream_id) != 0) {
241                         delete (*deserialized_inputs)[stream_id];
242                         deserialized_inputs->erase(stream_id);
243                 }
244         }
245 }
246         
247 bool dry_run_config(const std::string &argv0, const std::string &config_filename)
248 {
249         char *argv0_copy = strdup(argv0.c_str());
250         char *config_filename_copy = strdup(config_filename.c_str());
251
252         pid_t pid = fork();
253         switch (pid) {
254         case -1:
255                 perror("fork()");
256                 free(argv0_copy);
257                 free(config_filename_copy);
258                 return false;
259         case 0:
260                 // Child.
261                 execlp(argv0_copy, argv0_copy, "--test-config", config_filename_copy, NULL);
262                 perror(argv0_copy);
263                 _exit(1);
264         default:
265                 // Parent.
266                 break;
267         }
268                 
269         free(argv0_copy);
270         free(config_filename_copy);
271
272         int status;
273         pid_t err;
274         do {
275                 err = waitpid(pid, &status, 0);
276         } while (err == -1 && errno == EINTR);
277
278         if (err == -1) {
279                 perror("waitpid()");
280                 return false;
281         }       
282
283         return (WIFEXITED(status) && WEXITSTATUS(status) == 0);
284 }
285
286 int main(int argc, char **argv)
287 {
288         // Parse options.
289         int state_fd = -1;
290         bool test_config = false;
291         for ( ;; ) {
292                 static const option long_options[] = {
293                         { "state", required_argument, 0, 's' },
294                         { "test-config", no_argument, 0, 't' },
295                 };
296                 int option_index = 0;
297                 int c = getopt_long (argc, argv, "s:t", long_options, &option_index);
298      
299                 if (c == -1) {
300                         break;
301                 }
302                 switch (c) {
303                 case 's':
304                         state_fd = atoi(optarg);
305                         break;
306                 case 't':
307                         test_config = true;
308                         break;
309                 default:
310                         assert(false);
311                 }
312         }
313
314         string config_filename = "cubemap.config";
315         if (optind < argc) {
316                 config_filename = argv[optind++];
317         }
318
319         Config config;
320         if (!parse_config(config_filename, &config)) {
321                 exit(1);
322         }
323         if (test_config) {
324                 exit(0);
325         }
326
327 start:
328         fprintf(stderr, "\nCubemap " SERVER_VERSION " starting.\n");
329         servers = new ServerPool(config.num_servers);
330
331         CubemapStateProto loaded_state;
332         struct timeval serialize_start;
333         set<string> deserialized_stream_ids;
334         map<string, Input *> deserialized_inputs;
335         map<int, Acceptor *> deserialized_acceptors;
336         if (state_fd != -1) {
337                 fprintf(stderr, "Deserializing state from previous process... ");
338                 loaded_state = read_tempfile(state_fd);
339
340                 serialize_start.tv_sec = loaded_state.serialize_start_sec();
341                 serialize_start.tv_usec = loaded_state.serialize_start_usec();
342
343                 // Deserialize the streams.
344                 for (int i = 0; i < loaded_state.streams_size(); ++i) {
345                         servers->add_stream_from_serialized(loaded_state.streams(i));
346                         deserialized_stream_ids.insert(loaded_state.streams(i).stream_id());
347                 }
348
349                 // Deserialize the inputs. Note that we don't actually add them to any state yet.
350                 for (int i = 0; i < loaded_state.inputs_size(); ++i) {
351                         deserialized_inputs.insert(make_pair(
352                                 loaded_state.inputs(i).stream_id(),
353                                 create_input(loaded_state.inputs(i))));
354                 } 
355
356                 // Deserialize the acceptors.
357                 for (int i = 0; i < loaded_state.acceptors_size(); ++i) {
358                         deserialized_acceptors.insert(make_pair(
359                                 loaded_state.acceptors(i).port(),
360                                 new Acceptor(loaded_state.acceptors(i))));
361                 }
362
363                 fprintf(stderr, "done.\n");
364         }
365
366         // Find all streams in the configuration file, and create them.
367         create_streams(config, deserialized_stream_ids, &deserialized_inputs);
368
369         servers->run();
370
371         vector<Acceptor *> acceptors = create_acceptors(config, &deserialized_acceptors);
372         vector<Input *> inputs = create_inputs(config, &deserialized_inputs);
373         
374         // All deserialized inputs should now have been taken care of, one way or the other.
375         assert(deserialized_inputs.empty());
376         
377         // Put back the existing clients. It doesn't matter which server we
378         // allocate them to, so just do round-robin. However, we need to add
379         // them after the mark pools have been set up.
380         for (int i = 0; i < loaded_state.clients_size(); ++i) {
381                 servers->add_client_from_serialized(loaded_state.clients(i));
382         }
383
384         // Start writing statistics.
385         StatsThread *stats_thread = NULL;
386         if (!config.stats_file.empty()) {
387                 stats_thread = new StatsThread(config.stats_file, config.stats_interval);
388                 stats_thread->run();
389         } else if (config.stats_interval != -1) {
390                 fprintf(stderr, "WARNING: 'stats_interval' given, but no 'stats_file'. No statistics will be written.\n");
391         }
392
393         signal(SIGHUP, hup);
394         
395         struct timeval server_start;
396         gettimeofday(&server_start, NULL);
397         if (state_fd != -1) {
398                 // Measure time from we started deserializing (below) to now, when basically everything
399                 // is up and running. This is, in other words, a conservative estimate of how long our
400                 // “glitch” period was, not counting of course reconnects if the configuration changed.
401                 double glitch_time = server_start.tv_sec - serialize_start.tv_sec +
402                         1e-6 * (server_start.tv_usec - serialize_start.tv_usec);
403                 fprintf(stderr, "Re-exec happened in approx. %.0f ms.\n", glitch_time * 1000.0);
404         }
405
406         while (!hupped) {
407                 usleep(100000);
408         }
409
410         // OK, we've been HUPed. Time to shut down everything, serialize, and re-exec.
411         gettimeofday(&serialize_start, NULL);
412
413         if (stats_thread != NULL) {
414                 stats_thread->stop();
415         }
416         for (size_t i = 0; i < acceptors.size(); ++i) {
417                 acceptors[i]->stop();
418         }
419         for (size_t i = 0; i < inputs.size(); ++i) {
420                 inputs[i]->stop();
421         }
422         servers->stop();
423
424         fprintf(stderr, "Serializing state and re-execing...\n");
425         state_fd = make_tempfile(collect_state(
426                 serialize_start, acceptors, inputs, servers));
427         delete servers;
428
429         if (!dry_run_config(argv[0], config_filename)) {
430                 fprintf(stderr, "ERROR: %s --test-config failed. Restarting old version instead of new.\n", argv[0]);
431                 hupped = false;
432                 goto start;
433         }
434          
435         char buf[16];
436         sprintf(buf, "%d", state_fd);
437
438         for ( ;; ) {
439                 execlp(argv[0], argv[0], config_filename.c_str(), "--state", buf, NULL);
440                 perror("execlp");
441                 fprintf(stderr, "PANIC: re-exec of %s failed. Waiting 0.2 seconds and trying again...\n", argv[0]);
442                 usleep(200000);
443         }
444 }