Make a useful constructor for Client.
[cubemap] / server.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <stdint.h>
4 #include <assert.h>
5 #include <arpa/inet.h>
6 #include <curl/curl.h>
7 #include <sys/socket.h>
8 #include <pthread.h>
9 #include <sys/types.h>
10 #include <sys/ioctl.h>
11 #include <sys/epoll.h>
12 #include <errno.h>
13 #include <vector>
14 #include <string>
15 #include <map>
16 #include <algorithm>
17
18 #include "metacube.h"
19 #include "server.h"
20 #include "mutexlock.h"
21
22 using namespace std;
23
24 Client::Client(int sock)
25         : state(Client::READING_REQUEST),
26           header_bytes_sent(0),
27           bytes_sent(0)
28 {
29         request.reserve(1024);
30 }
31
32 Server::Server()
33 {
34         pthread_mutex_init(&mutex, NULL);
35
36         epoll_fd = epoll_create(1024);  // Size argument is ignored.
37         if (epoll_fd == -1) {
38                 perror("epoll_fd");
39                 exit(1);
40         }
41 }
42
43 void Server::run()
44 {
45         should_stop = false;
46         
47         // Joinable is already the default, but it's good to be certain.
48         pthread_attr_t attr;
49         pthread_attr_init(&attr);
50         pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
51         pthread_create(&worker_thread, &attr, Server::do_work_thunk, this);
52 }
53
54 void Server::stop()
55 {
56         {
57                 MutexLock lock(&mutex);
58                 should_stop = true;
59         }
60
61         if (pthread_join(worker_thread, NULL) == -1) {
62                 perror("pthread_join");
63                 exit(1);
64         }
65 }
66
67 void *Server::do_work_thunk(void *arg)
68 {
69         Server *server = static_cast<Server *>(arg);
70         server->do_work();
71         return NULL;
72 }
73
74 void Server::do_work()
75 {
76         for ( ;; ) {
77                 int nfds = epoll_wait(epoll_fd, events, EPOLL_MAX_EVENTS, EPOLL_TIMEOUT_MS);
78                 if (nfds == -1) {
79                         perror("epoll_wait");
80                         exit(1);
81                 }
82
83                 MutexLock lock(&mutex);  // We release the mutex between iterations.
84         
85                 if (should_stop) {
86                         return;
87                 }
88         
89                 for (int i = 0; i < nfds; ++i) {
90                         int fd = events[i].data.fd;
91                         assert(clients.count(fd) != 0);
92                         Client *client = &clients[fd];
93
94                         if (events[i].events & (EPOLLERR | EPOLLRDHUP | EPOLLHUP)) {
95                                 close_client(client);
96                                 continue;
97                         }
98
99                         process_client(client);
100                 }
101         }
102 }
103         
104 void Server::add_client(int sock)
105 {
106         MutexLock lock(&mutex);
107         clients.insert(make_pair(sock, Client(sock)));
108
109         // Start listening on data from this socket.
110         epoll_event ev;
111         ev.events = EPOLLIN | EPOLLRDHUP;
112         ev.data.fd = sock;
113         if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, sock, &ev) == -1) {
114                 perror("epoll_ctl(EPOLL_CTL_ADD)");
115                 exit(1);
116         }
117 }
118         
119 void Server::add_stream(const string &stream_id)
120 {
121         MutexLock lock(&mutex);
122         streams.insert(make_pair(stream_id, Stream()));
123 }
124         
125 void Server::set_header(const string &stream_id, const string &header)
126 {
127         MutexLock lock(&mutex);
128         assert(streams.count(stream_id) != 0);
129         streams[stream_id].header = header;
130 }
131         
132 void Server::add_data(const string &stream_id, const char *data, size_t bytes)
133 {
134         if (bytes == 0) {
135                 return;
136         }
137
138         MutexLock lock(&mutex);
139         assert(streams.count(stream_id) != 0);
140         Stream *stream = &streams[stream_id];
141         size_t pos = stream->data_size % BACKLOG_SIZE;
142         stream->data_size += bytes;
143
144         if (pos + bytes > BACKLOG_SIZE) {
145                 size_t to_copy = BACKLOG_SIZE - pos;
146                 memcpy(stream->data + pos, data, to_copy);
147                 data += to_copy;
148                 bytes -= to_copy;
149                 pos = 0;
150         }
151
152         memcpy(stream->data + pos, data, bytes);
153         wake_up_all_clients();
154 }
155         
156 void Server::process_client(Client *client)
157 {
158         switch (client->state) {
159         case Client::READING_REQUEST: {
160                 // Try to read more of the request.
161                 char buf[1024];
162                 int ret = read(client->sock, buf, sizeof(buf));
163                 if (ret == -1) {
164                         perror("read");
165                         close_client(client);
166                         return;
167                 }
168                 if (ret == 0) {
169                         // No data? This really means that we were triggered for something else than
170                         // POLLIN (which suggests a logic error in epoll).
171                         fprintf(stderr, "WARNING: fd %d returned unexpectedly 0 bytes!\n", client->sock);
172                         close_client(client);
173                         return;
174                 }
175
176                 // Guard against overlong requests gobbling up all of our space.
177                 if (client->request.size() + ret > MAX_CLIENT_REQUEST) {
178                         fprintf(stderr, "WARNING: fd %d sent overlong request!\n", client->sock);
179                         close_client(client);
180                         return;
181                 }       
182
183                 // See if we have \r\n\r\n anywhere in the request. We start three bytes
184                 // before what we just appended, in case we just got the final character.
185                 size_t existing_req_bytes = client->request.size();
186                 client->request.append(string(buf, buf + ret));
187         
188                 size_t start_at = (existing_req_bytes >= 3 ? existing_req_bytes - 3 : 0);
189                 const char *ptr = reinterpret_cast<char *>(
190                         memmem(client->request.data() + start_at, client->request.size() - start_at,
191                                "\r\n\r\n", 4));
192                 if (ptr == NULL) {
193                         // OK, we don't have the entire header yet. Fine; we'll get it later.
194                         return;
195                 }
196
197                 if (ptr != client->request.data() + client->request.size() - 4) {
198                         fprintf(stderr, "WARNING: fd %d had junk data after request!\n", client->sock);
199                         close_client(client);
200                         return;
201                 }
202
203                 parse_request(client);
204                 break;
205         }
206         case Client::SENDING_HEADER: {
207                 int ret = write(client->sock,
208                                 client->header.data() + client->header_bytes_sent,
209                                 client->header.size() - client->header_bytes_sent);
210                 if (ret == -1) {
211                         perror("write");
212                         close_client(client);
213                         return;
214                 }
215                 
216                 client->header_bytes_sent += ret;
217                 assert(client->header_bytes_sent <= client->header.size());
218
219                 if (client->header_bytes_sent < client->header.size()) {
220                         // We haven't sent all yet. Fine; we'll do that later.
221                         return;
222                 }
223
224                 // We're done sending the header! Clear the entire header to release some memory.
225                 client->header.clear();
226
227                 // Start sending from the end. In other words, we won't send any of the backlog,
228                 // but we'll start sending immediately as we get data.
229                 client->state = Client::SENDING_DATA;
230                 client->bytes_sent = streams[client->stream_id].data_size;
231                 break;
232         }
233         case Client::SENDING_DATA: {
234                 // See if there's some data we've lost. Ideally, we should drop to a block boundary,
235                 // but resync will be the mux's problem.
236                 const Stream &stream = streams[client->stream_id];
237                 size_t bytes_to_send = stream.data_size - client->bytes_sent;
238                 if (bytes_to_send > BACKLOG_SIZE) {
239                         fprintf(stderr, "WARNING: fd %d lost %lld bytes, maybe too slow connection\n",
240                                 client->sock,
241                                 (long long int)(bytes_to_send - BACKLOG_SIZE));
242                         client->bytes_sent = streams[client->stream_id].data_size - BACKLOG_SIZE;
243                         bytes_to_send = BACKLOG_SIZE;
244                 }
245
246                 // See if we need to split across the circular buffer.
247                 ssize_t ret;
248                 if ((client->bytes_sent % BACKLOG_SIZE) + bytes_to_send > BACKLOG_SIZE) {
249                         size_t bytes_first_part = BACKLOG_SIZE - (client->bytes_sent % BACKLOG_SIZE);
250
251                         iovec iov[2];
252                         iov[0].iov_base = const_cast<char *>(stream.data + (client->bytes_sent % BACKLOG_SIZE));
253                         iov[0].iov_len = bytes_first_part;
254
255                         iov[1].iov_base = const_cast<char *>(stream.data);
256                         iov[1].iov_len = bytes_to_send - bytes_first_part;
257
258                         ret = writev(client->sock, iov, 2);
259                 } else {
260                         ret = write(client->sock,
261                                     stream.data + (client->bytes_sent % BACKLOG_SIZE),
262                                     bytes_to_send);
263                 }
264                 if (ret == -1) {
265                         perror("write/writev");
266                         close_client(client);
267                         return;
268                 }
269                 client->bytes_sent += ret;
270
271                 if (client->bytes_sent == stream.data_size) {
272                         // We don't have any more data for this client, so put it to sleep.
273                         put_client_to_sleep(client);
274                 }
275                 break;
276         }
277         default:
278                 assert(false);
279         }
280 }
281
282 void Server::parse_request(Client *client)
283 {
284         // TODO: Actually parse the request. :-)
285         client->stream_id = "stream";
286         client->request.clear();
287
288         // Construct the header.
289         client->header = "HTTP/1.0 200 OK\r\n  Content-type: video/x-flv\r\nCache-Control: no-cache\r\nContent-type: todo/fixme\r\n\r\n" +
290                 streams[client->stream_id].header;
291
292         // Switch states.
293         client->state = Client::SENDING_HEADER;
294
295         epoll_event ev;
296         ev.events = EPOLLOUT | EPOLLRDHUP;
297         ev.data.fd = client->sock;
298
299         if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, client->sock, &ev) == -1) {
300                 perror("epoll_ctl(EPOLL_CTL_MOD)");
301                 exit(1);
302         }
303 }
304         
305 void Server::close_client(Client *client)
306 {
307         if (epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client->sock, NULL) == -1) {
308                 perror("epoll_ctl(EPOLL_CTL_DEL)");
309                 exit(1);
310         }
311
312         // This client could be sleeping, so we'll need to fix that. (Argh, O(n).)
313         vector<int>::iterator new_end =
314                 remove(sleeping_clients.begin(), sleeping_clients.end(), client->sock);
315         sleeping_clients.erase(new_end, sleeping_clients.end());
316         
317         // Bye-bye!
318         close(client->sock);
319         clients.erase(client->sock);
320 }
321         
322 void Server::put_client_to_sleep(Client *client)
323 {
324         epoll_event ev;
325         ev.events = EPOLLRDHUP;
326         ev.data.fd = client->sock;
327
328         if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, client->sock, &ev) == -1) {
329                 perror("epoll_ctl(EPOLL_CTL_MOD)");
330                 exit(1);
331         }
332
333         sleeping_clients.push_back(client->sock);
334 }
335
336 void Server::wake_up_all_clients()
337 {
338         for (unsigned i = 0; i < sleeping_clients.size(); ++i) {
339                 epoll_event ev;
340                 ev.events = EPOLLOUT | EPOLLRDHUP;
341                 ev.data.fd = sleeping_clients[i];
342                 if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, sleeping_clients[i], &ev) == -1) {
343                         perror("epoll_ctl(EPOLL_CTL_MOD)");
344                         exit(1);
345                 }
346         }
347         sleeping_clients.clear();
348 }