]> git.sesse.net Git - cubemap/blob - serverpool.cpp
f8e64aaeb962e75b139e3cc3e9098161abda6990
[cubemap] / serverpool.cpp
1 #include <assert.h>
2 #include <errno.h>
3 #include <google/protobuf/repeated_field.h>
4 #include <stdlib.h>
5 #include <unistd.h>
6
7 #include "client.h"
8 #include "log.h"
9 #include "server.h"
10 #include "serverpool.h"
11 #include "state.pb.h"
12 #include "util.h"
13
14 using namespace std;
15
16 ServerPool::ServerPool(int size)
17         : servers(new Server[size]),
18           num_servers(size),
19           clients_added(0),
20           num_http_streams(0)
21 {
22 }
23
24 ServerPool::~ServerPool()
25 {
26         delete[] servers;
27
28         for (size_t i = 0; i < udp_streams.size(); ++i) {
29                 delete udp_streams[i];
30         }
31 }
32         
33 CubemapStateProto ServerPool::serialize()
34 {
35         CubemapStateProto state;
36
37         for (int i = 0; i < num_servers; ++i) {
38                 CubemapStateProto local_state = servers[i].serialize();
39
40                 // The stream state should be identical between the servers, so we only store it once,
41                 // save for the fds, which we keep around to distribute to the servers after re-exec.
42                 if (i == 0) {
43                         state.mutable_streams()->MergeFrom(local_state.streams());
44                 } else {
45                         assert(state.streams_size() == local_state.streams_size());
46                         for (int j = 0; j < local_state.streams_size(); ++j) {
47                                 assert(local_state.streams(j).data_fds_size() == 1);
48                                 state.mutable_streams(j)->add_data_fds(local_state.streams(j).data_fds(0));
49                         }
50                 }
51                 for (int j = 0; j < local_state.clients_size(); ++j) {
52                         state.add_clients()->MergeFrom(local_state.clients(j));
53                 }
54         }
55
56         return state;
57 }
58
59 void ServerPool::add_client(int sock)
60 {
61         servers[clients_added++ % num_servers].add_client_deferred(sock);
62 }
63
64 void ServerPool::add_client_from_serialized(const ClientProto &client)
65 {
66         servers[clients_added++ % num_servers].add_client_from_serialized(client);
67 }
68
69 int ServerPool::lookup_stream_by_url(const std::string &url) const
70 {
71         assert(servers != NULL);
72         return servers[0].lookup_stream_by_url(url);
73 }
74
75 int ServerPool::add_stream(const string &url, size_t backlog_size, Stream::Encoding encoding)
76 {
77         // Adding more HTTP streams after UDP streams would cause the UDP stream
78         // indices to move around, which is obviously not good.
79         assert(udp_streams.empty());
80
81         for (int i = 0; i < num_servers; ++i) {
82                 int stream_index = servers[i].add_stream(url, backlog_size, encoding);
83                 assert(stream_index == num_http_streams);
84         }
85         return num_http_streams++;
86 }
87
88 int ServerPool::add_stream_from_serialized(const StreamProto &stream, const vector<int> &data_fds)
89 {
90         // Adding more HTTP streams after UDP streams would cause the UDP stream
91         // indices to move around, which is obviously not good.
92         assert(udp_streams.empty());
93
94         assert(!data_fds.empty());
95         string contents;
96         for (int i = 0; i < num_servers; ++i) {
97                 int data_fd;
98                 if (i < int(data_fds.size())) {
99                         // Reuse one of the existing file descriptors.
100                         data_fd = data_fds[i];
101                 } else {
102                         // Clone the first one.
103                         if (contents.empty()) {
104                                 if (!read_tempfile(data_fds[0], &contents)) {
105                                         exit(1);
106                                 }
107                         }
108                         data_fd = make_tempfile(contents);
109                 }
110
111                 int stream_index = servers[i].add_stream_from_serialized(stream, data_fd);
112                 assert(stream_index == num_http_streams);
113         }
114
115         // Close and delete any leftovers, if the number of servers was reduced.
116         for (size_t i = num_servers; i < data_fds.size(); ++i) {
117                 safe_close(data_fds[i]);  // Implicitly deletes the file.
118         }
119
120         return num_http_streams++;
121 }
122         
123 int ServerPool::add_udpstream(const sockaddr_in6 &dst, MarkPool *mark_pool)
124 {
125         udp_streams.push_back(new UDPStream(dst, mark_pool));
126         return num_http_streams + udp_streams.size() - 1;
127 }
128
129 void ServerPool::set_header(int stream_index, const string &http_header, const string &stream_header)
130 {
131         assert(stream_index >= 0 && stream_index < ssize_t(num_http_streams + udp_streams.size()));
132
133         if (stream_index >= num_http_streams) {
134                 // UDP stream. TODO: Log which stream this is.
135                 if (!stream_header.empty()) {
136                         log(WARNING, "Trying to send stream format with headers to a UDP destination. This is unlikely to work well.");
137                 }
138
139                 // Ignore the HTTP header.
140                 return;
141         }
142
143         // HTTP stream.
144         for (int i = 0; i < num_servers; ++i) {
145                 servers[i].set_header(stream_index, http_header, stream_header);
146         }
147 }
148
149 void ServerPool::add_data(int stream_index, const char *data, size_t bytes)
150 {
151         assert(stream_index >= 0 && stream_index < ssize_t(num_http_streams + udp_streams.size()));
152
153         if (stream_index >= num_http_streams) {
154                 // UDP stream.
155                 udp_streams[stream_index - num_http_streams]->send(data, bytes);
156                 return;
157         }
158
159         // HTTP stream.
160         for (int i = 0; i < num_servers; ++i) {
161                 servers[i].add_data_deferred(stream_index, data, bytes);
162         }
163 }
164
165 void ServerPool::run()
166 {
167         for (int i = 0; i < num_servers; ++i) {
168                 servers[i].run();
169         }
170 }
171         
172 void ServerPool::stop()
173 {
174         for (int i = 0; i < num_servers; ++i) {
175                 servers[i].stop();
176         }
177 }
178         
179 vector<ClientStats> ServerPool::get_client_stats() const
180 {
181         vector<ClientStats> ret;
182         for (int i = 0; i < num_servers; ++i) {
183                 vector<ClientStats> stats = servers[i].get_client_stats();
184                 ret.insert(ret.end(), stats.begin(), stats.end());
185         }
186         return ret;
187 }
188         
189 void ServerPool::set_mark_pool(int stream_index, MarkPool *mark_pool)
190 {
191         for (int i = 0; i < num_servers; ++i) {
192                 servers[i].set_mark_pool(stream_index, mark_pool);
193         }       
194 }
195
196 void ServerPool::set_backlog_size(int stream_index, size_t new_size)
197 {
198         for (int i = 0; i < num_servers; ++i) {
199                 servers[i].set_backlog_size(stream_index, new_size);
200         }       
201 }
202
203 void ServerPool::set_encoding(int stream_index, Stream::Encoding encoding)
204 {
205         for (int i = 0; i < num_servers; ++i) {
206                 servers[i].set_encoding(stream_index, encoding);
207         }       
208 }