]> git.sesse.net Git - cubemap/blob - main.cpp
Split configuration parsing out cleanly from initialization.
[cubemap] / main.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <stdint.h>
4 #include <assert.h>
5 #include <arpa/inet.h>
6 #include <sys/socket.h>
7 #include <pthread.h>
8 #include <sys/types.h>
9 #include <sys/ioctl.h>
10 #include <sys/poll.h>
11 #include <sys/time.h>
12 #include <signal.h>
13 #include <errno.h>
14 #include <ctype.h>
15 #include <fcntl.h>
16 #include <vector>
17 #include <string>
18 #include <map>
19 #include <set>
20
21 #include "acceptor.h"
22 #include "config.h"
23 #include "markpool.h"
24 #include "metacube.h"
25 #include "parse.h"
26 #include "server.h"
27 #include "serverpool.h"
28 #include "input.h"
29 #include "stats.h"
30 #include "version.h"
31 #include "state.pb.h"
32
33 using namespace std;
34
35 ServerPool *servers = NULL;
36 volatile bool hupped = false;
37
38 void hup(int ignored)
39 {
40         hupped = true;
41 }
42
43 // Serialize the given state to a file descriptor, and return the (still open)
44 // descriptor.
45 int make_tempfile(const CubemapStateProto &state)
46 {
47         char tmpl[] = "/tmp/cubemapstate.XXXXXX";
48         int state_fd = mkstemp(tmpl);
49         if (state_fd == -1) {
50                 perror("mkstemp");
51                 exit(1);
52         }
53
54         string serialized;
55         state.SerializeToString(&serialized);
56
57         const char *ptr = serialized.data();
58         size_t to_write = serialized.size();
59         while (to_write > 0) {
60                 ssize_t ret = write(state_fd, ptr, to_write);
61                 if (ret == -1) {
62                         perror("write");
63                         exit(1);
64                 }
65
66                 ptr += ret;
67                 to_write -= ret;
68         }
69
70         return state_fd;
71 }
72
73 CubemapStateProto collect_state(const timeval &serialize_start,
74                                 const vector<Acceptor *> acceptors,
75                                 const vector<Input *> inputs,
76                                 ServerPool *servers,
77                                 int num_servers)
78 {
79         CubemapStateProto state;
80         state.set_serialize_start_sec(serialize_start.tv_sec);
81         state.set_serialize_start_usec(serialize_start.tv_usec);
82         
83         for (size_t i = 0; i < acceptors.size(); ++i) {
84                 state.add_acceptors()->MergeFrom(acceptors[i]->serialize());
85         }
86
87         for (size_t i = 0; i < inputs.size(); ++i) {
88                 state.add_inputs()->MergeFrom(inputs[i]->serialize());
89         }
90
91         for (int i = 0; i < num_servers; ++i) { 
92                 CubemapStateProto local_state = servers->get_server(i)->serialize();
93
94                 // The stream state should be identical between the servers, so we only store it once.
95                 if (i == 0) {
96                         state.mutable_streams()->MergeFrom(local_state.streams());
97                 }
98                 for (int j = 0; j < local_state.clients_size(); ++j) {
99                         state.add_clients()->MergeFrom(local_state.clients(j));
100                 }
101         }
102
103         return state;
104 }
105
106 // Read the state back from the file descriptor made by make_tempfile,
107 // and close it.
108 CubemapStateProto read_tempfile(int state_fd)
109 {
110         if (lseek(state_fd, 0, SEEK_SET) == -1) {
111                 perror("lseek");
112                 exit(1);
113         }
114
115         string serialized;
116         char buf[4096];
117         for ( ;; ) {
118                 ssize_t ret = read(state_fd, buf, sizeof(buf));
119                 if (ret == -1) {
120                         perror("read");
121                         exit(1);
122                 }
123                 if (ret == 0) {
124                         // EOF.
125                         break;
126                 }
127
128                 serialized.append(string(buf, buf + ret));
129         }
130
131         close(state_fd);  // Implicitly deletes the file.
132
133         CubemapStateProto state;
134         if (!state.ParseFromString(serialized)) {
135                 fprintf(stderr, "PANIC: Failed deserialization of state.\n");
136                 exit(1);
137         }
138
139         return state;
140 }
141         
142 // Find all port statements in the configuration file, and create acceptors for htem.
143 vector<Acceptor *> create_acceptors(
144         const Config &config,
145         map<int, Acceptor *> *deserialized_acceptors)
146 {
147         vector<Acceptor *> acceptors;
148         for (unsigned i = 0; i < config.acceptors.size(); ++i) {
149                 const AcceptorConfig &acceptor_config = config.acceptors[i];
150                 Acceptor *acceptor = NULL;
151                 map<int, Acceptor *>::iterator deserialized_acceptor_it =
152                         deserialized_acceptors->find(acceptor_config.port);
153                 if (deserialized_acceptor_it != deserialized_acceptors->end()) {
154                         acceptor = deserialized_acceptor_it->second;
155                         deserialized_acceptors->erase(deserialized_acceptor_it);
156                 } else {
157                         int server_sock = create_server_socket(acceptor_config.port, TCP_SOCKET);
158                         acceptor = new Acceptor(server_sock, acceptor_config.port);
159                 }
160                 acceptor->run();
161                 acceptors.push_back(acceptor);
162         }
163
164         // Close all acceptors that are no longer in the configuration file.
165         for (map<int, Acceptor *>::iterator acceptor_it = deserialized_acceptors->begin();
166              acceptor_it != deserialized_acceptors->end();
167              ++acceptor_it) {
168                 acceptor_it->second->close_socket();
169                 delete acceptor_it->second;
170         }
171
172         return acceptors;
173 }
174
175 // Find all streams in the configuration file, and create inputs for them.
176 vector<Input *> create_inputs(const Config &config,
177                               map<string, Input *> *deserialized_inputs)
178 {
179         vector<Input *> inputs;
180         for (unsigned i = 0; i < config.streams.size(); ++i) {
181                 const StreamConfig &stream_config = config.streams[i];
182                 if (stream_config.src.empty()) {
183                         continue;
184                 }
185
186                 string stream_id = stream_config.stream_id;
187                 string src = stream_config.src;
188
189                 Input *input = NULL;
190                 map<string, Input *>::iterator deserialized_input_it =
191                         deserialized_inputs->find(stream_id);
192                 if (deserialized_input_it != deserialized_inputs->end()) {
193                         input = deserialized_input_it->second;
194                         if (input->get_url() != src) {
195                                 fprintf(stderr, "INFO: Stream '%s' has changed URL from '%s' to '%s', restarting input.\n",
196                                         stream_id.c_str(), input->get_url().c_str(), src.c_str());
197                                 input->close_socket();
198                                 delete input;
199                                 input = NULL;
200                         }
201                         deserialized_inputs->erase(deserialized_input_it);
202                 }
203                 if (input == NULL) {
204                         input = create_input(stream_id, src);
205                         if (input == NULL) {
206                                 fprintf(stderr, "ERROR: did not understand URL '%s', clients will not get any data.\n",
207                                         src.c_str());
208                                 continue;
209                         }
210                 }
211                 input->run();
212                 inputs.push_back(input);
213         }
214         return inputs;
215 }
216
217 void create_streams(const Config &config,
218                     const set<string> &deserialized_stream_ids,
219                     map<string, Input *> *deserialized_inputs)
220 {
221         vector<MarkPool *> mark_pools;  // FIXME: leak
222         for (unsigned i = 0; i < config.mark_pools.size(); ++i) {
223                 const MarkPoolConfig &mp_config = config.mark_pools[i];
224                 mark_pools.push_back(new MarkPool(mp_config.from, mp_config.to));
225         }
226
227         set<string> expecting_stream_ids = deserialized_stream_ids;
228         for (unsigned i = 0; i < config.streams.size(); ++i) {
229                 const StreamConfig &stream_config = config.streams[i];
230                 if (deserialized_stream_ids.count(stream_config.stream_id) == 0) {
231                         servers->add_stream(stream_config.stream_id);
232                 }
233                 expecting_stream_ids.erase(stream_config.stream_id);
234
235                 if (stream_config.mark_pool != -1) {
236                         servers->set_mark_pool(stream_config.stream_id,
237                                                mark_pools[stream_config.mark_pool]);
238                 }
239         }
240
241         // Warn about any servers we've lost.
242         // TODO: Make an option (delete=yes?) to actually shut down streams.
243         for (set<string>::const_iterator stream_it = expecting_stream_ids.begin();
244              stream_it != expecting_stream_ids.end();
245              ++stream_it) {
246                 string stream_id = *stream_it;
247                 fprintf(stderr, "WARNING: stream '%s' disappeared from the configuration file.\n",
248                         stream_id.c_str());
249                 fprintf(stderr, "         It will not be deleted, but clients will not get any new inputs.\n");
250                 if (deserialized_inputs->count(stream_id) != 0) {
251                         delete (*deserialized_inputs)[stream_id];
252                         deserialized_inputs->erase(stream_id);
253                 }
254         }
255 }
256
257 int main(int argc, char **argv)
258 {
259         fprintf(stderr, "\nCubemap " SERVER_VERSION " starting.\n");
260
261         struct timeval serialize_start;
262         bool is_reexec = false;
263
264         string config_filename = (argc == 1) ? "cubemap.config" : argv[1];
265         Config config;
266         if (!parse_config(config_filename, &config)) {
267                 exit(1);
268         }       
269
270         servers = new ServerPool(config.num_servers);
271
272         CubemapStateProto loaded_state;
273         set<string> deserialized_stream_ids;
274         map<string, Input *> deserialized_inputs;
275         map<int, Acceptor *> deserialized_acceptors;
276         if (argc == 4 && strcmp(argv[2], "-state") == 0) {
277                 is_reexec = true;
278
279                 fprintf(stderr, "Deserializing state from previous process... ");
280                 int state_fd = atoi(argv[3]);
281                 loaded_state = read_tempfile(state_fd);
282
283                 serialize_start.tv_sec = loaded_state.serialize_start_sec();
284                 serialize_start.tv_usec = loaded_state.serialize_start_usec();
285
286                 // Deserialize the streams.
287                 for (int i = 0; i < loaded_state.streams_size(); ++i) {
288                         servers->add_stream_from_serialized(loaded_state.streams(i));
289                         deserialized_stream_ids.insert(loaded_state.streams(i).stream_id());
290                 }
291
292                 // Deserialize the inputs. Note that we don't actually add them to any state yet.
293                 for (int i = 0; i < loaded_state.inputs_size(); ++i) {
294                         deserialized_inputs.insert(make_pair(
295                                 loaded_state.inputs(i).stream_id(),
296                                 create_input(loaded_state.inputs(i))));
297                 } 
298
299                 // Convert the acceptor from older serialized formats.
300                 if (loaded_state.has_server_sock() && loaded_state.has_port()) {
301                         AcceptorProto *acceptor = loaded_state.add_acceptors();
302                         acceptor->set_server_sock(loaded_state.server_sock());
303                         acceptor->set_port(loaded_state.port());
304                 }
305
306                 // Deserialize the acceptors.
307                 for (int i = 0; i < loaded_state.acceptors_size(); ++i) {
308                         deserialized_acceptors.insert(make_pair(
309                                 loaded_state.acceptors(i).port(),
310                                 new Acceptor(loaded_state.acceptors(i))));
311                 }
312
313                 fprintf(stderr, "done.\n");
314         }
315
316         // Find all streams in the configuration file, and create them.
317         create_streams(config, deserialized_stream_ids, &deserialized_inputs);
318
319         servers->run();
320
321         vector<Acceptor *> acceptors = create_acceptors(config, &deserialized_acceptors);
322         vector<Input *> inputs = create_inputs(config, &deserialized_inputs);
323         
324         // All deserialized inputs should now have been taken care of, one way or the other.
325         assert(deserialized_inputs.empty());
326         
327         if (is_reexec) {        
328                 // Put back the existing clients. It doesn't matter which server we
329                 // allocate them to, so just do round-robin. However, we need to add
330                 // them after the mark pools have been set up.
331                 for (int i = 0; i < loaded_state.clients_size(); ++i) {
332                         servers->add_client_from_serialized(loaded_state.clients(i));
333                 }
334         }
335
336         // Start writing statistics.
337         StatsThread *stats_thread = NULL;
338         if (!config.stats_file.empty()) {
339                 stats_thread = new StatsThread(config.stats_file, config.stats_interval);
340                 stats_thread->run();
341         } else if (config.stats_interval != -1) {
342                 fprintf(stderr, "WARNING: 'stats_interval' given, but no 'stats_file'. No statistics will be written.\n");
343         }
344
345         signal(SIGHUP, hup);
346         
347         struct timeval server_start;
348         gettimeofday(&server_start, NULL);
349         if (is_reexec) {
350                 // Measure time from we started deserializing (below) to now, when basically everything
351                 // is up and running. This is, in other words, a conservative estimate of how long our
352                 // “glitch” period was, not counting of course reconnects if the configuration changed.
353                 double glitch_time = server_start.tv_sec - serialize_start.tv_sec +
354                         1e-6 * (server_start.tv_usec - serialize_start.tv_usec);
355                 fprintf(stderr, "Re-exec happened in approx. %.0f ms.\n", glitch_time * 1000.0);
356         }
357
358         while (!hupped) {
359                 usleep(100000);
360         }
361
362         // OK, we've been HUPed. Time to shut down everything, serialize, and re-exec.
363         gettimeofday(&serialize_start, NULL);
364
365         if (stats_thread != NULL) {
366                 stats_thread->stop();
367         }
368         for (size_t i = 0; i < acceptors.size(); ++i) {
369                 acceptors[i]->stop();
370         }
371         for (size_t i = 0; i < inputs.size(); ++i) {
372                 inputs[i]->stop();
373         }
374         servers->stop();
375
376         fprintf(stderr, "Serializing state and re-execing...\n");
377         int state_fd = make_tempfile(collect_state(
378                 serialize_start, acceptors, inputs, servers, config.num_servers));
379         delete servers;
380          
381         char buf[16];
382         sprintf(buf, "%d", state_fd);
383
384         for ( ;; ) {
385                 execlp(argv[0], argv[0], config_filename.c_str(), "-state", buf, NULL);
386                 perror("execlp");
387                 fprintf(stderr, "PANIC: re-exec of %s failed. Waiting 0.2 seconds and trying again...\n", argv[0]);
388                 usleep(200000);
389         }
390 }