Make Server stoppable.
[cubemap] / server.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <stdint.h>
4 #include <assert.h>
5 #include <arpa/inet.h>
6 #include <curl/curl.h>
7 #include <sys/socket.h>
8 #include <pthread.h>
9 #include <sys/types.h>
10 #include <sys/ioctl.h>
11 #include <sys/epoll.h>
12 #include <errno.h>
13 #include <vector>
14 #include <string>
15 #include <map>
16 #include <algorithm>
17
18 #include "metacube.h"
19 #include "server.h"
20 #include "mutexlock.h"
21
22 using namespace std;
23
24 Server::Server()
25 {
26         pthread_mutex_init(&mutex, NULL);
27
28         epoll_fd = epoll_create(1024);  // Size argument is ignored.
29         if (epoll_fd == -1) {
30                 perror("epoll_fd");
31                 exit(1);
32         }
33 }
34
35 void Server::run()
36 {
37         should_stop = false;
38         
39         // Joinable is already the default, but it's good to be certain.
40         pthread_attr_t attr;
41         pthread_attr_init(&attr);
42         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
43         pthread_create(&worker_thread, &attr, Server::do_work_thunk, this);
44 }
45
46 void Server::stop()
47 {
48         {
49                 MutexLock lock(&mutex);
50                 should_stop = true;
51         }
52
53         if (pthread_join(worker_thread, NULL) == -1) {
54                 perror("pthread_join");
55                 exit(1);
56         }
57 }
58
59 void *Server::do_work_thunk(void *arg)
60 {
61         Server *server = static_cast<Server *>(arg);
62         server->do_work();
63         return NULL;
64 }
65
66 void Server::do_work()
67 {
68         for ( ;; ) {
69                 int nfds = epoll_wait(epoll_fd, events, EPOLL_MAX_EVENTS, EPOLL_TIMEOUT_MS);
70                 if (nfds == -1) {
71                         perror("epoll_wait");
72                         exit(1);
73                 }
74
75                 MutexLock lock(&mutex);  // We release the mutex between iterations.
76         
77                 if (should_stop) {
78                         return;
79                 }
80         
81                 for (int i = 0; i < nfds; ++i) {
82                         int fd = events[i].data.fd;
83                         assert(clients.count(fd) != 0);
84                         Client *client = &clients[fd];
85
86                         if (events[i].events & (EPOLLERR | EPOLLRDHUP | EPOLLHUP)) {
87                                 close_client(client);
88                                 continue;
89                         }
90
91                         process_client(client);
92                 }
93         }
94 }
95         
96 void Server::add_client(int sock)
97 {
98         MutexLock lock(&mutex);
99         Client new_client;
100         new_client.sock = sock;
101         new_client.request.reserve(1024);
102         new_client.state = Client::READING_REQUEST;
103         new_client.header_bytes_sent = 0;
104         new_client.bytes_sent = 0;
105
106         clients.insert(make_pair(sock, new_client));
107
108         // Start listening on data from this socket.
109         epoll_event ev;
110         ev.events = EPOLLIN | EPOLLRDHUP;
111         ev.data.fd = sock;
112         if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, sock, &ev) == -1) {
113                 perror("epoll_ctl(EPOLL_CTL_ADD)");
114                 exit(1);
115         }
116 }
117         
118 void Server::add_stream(const string &stream_id)
119 {
120         MutexLock lock(&mutex);
121         streams.insert(make_pair(stream_id, Stream()));
122 }
123         
124 void Server::set_header(const string &stream_id, const string &header)
125 {
126         MutexLock lock(&mutex);
127         assert(streams.count(stream_id) != 0);
128         streams[stream_id].header = header;
129 }
130         
131 void Server::add_data(const string &stream_id, const char *data, size_t bytes)
132 {
133         if (bytes == 0) {
134                 return;
135         }
136
137         MutexLock lock(&mutex);
138         assert(streams.count(stream_id) != 0);
139         Stream *stream = &streams[stream_id];
140         size_t pos = stream->data_size % BACKLOG_SIZE;
141         stream->data_size += bytes;
142
143         if (pos + bytes > BACKLOG_SIZE) {
144                 size_t to_copy = BACKLOG_SIZE - pos;
145                 memcpy(stream->data + pos, data, to_copy);
146                 data += to_copy;
147                 bytes -= to_copy;
148                 pos = 0;
149         }
150
151         memcpy(stream->data + pos, data, bytes);
152         wake_up_all_clients();
153 }
154         
155 void Server::process_client(Client *client)
156 {
157         switch (client->state) {
158         case Client::READING_REQUEST: {
159                 // Try to read more of the request.
160                 char buf[1024];
161                 int ret = read(client->sock, buf, sizeof(buf));
162                 if (ret == -1) {
163                         perror("read");
164                         close_client(client);
165                         return;
166                 }
167                 if (ret == 0) {
168                         // No data? This really means that we were triggered for something else than
169                         // POLLIN (which suggests a logic error in epoll).
170                         fprintf(stderr, "WARNING: fd %d returned unexpectedly 0 bytes!\n", client->sock);
171                         close_client(client);
172                         return;
173                 }
174
175                 // Guard against overlong requests gobbling up all of our space.
176                 if (client->request.size() + ret > MAX_CLIENT_REQUEST) {
177                         fprintf(stderr, "WARNING: fd %d sent overlong request!\n", client->sock);
178                         close_client(client);
179                         return;
180                 }       
181
182                 // See if we have \r\n\r\n anywhere in the request. We start three bytes
183                 // before what we just appended, in case we just got the final character.
184                 size_t existing_req_bytes = client->request.size();
185                 client->request.append(string(buf, buf + ret));
186         
187                 size_t start_at = (existing_req_bytes >= 3 ? existing_req_bytes - 3 : 0);
188                 const char *ptr = reinterpret_cast<char *>(
189                         memmem(client->request.data() + start_at, client->request.size() - start_at,
190                                "\r\n\r\n", 4));
191                 if (ptr == NULL) {
192                         // OK, we don't have the entire header yet. Fine; we'll get it later.
193                         return;
194                 }
195
196                 if (ptr != client->request.data() + client->request.size() - 4) {
197                         fprintf(stderr, "WARNING: fd %d had junk data after request!\n", client->sock);
198                         close_client(client);
199                         return;
200                 }
201
202                 parse_request(client);
203                 break;
204         }
205         case Client::SENDING_HEADER: {
206                 int ret = write(client->sock,
207                                 client->header.data() + client->header_bytes_sent,
208                                 client->header.size() - client->header_bytes_sent);
209                 if (ret == -1) {
210                         perror("write");
211                         close_client(client);
212                         return;
213                 }
214                 
215                 client->header_bytes_sent += ret;
216                 assert(client->header_bytes_sent <= client->header.size());
217
218                 if (client->header_bytes_sent < client->header.size()) {
219                         // We haven't sent all yet. Fine; we'll do that later.
220                         return;
221                 }
222
223                 // We're done sending the header! Clear the entire header to release some memory.
224                 client->header.clear();
225
226                 // Start sending from the end. In other words, we won't send any of the backlog,
227                 // but we'll start sending immediately as we get data.
228                 client->state = Client::SENDING_DATA;
229                 client->bytes_sent = streams[client->stream_id].data_size;
230                 break;
231         }
232         case Client::SENDING_DATA: {
233                 // See if there's some data we've lost. Ideally, we should drop to a block boundary,
234                 // but resync will be the mux's problem.
235                 const Stream &stream = streams[client->stream_id];
236                 size_t bytes_to_send = stream.data_size - client->bytes_sent;
237                 if (bytes_to_send > BACKLOG_SIZE) {
238                         fprintf(stderr, "WARNING: fd %d lost %lld bytes, maybe too slow connection\n",
239                                 client->sock,
240                                 (long long int)(bytes_to_send - BACKLOG_SIZE));
241                         client->bytes_sent = streams[client->stream_id].data_size - BACKLOG_SIZE;
242                         bytes_to_send = BACKLOG_SIZE;
243                 }
244
245                 // See if we need to split across the circular buffer.
246                 ssize_t ret;
247                 if ((client->bytes_sent % BACKLOG_SIZE) + bytes_to_send > BACKLOG_SIZE) {
248                         size_t bytes_first_part = BACKLOG_SIZE - (client->bytes_sent % BACKLOG_SIZE);
249
250                         iovec iov[2];
251                         iov[0].iov_base = const_cast<char *>(stream.data + (client->bytes_sent % BACKLOG_SIZE));
252                         iov[0].iov_len = bytes_first_part;
253
254                         iov[1].iov_base = const_cast<char *>(stream.data);
255                         iov[1].iov_len = bytes_to_send - bytes_first_part;
256
257                         ret = writev(client->sock, iov, 2);
258                 } else {
259                         ret = write(client->sock,
260                                     stream.data + (client->bytes_sent % BACKLOG_SIZE),
261                                     bytes_to_send);
262                 }
263                 if (ret == -1) {
264                         perror("write/writev");
265                         close_client(client);
266                         return;
267                 }
268                 client->bytes_sent += ret;
269
270                 if (client->bytes_sent == stream.data_size) {
271                         // We don't have any more data for this client, so put it to sleep.
272                         put_client_to_sleep(client);
273                 }
274                 break;
275         }
276         default:
277                 assert(false);
278         }
279 }
280
281 void Server::parse_request(Client *client)
282 {
283         // TODO: Actually parse the request. :-)
284         client->stream_id = "stream";
285         client->request.clear();
286
287         // Construct the header.
288         client->header = "HTTP/1.0 200 OK\r\n  Content-type: video/x-flv\r\nCache-Control: no-cache\r\nContent-type: todo/fixme\r\n\r\n" +
289                 streams[client->stream_id].header;
290
291         // Switch states.
292         client->state = Client::SENDING_HEADER;
293
294         epoll_event ev;
295         ev.events = EPOLLOUT | EPOLLRDHUP;
296         ev.data.fd = client->sock;
297
298         if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, client->sock, &ev) == -1) {
299                 perror("epoll_ctl(EPOLL_CTL_MOD)");
300                 exit(1);
301         }
302 }
303         
304 void Server::close_client(Client *client)
305 {
306         if (epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client->sock, NULL) == -1) {
307                 perror("epoll_ctl(EPOLL_CTL_DEL)");
308                 exit(1);
309         }
310
311         // This client could be sleeping, so we'll need to fix that. (Argh, O(n).)
312         vector<int>::iterator new_end =
313                 remove(sleeping_clients.begin(), sleeping_clients.end(), client->sock);
314         sleeping_clients.erase(new_end, sleeping_clients.end());
315         
316         // Bye-bye!
317         close(client->sock);
318         clients.erase(client->sock);
319 }
320         
321 void Server::put_client_to_sleep(Client *client)
322 {
323         epoll_event ev;
324         ev.events = EPOLLRDHUP;
325         ev.data.fd = client->sock;
326
327         if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, client->sock, &ev) == -1) {
328                 perror("epoll_ctl(EPOLL_CTL_MOD)");
329                 exit(1);
330         }
331
332         sleeping_clients.push_back(client->sock);
333 }
334
335 void Server::wake_up_all_clients()
336 {
337         for (unsigned i = 0; i < sleeping_clients.size(); ++i) {
338                 epoll_event ev;
339                 ev.events = EPOLLOUT | EPOLLRDHUP;
340                 ev.data.fd = sleeping_clients[i];
341                 if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, sleeping_clients[i], &ev) == -1) {
342                         perror("epoll_ctl(EPOLL_CTL_MOD)");
343                         exit(1);
344                 }
345         }
346         sleeping_clients.clear();
347 }