]> git.sesse.net Git - cubemap/blob - main.cpp
Move some serialization logic into ServerPool, where it belongs.
[cubemap] / main.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <stdint.h>
4 #include <assert.h>
5 #include <arpa/inet.h>
6 #include <sys/socket.h>
7 #include <pthread.h>
8 #include <sys/types.h>
9 #include <sys/ioctl.h>
10 #include <sys/poll.h>
11 #include <sys/time.h>
12 #include <signal.h>
13 #include <errno.h>
14 #include <ctype.h>
15 #include <fcntl.h>
16 #include <vector>
17 #include <string>
18 #include <map>
19 #include <set>
20
21 #include "acceptor.h"
22 #include "config.h"
23 #include "markpool.h"
24 #include "metacube.h"
25 #include "parse.h"
26 #include "server.h"
27 #include "serverpool.h"
28 #include "input.h"
29 #include "stats.h"
30 #include "version.h"
31 #include "state.pb.h"
32
33 using namespace std;
34
35 ServerPool *servers = NULL;
36 volatile bool hupped = false;
37
38 void hup(int ignored)
39 {
40         hupped = true;
41 }
42
43 // Serialize the given state to a file descriptor, and return the (still open)
44 // descriptor.
45 int make_tempfile(const CubemapStateProto &state)
46 {
47         char tmpl[] = "/tmp/cubemapstate.XXXXXX";
48         int state_fd = mkstemp(tmpl);
49         if (state_fd == -1) {
50                 perror("mkstemp");
51                 exit(1);
52         }
53
54         string serialized;
55         state.SerializeToString(&serialized);
56
57         const char *ptr = serialized.data();
58         size_t to_write = serialized.size();
59         while (to_write > 0) {
60                 ssize_t ret = write(state_fd, ptr, to_write);
61                 if (ret == -1) {
62                         perror("write");
63                         exit(1);
64                 }
65
66                 ptr += ret;
67                 to_write -= ret;
68         }
69
70         return state_fd;
71 }
72
73 CubemapStateProto collect_state(const timeval &serialize_start,
74                                 const vector<Acceptor *> acceptors,
75                                 const vector<Input *> inputs,
76                                 ServerPool *servers)
77 {
78         CubemapStateProto state = servers->serialize();  // Fills streams() and clients().
79         state.set_serialize_start_sec(serialize_start.tv_sec);
80         state.set_serialize_start_usec(serialize_start.tv_usec);
81         
82         for (size_t i = 0; i < acceptors.size(); ++i) {
83                 state.add_acceptors()->MergeFrom(acceptors[i]->serialize());
84         }
85
86         for (size_t i = 0; i < inputs.size(); ++i) {
87                 state.add_inputs()->MergeFrom(inputs[i]->serialize());
88         }
89
90         return state;
91 }
92
93 // Read the state back from the file descriptor made by make_tempfile,
94 // and close it.
95 CubemapStateProto read_tempfile(int state_fd)
96 {
97         if (lseek(state_fd, 0, SEEK_SET) == -1) {
98                 perror("lseek");
99                 exit(1);
100         }
101
102         string serialized;
103         char buf[4096];
104         for ( ;; ) {
105                 ssize_t ret = read(state_fd, buf, sizeof(buf));
106                 if (ret == -1) {
107                         perror("read");
108                         exit(1);
109                 }
110                 if (ret == 0) {
111                         // EOF.
112                         break;
113                 }
114
115                 serialized.append(string(buf, buf + ret));
116         }
117
118         close(state_fd);  // Implicitly deletes the file.
119
120         CubemapStateProto state;
121         if (!state.ParseFromString(serialized)) {
122                 fprintf(stderr, "PANIC: Failed deserialization of state.\n");
123                 exit(1);
124         }
125
126         return state;
127 }
128         
129 // Find all port statements in the configuration file, and create acceptors for htem.
130 vector<Acceptor *> create_acceptors(
131         const Config &config,
132         map<int, Acceptor *> *deserialized_acceptors)
133 {
134         vector<Acceptor *> acceptors;
135         for (unsigned i = 0; i < config.acceptors.size(); ++i) {
136                 const AcceptorConfig &acceptor_config = config.acceptors[i];
137                 Acceptor *acceptor = NULL;
138                 map<int, Acceptor *>::iterator deserialized_acceptor_it =
139                         deserialized_acceptors->find(acceptor_config.port);
140                 if (deserialized_acceptor_it != deserialized_acceptors->end()) {
141                         acceptor = deserialized_acceptor_it->second;
142                         deserialized_acceptors->erase(deserialized_acceptor_it);
143                 } else {
144                         int server_sock = create_server_socket(acceptor_config.port, TCP_SOCKET);
145                         acceptor = new Acceptor(server_sock, acceptor_config.port);
146                 }
147                 acceptor->run();
148                 acceptors.push_back(acceptor);
149         }
150
151         // Close all acceptors that are no longer in the configuration file.
152         for (map<int, Acceptor *>::iterator acceptor_it = deserialized_acceptors->begin();
153              acceptor_it != deserialized_acceptors->end();
154              ++acceptor_it) {
155                 acceptor_it->second->close_socket();
156                 delete acceptor_it->second;
157         }
158
159         return acceptors;
160 }
161
162 // Find all streams in the configuration file, and create inputs for them.
163 vector<Input *> create_inputs(const Config &config,
164                               map<string, Input *> *deserialized_inputs)
165 {
166         vector<Input *> inputs;
167         for (unsigned i = 0; i < config.streams.size(); ++i) {
168                 const StreamConfig &stream_config = config.streams[i];
169                 if (stream_config.src.empty()) {
170                         continue;
171                 }
172
173                 string stream_id = stream_config.stream_id;
174                 string src = stream_config.src;
175
176                 Input *input = NULL;
177                 map<string, Input *>::iterator deserialized_input_it =
178                         deserialized_inputs->find(stream_id);
179                 if (deserialized_input_it != deserialized_inputs->end()) {
180                         input = deserialized_input_it->second;
181                         if (input->get_url() != src) {
182                                 fprintf(stderr, "INFO: Stream '%s' has changed URL from '%s' to '%s', restarting input.\n",
183                                         stream_id.c_str(), input->get_url().c_str(), src.c_str());
184                                 input->close_socket();
185                                 delete input;
186                                 input = NULL;
187                         }
188                         deserialized_inputs->erase(deserialized_input_it);
189                 }
190                 if (input == NULL) {
191                         input = create_input(stream_id, src);
192                         if (input == NULL) {
193                                 fprintf(stderr, "ERROR: did not understand URL '%s', clients will not get any data.\n",
194                                         src.c_str());
195                                 continue;
196                         }
197                 }
198                 input->run();
199                 inputs.push_back(input);
200         }
201         return inputs;
202 }
203
204 void create_streams(const Config &config,
205                     const set<string> &deserialized_stream_ids,
206                     map<string, Input *> *deserialized_inputs)
207 {
208         vector<MarkPool *> mark_pools;  // FIXME: leak
209         for (unsigned i = 0; i < config.mark_pools.size(); ++i) {
210                 const MarkPoolConfig &mp_config = config.mark_pools[i];
211                 mark_pools.push_back(new MarkPool(mp_config.from, mp_config.to));
212         }
213
214         set<string> expecting_stream_ids = deserialized_stream_ids;
215         for (unsigned i = 0; i < config.streams.size(); ++i) {
216                 const StreamConfig &stream_config = config.streams[i];
217                 if (deserialized_stream_ids.count(stream_config.stream_id) == 0) {
218                         servers->add_stream(stream_config.stream_id);
219                 }
220                 expecting_stream_ids.erase(stream_config.stream_id);
221
222                 if (stream_config.mark_pool != -1) {
223                         servers->set_mark_pool(stream_config.stream_id,
224                                                mark_pools[stream_config.mark_pool]);
225                 }
226         }
227
228         // Warn about any servers we've lost.
229         // TODO: Make an option (delete=yes?) to actually shut down streams.
230         for (set<string>::const_iterator stream_it = expecting_stream_ids.begin();
231              stream_it != expecting_stream_ids.end();
232              ++stream_it) {
233                 string stream_id = *stream_it;
234                 fprintf(stderr, "WARNING: stream '%s' disappeared from the configuration file.\n",
235                         stream_id.c_str());
236                 fprintf(stderr, "         It will not be deleted, but clients will not get any new inputs.\n");
237                 if (deserialized_inputs->count(stream_id) != 0) {
238                         delete (*deserialized_inputs)[stream_id];
239                         deserialized_inputs->erase(stream_id);
240                 }
241         }
242 }
243
244 int main(int argc, char **argv)
245 {
246         fprintf(stderr, "\nCubemap " SERVER_VERSION " starting.\n");
247
248         struct timeval serialize_start;
249         bool is_reexec = false;
250
251         string config_filename = (argc == 1) ? "cubemap.config" : argv[1];
252         Config config;
253         if (!parse_config(config_filename, &config)) {
254                 exit(1);
255         }       
256
257         servers = new ServerPool(config.num_servers);
258
259         CubemapStateProto loaded_state;
260         set<string> deserialized_stream_ids;
261         map<string, Input *> deserialized_inputs;
262         map<int, Acceptor *> deserialized_acceptors;
263         if (argc == 4 && strcmp(argv[2], "-state") == 0) {
264                 is_reexec = true;
265
266                 fprintf(stderr, "Deserializing state from previous process... ");
267                 int state_fd = atoi(argv[3]);
268                 loaded_state = read_tempfile(state_fd);
269
270                 serialize_start.tv_sec = loaded_state.serialize_start_sec();
271                 serialize_start.tv_usec = loaded_state.serialize_start_usec();
272
273                 // Deserialize the streams.
274                 for (int i = 0; i < loaded_state.streams_size(); ++i) {
275                         servers->add_stream_from_serialized(loaded_state.streams(i));
276                         deserialized_stream_ids.insert(loaded_state.streams(i).stream_id());
277                 }
278
279                 // Deserialize the inputs. Note that we don't actually add them to any state yet.
280                 for (int i = 0; i < loaded_state.inputs_size(); ++i) {
281                         deserialized_inputs.insert(make_pair(
282                                 loaded_state.inputs(i).stream_id(),
283                                 create_input(loaded_state.inputs(i))));
284                 } 
285
286                 // Convert the acceptor from older serialized formats.
287                 if (loaded_state.has_server_sock() && loaded_state.has_port()) {
288                         AcceptorProto *acceptor = loaded_state.add_acceptors();
289                         acceptor->set_server_sock(loaded_state.server_sock());
290                         acceptor->set_port(loaded_state.port());
291                 }
292
293                 // Deserialize the acceptors.
294                 for (int i = 0; i < loaded_state.acceptors_size(); ++i) {
295                         deserialized_acceptors.insert(make_pair(
296                                 loaded_state.acceptors(i).port(),
297                                 new Acceptor(loaded_state.acceptors(i))));
298                 }
299
300                 fprintf(stderr, "done.\n");
301         }
302
303         // Find all streams in the configuration file, and create them.
304         create_streams(config, deserialized_stream_ids, &deserialized_inputs);
305
306         servers->run();
307
308         vector<Acceptor *> acceptors = create_acceptors(config, &deserialized_acceptors);
309         vector<Input *> inputs = create_inputs(config, &deserialized_inputs);
310         
311         // All deserialized inputs should now have been taken care of, one way or the other.
312         assert(deserialized_inputs.empty());
313         
314         if (is_reexec) {        
315                 // Put back the existing clients. It doesn't matter which server we
316                 // allocate them to, so just do round-robin. However, we need to add
317                 // them after the mark pools have been set up.
318                 for (int i = 0; i < loaded_state.clients_size(); ++i) {
319                         servers->add_client_from_serialized(loaded_state.clients(i));
320                 }
321         }
322
323         // Start writing statistics.
324         StatsThread *stats_thread = NULL;
325         if (!config.stats_file.empty()) {
326                 stats_thread = new StatsThread(config.stats_file, config.stats_interval);
327                 stats_thread->run();
328         } else if (config.stats_interval != -1) {
329                 fprintf(stderr, "WARNING: 'stats_interval' given, but no 'stats_file'. No statistics will be written.\n");
330         }
331
332         signal(SIGHUP, hup);
333         
334         struct timeval server_start;
335         gettimeofday(&server_start, NULL);
336         if (is_reexec) {
337                 // Measure time from we started deserializing (below) to now, when basically everything
338                 // is up and running. This is, in other words, a conservative estimate of how long our
339                 // “glitch” period was, not counting of course reconnects if the configuration changed.
340                 double glitch_time = server_start.tv_sec - serialize_start.tv_sec +
341                         1e-6 * (server_start.tv_usec - serialize_start.tv_usec);
342                 fprintf(stderr, "Re-exec happened in approx. %.0f ms.\n", glitch_time * 1000.0);
343         }
344
345         while (!hupped) {
346                 usleep(100000);
347         }
348
349         // OK, we've been HUPed. Time to shut down everything, serialize, and re-exec.
350         gettimeofday(&serialize_start, NULL);
351
352         if (stats_thread != NULL) {
353                 stats_thread->stop();
354         }
355         for (size_t i = 0; i < acceptors.size(); ++i) {
356                 acceptors[i]->stop();
357         }
358         for (size_t i = 0; i < inputs.size(); ++i) {
359                 inputs[i]->stop();
360         }
361         servers->stop();
362
363         fprintf(stderr, "Serializing state and re-execing...\n");
364         int state_fd = make_tempfile(collect_state(
365                 serialize_start, acceptors, inputs, servers));
366         delete servers;
367          
368         char buf[16];
369         sprintf(buf, "%d", state_fd);
370
371         for ( ;; ) {
372                 execlp(argv[0], argv[0], config_filename.c_str(), "-state", buf, NULL);
373                 perror("execlp");
374                 fprintf(stderr, "PANIC: re-exec of %s failed. Waiting 0.2 seconds and trying again...\n", argv[0]);
375                 usleep(200000);
376         }
377 }