]> git.sesse.net Git - cubemap/blob - server.cpp
0c36e49363178f1ccd0a4955ef119a5ea6ebe9bd
[cubemap] / server.cpp
1 #include <stdio.h>
2 #include <string.h>
3 #include <stdint.h>
4 #include <unistd.h>
5 #include <assert.h>
6 #include <arpa/inet.h>
7 #include <sys/socket.h>
8 #include <pthread.h>
9 #include <sys/types.h>
10 #include <sys/ioctl.h>
11 #include <sys/epoll.h>
12 #include <sys/sendfile.h>
13 #include <time.h>
14 #include <signal.h>
15 #include <errno.h>
16 #include <vector>
17 #include <string>
18 #include <map>
19 #include <algorithm>
20
21 #include "markpool.h"
22 #include "metacube.h"
23 #include "server.h"
24 #include "mutexlock.h"
25 #include "parse.h"
26 #include "util.h"
27 #include "state.pb.h"
28
29 using namespace std;
30
31 Client::Client(int sock)
32         : sock(sock),
33           fwmark(0),
34           connect_time(time(NULL)),
35           state(Client::READING_REQUEST),
36           stream(NULL),
37           header_or_error_bytes_sent(0),
38           bytes_sent(0)
39 {
40         request.reserve(1024);
41
42         // Find the remote address, and convert it to ASCII.
43         sockaddr_in6 addr;
44         socklen_t addr_len = sizeof(addr);
45
46         if (getpeername(sock, reinterpret_cast<sockaddr *>(&addr), &addr_len) == -1) {
47                 perror("getpeername");
48                 remote_addr = "";
49         } else {
50                 char buf[INET6_ADDRSTRLEN];
51                 if (inet_ntop(addr.sin6_family, &addr.sin6_addr, buf, sizeof(buf)) == NULL) {
52                         perror("inet_ntop");
53                         remote_addr = "";
54                 } else {
55                         remote_addr = buf;
56                 }
57         }
58 }
59         
60 Client::Client(const ClientProto &serialized, Stream *stream)
61         : sock(serialized.sock()),
62           remote_addr(serialized.remote_addr()),
63           connect_time(serialized.connect_time()),
64           state(State(serialized.state())),
65           request(serialized.request()),
66           stream_id(serialized.stream_id()),
67           stream(stream),
68           header_or_error(serialized.header_or_error()),
69           header_or_error_bytes_sent(serialized.header_or_error_bytes_sent()),
70           bytes_sent(serialized.bytes_sent())
71 {
72         if (stream->mark_pool != NULL) {
73                 fwmark = stream->mark_pool->get_mark();
74         } else {
75                 fwmark = 0;  // No mark.
76         }
77         if (setsockopt(sock, SOL_SOCKET, SO_MARK, &fwmark, sizeof(fwmark)) == -1) {
78                 if (fwmark != 0) {
79                         perror("setsockopt(SO_MARK)");
80                 }
81         }
82 }
83
84 ClientProto Client::serialize() const
85 {
86         ClientProto serialized;
87         serialized.set_sock(sock);
88         serialized.set_remote_addr(remote_addr);
89         serialized.set_connect_time(connect_time);
90         serialized.set_state(state);
91         serialized.set_request(request);
92         serialized.set_stream_id(stream_id);
93         serialized.set_header_or_error(header_or_error);
94         serialized.set_header_or_error_bytes_sent(serialized.header_or_error_bytes_sent());
95         serialized.set_bytes_sent(bytes_sent);
96         return serialized;
97 }
98         
99 ClientStats Client::get_stats() const
100 {
101         ClientStats stats;
102         stats.stream_id = stream_id;
103         stats.remote_addr = remote_addr;
104         stats.connect_time = connect_time;
105         stats.bytes_sent = bytes_sent;
106         return stats;
107 }
108
109 Stream::Stream(const string &stream_id, size_t backlog_size)
110         : stream_id(stream_id),
111           data_fd(make_tempfile("")),
112           backlog_size(backlog_size),
113           bytes_received(0),
114           mark_pool(NULL)
115 {
116         if (data_fd == -1) {
117                 exit(1);
118         }
119 }
120
121 Stream::~Stream()
122 {
123         if (data_fd != -1) {
124                 int ret;
125                 do {
126                         ret = close(data_fd);
127                 } while (ret == -1 && errno == EINTR);
128                 if (ret == -1) {
129                         perror("close");
130                 }
131         }
132 }
133
134 Stream::Stream(const StreamProto &serialized)
135         : stream_id(serialized.stream_id()),
136           header(serialized.header()),
137           data_fd(make_tempfile(serialized.data())),
138           backlog_size(serialized.backlog_size()),
139           bytes_received(serialized.bytes_received()),
140           mark_pool(NULL)
141 {
142         if (data_fd == -1) {
143                 exit(1);
144         }
145 }
146
147 StreamProto Stream::serialize()
148 {
149         StreamProto serialized;
150         serialized.set_header(header);
151         if (!read_tempfile(data_fd, serialized.mutable_data())) {  // Closes data_fd.
152                 exit(1);
153         }
154         serialized.set_backlog_size(backlog_size);
155         serialized.set_bytes_received(bytes_received);
156         serialized.set_stream_id(stream_id);
157         data_fd = -1;
158         return serialized;
159 }
160
161 void Stream::put_client_to_sleep(Client *client)
162 {
163         sleeping_clients.push_back(client);
164 }
165
166 void Stream::wake_up_all_clients()
167 {
168         if (to_process.empty()) {
169                 swap(sleeping_clients, to_process);
170         } else {
171                 to_process.insert(to_process.end(), sleeping_clients.begin(), sleeping_clients.end());
172                 sleeping_clients.clear();
173         }
174 }
175
176 Server::Server()
177 {
178         pthread_mutex_init(&mutex, NULL);
179         pthread_mutex_init(&queued_data_mutex, NULL);
180
181         epoll_fd = epoll_create(1024);  // Size argument is ignored.
182         if (epoll_fd == -1) {
183                 perror("epoll_fd");
184                 exit(1);
185         }
186 }
187
188 Server::~Server()
189 {
190         int ret;
191         do {
192                 ret = close(epoll_fd);
193         } while (ret == -1 && errno == EINTR);
194
195         if (ret == -1) {
196                 perror("close(epoll_fd)");
197         }
198 }
199
200 vector<ClientStats> Server::get_client_stats() const
201 {
202         vector<ClientStats> ret;
203
204         MutexLock lock(&mutex);
205         for (map<int, Client>::const_iterator client_it = clients.begin();
206              client_it != clients.end();
207              ++client_it) {
208                 ret.push_back(client_it->second.get_stats());
209         }
210         return ret;
211 }
212
213 void Server::do_work()
214 {
215         for ( ;; ) {
216                 int nfds = epoll_wait(epoll_fd, events, EPOLL_MAX_EVENTS, EPOLL_TIMEOUT_MS);
217                 if (nfds == -1 && errno == EINTR) {
218                         if (should_stop) {
219                                 return;
220                         }
221                         continue;
222                 }
223                 if (nfds == -1) {
224                         perror("epoll_wait");
225                         exit(1);
226                 }
227
228                 MutexLock lock(&mutex);  // We release the mutex between iterations.
229         
230                 process_queued_data();
231
232                 for (int i = 0; i < nfds; ++i) {
233                         int fd = events[i].data.fd;
234                         assert(clients.count(fd) != 0);
235                         Client *client = &clients[fd];
236
237                         if (events[i].events & (EPOLLERR | EPOLLRDHUP | EPOLLHUP)) {
238                                 close_client(client);
239                                 continue;
240                         }
241
242                         process_client(client);
243                 }
244
245                 for (map<string, Stream *>::iterator stream_it = streams.begin();
246                      stream_it != streams.end();
247                      ++stream_it) {
248                         vector<Client *> to_process;
249                         swap(stream_it->second->to_process, to_process);
250                         for (size_t i = 0; i < to_process.size(); ++i) {
251                                 process_client(to_process[i]);
252                         }
253                 }
254
255                 if (should_stop) {
256                         return;
257                 }
258         }
259 }
260
261 CubemapStateProto Server::serialize()
262 {
263         // We don't serialize anything queued, so empty the queues.
264         process_queued_data();
265
266         CubemapStateProto serialized;
267         for (map<int, Client>::const_iterator client_it = clients.begin();
268              client_it != clients.end();
269              ++client_it) {
270                 serialized.add_clients()->MergeFrom(client_it->second.serialize());
271         }
272         for (map<string, Stream *>::const_iterator stream_it = streams.begin();
273              stream_it != streams.end();
274              ++stream_it) {
275                 serialized.add_streams()->MergeFrom(stream_it->second->serialize());
276         }
277         return serialized;
278 }
279
280 void Server::add_client_deferred(int sock)
281 {
282         MutexLock lock(&queued_data_mutex);
283         queued_add_clients.push_back(sock);
284 }
285
286 void Server::add_client(int sock)
287 {
288         clients.insert(make_pair(sock, Client(sock)));
289
290         // Start listening on data from this socket.
291         epoll_event ev;
292         ev.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
293         ev.data.u64 = 0;  // Keep Valgrind happy.
294         ev.data.fd = sock;
295         if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, sock, &ev) == -1) {
296                 perror("epoll_ctl(EPOLL_CTL_ADD)");
297                 exit(1);
298         }
299
300         process_client(&clients[sock]);
301 }
302
303 void Server::add_client_from_serialized(const ClientProto &client)
304 {
305         MutexLock lock(&mutex);
306         Stream *stream = find_stream(client.stream_id());
307         clients.insert(make_pair(client.sock(), Client(client, stream)));
308         Client *client_ptr = &clients[client.sock()];
309
310         // Start listening on data from this socket.
311         epoll_event ev;
312         if (client.state() == Client::READING_REQUEST) {
313                 ev.events = EPOLLIN | EPOLLET | EPOLLRDHUP;
314         } else {
315                 // If we don't have more data for this client, we'll be putting it into
316                 // the sleeping array again soon.
317                 ev.events = EPOLLOUT | EPOLLET | EPOLLRDHUP;
318         }
319         ev.data.u64 = 0;  // Keep Valgrind happy.
320         ev.data.fd = client.sock();
321         if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client.sock(), &ev) == -1) {
322                 perror("epoll_ctl(EPOLL_CTL_ADD)");
323                 exit(1);
324         }
325
326         if (client_ptr->state == Client::SENDING_DATA && 
327             client_ptr->bytes_sent == client_ptr->stream->bytes_received) {
328                 client_ptr->stream->put_client_to_sleep(client_ptr);
329         } else {
330                 process_client(client_ptr);
331         }
332 }
333
334 void Server::add_stream(const string &stream_id, size_t backlog_size)
335 {
336         MutexLock lock(&mutex);
337         streams.insert(make_pair(stream_id, new Stream(stream_id, backlog_size)));
338 }
339
340 void Server::add_stream_from_serialized(const StreamProto &stream)
341 {
342         MutexLock lock(&mutex);
343         streams.insert(make_pair(stream.stream_id(), new Stream(stream)));
344 }
345         
346 void Server::set_header(const string &stream_id, const string &header)
347 {
348         MutexLock lock(&mutex);
349         find_stream(stream_id)->header = header;
350
351         // If there are clients we haven't sent anything to yet, we should give
352         // them the header, so push back into the SENDING_HEADER state.
353         for (map<int, Client>::iterator client_it = clients.begin();
354              client_it != clients.end();
355              ++client_it) {
356                 Client *client = &client_it->second;
357                 if (client->state == Client::SENDING_DATA &&
358                     client->bytes_sent == 0) {
359                         construct_header(client);
360                 }
361         }
362 }
363         
364 void Server::set_mark_pool(const std::string &stream_id, MarkPool *mark_pool)
365 {
366         MutexLock lock(&mutex);
367         assert(clients.empty());
368         find_stream(stream_id)->mark_pool = mark_pool;
369 }
370
371 void Server::add_data_deferred(const string &stream_id, const char *data, size_t bytes)
372 {
373         MutexLock lock(&queued_data_mutex);
374         queued_data[stream_id].append(string(data, data + bytes));
375 }
376
377 void Server::add_data(const string &stream_id, const char *data, ssize_t bytes)
378 {
379         Stream *stream = find_stream(stream_id);
380         size_t pos = stream->bytes_received % stream->backlog_size;
381         stream->bytes_received += bytes;
382
383         if (pos + bytes > stream->backlog_size) {
384                 ssize_t to_copy = stream->backlog_size - pos;
385                 while (to_copy > 0) {
386                         int ret = pwrite(stream->data_fd, data, to_copy, pos);
387                         if (ret == -1 && errno == EINTR) {
388                                 continue;
389                         }
390                         if (ret == -1) {
391                                 perror("pwrite");
392                                 // Dazed and confused, but trying to continue...
393                                 break;
394                         }
395                         pos += ret;
396                         data += ret;
397                         to_copy -= ret;
398                         bytes -= ret;
399                 }
400                 pos = 0;
401         }
402
403         while (bytes > 0) {
404                 int ret = pwrite(stream->data_fd, data, bytes, pos);
405                 if (ret == -1 && errno == EINTR) {
406                         continue;
407                 }
408                 if (ret == -1) {
409                         perror("pwrite");
410                         // Dazed and confused, but trying to continue...
411                         break;
412                 }
413                 pos += ret;
414                 data += ret;
415                 bytes -= ret;
416         }
417
418         stream->wake_up_all_clients();
419 }
420
421 // See the .h file for postconditions after this function.      
422 void Server::process_client(Client *client)
423 {
424         switch (client->state) {
425         case Client::READING_REQUEST: {
426 read_request_again:
427                 // Try to read more of the request.
428                 char buf[1024];
429                 int ret;
430                 do {
431                         ret = read(client->sock, buf, sizeof(buf));
432                 } while (ret == -1 && errno == EINTR);
433
434                 if (ret == -1 && errno == EAGAIN) {
435                         // No more data right now. Nothing to do.
436                         // This is postcondition #2.
437                         return;
438                 }
439                 if (ret == -1) {
440                         perror("read");
441                         close_client(client);
442                         return;
443                 }
444                 if (ret == 0) {
445                         // OK, the socket is closed.
446                         close_client(client);
447                         return;
448                 }
449
450                 RequestParseStatus status = wait_for_double_newline(&client->request, buf, ret);
451         
452                 switch (status) {
453                 case RP_OUT_OF_SPACE:
454                         fprintf(stderr, "WARNING: fd %d sent overlong request!\n", client->sock);
455                         close_client(client);
456                         return;
457                 case RP_NOT_FINISHED_YET:
458                         // OK, we don't have the entire header yet. Fine; we'll get it later.
459                         // See if there's more data for us.
460                         goto read_request_again;
461                 case RP_EXTRA_DATA:
462                         fprintf(stderr, "WARNING: fd %d had junk data after request!\n", client->sock);
463                         close_client(client);
464                         return;
465                 case RP_FINISHED:
466                         break;
467                 }
468
469                 assert(status == RP_FINISHED);
470
471                 int error_code = parse_request(client);
472                 if (error_code == 200) {
473                         construct_header(client);
474                 } else {
475                         construct_error(client, error_code);
476                 }
477
478                 // We've changed states, so fall through.
479                 assert(client->state == Client::SENDING_ERROR ||
480                        client->state == Client::SENDING_HEADER);
481         }
482         case Client::SENDING_ERROR:
483         case Client::SENDING_HEADER: {
484 sending_header_or_error_again:
485                 int ret;
486                 do {
487                         ret = write(client->sock,
488                                     client->header_or_error.data() + client->header_or_error_bytes_sent,
489                                     client->header_or_error.size() - client->header_or_error_bytes_sent);
490                 } while (ret == -1 && errno == EINTR);
491
492                 if (ret == -1 && errno == EAGAIN) {
493                         // We're out of socket space, so now we're at the “low edge” of epoll's
494                         // edge triggering. epoll will tell us when there is more room, so for now,
495                         // just return.
496                         // This is postcondition #4.
497                         return;
498                 }
499
500                 if (ret == -1) {
501                         // Error! Postcondition #1.
502                         perror("write");
503                         close_client(client);
504                         return;
505                 }
506                 
507                 client->header_or_error_bytes_sent += ret;
508                 assert(client->header_or_error_bytes_sent <= client->header_or_error.size());
509
510                 if (client->header_or_error_bytes_sent < client->header_or_error.size()) {
511                         // We haven't sent all yet. Fine; go another round.
512                         goto sending_header_or_error_again;
513                 }
514
515                 // We're done sending the header or error! Clear it to release some memory.
516                 client->header_or_error.clear();
517
518                 if (client->state == Client::SENDING_ERROR) {
519                         // We're done sending the error, so now close.  
520                         // This is postcondition #1.
521                         close_client(client);
522                         return;
523                 }
524
525                 // Start sending from the end. In other words, we won't send any of the backlog,
526                 // but we'll start sending immediately as we get data.
527                 // This is postcondition #3.
528                 client->state = Client::SENDING_DATA;
529                 client->bytes_sent = client->stream->bytes_received;
530                 client->stream->put_client_to_sleep(client);
531                 return;
532         }
533         case Client::SENDING_DATA: {
534 sending_data_again:
535                 // See if there's some data we've lost. Ideally, we should drop to a block boundary,
536                 // but resync will be the mux's problem.
537                 Stream *stream = client->stream;
538                 size_t bytes_to_send = stream->bytes_received - client->bytes_sent;
539                 if (bytes_to_send == 0) {
540                         return;
541                 }
542                 if (bytes_to_send > stream->backlog_size) {
543                         fprintf(stderr, "WARNING: fd %d lost %lld bytes, maybe too slow connection\n",
544                                 client->sock,
545                                 (long long int)(bytes_to_send - stream->backlog_size));
546                         client->bytes_sent = stream->bytes_received - stream->backlog_size;
547                         bytes_to_send = stream->backlog_size;
548                 }
549
550                 // See if we need to split across the circular buffer.
551                 bool more_data = false;
552                 if ((client->bytes_sent % stream->backlog_size) + bytes_to_send > stream->backlog_size) {
553                         bytes_to_send = stream->backlog_size - (client->bytes_sent % stream->backlog_size);
554                         more_data = true;
555                 }
556
557                 ssize_t ret;
558                 do {
559                         loff_t offset = client->bytes_sent % stream->backlog_size;
560                         ret = sendfile(client->sock, stream->data_fd, &offset, bytes_to_send);
561                 } while (ret == -1 && errno == EINTR);
562
563                 if (ret == -1 && errno == EAGAIN) {
564                         // We're out of socket space, so return; epoll will wake us up
565                         // when there is more room.
566                         // This is postcondition #4.
567                         return;
568                 }
569                 if (ret == -1) {
570                         // Error, close; postcondition #1.
571                         perror("sendfile");
572                         close_client(client);
573                         return;
574                 }
575                 client->bytes_sent += ret;
576
577                 if (client->bytes_sent == stream->bytes_received) {
578                         // We don't have any more data for this client, so put it to sleep.
579                         // This is postcondition #3.
580                         stream->put_client_to_sleep(client);
581                 } else if (more_data) {
582                         goto sending_data_again;
583                 }
584                 break;
585         }
586         default:
587                 assert(false);
588         }
589 }
590
591 int Server::parse_request(Client *client)
592 {
593         vector<string> lines = split_lines(client->request);
594         if (lines.empty()) {
595                 return 400;  // Bad request (empty).
596         }
597
598         vector<string> request_tokens = split_tokens(lines[0]);
599         if (request_tokens.size() < 2) {
600                 return 400;  // Bad request (empty).
601         }
602         if (request_tokens[0] != "GET") {
603                 return 400;  // Should maybe be 405 instead?
604         }
605         if (streams.count(request_tokens[1]) == 0) {
606                 return 404;  // Not found.
607         }
608
609         client->stream_id = request_tokens[1];
610         client->stream = find_stream(client->stream_id);
611         if (client->stream->mark_pool != NULL) {
612                 client->fwmark = client->stream->mark_pool->get_mark();
613         } else {
614                 client->fwmark = 0;  // No mark.
615         }
616         if (setsockopt(client->sock, SOL_SOCKET, SO_MARK, &client->fwmark, sizeof(client->fwmark)) == -1) {                          
617                 if (client->fwmark != 0) {
618                         perror("setsockopt(SO_MARK)");
619                 }
620         }
621         client->request.clear();
622
623         return 200;  // OK!
624 }
625
626 void Server::construct_header(Client *client)
627 {
628         client->header_or_error = find_stream(client->stream_id)->header;
629
630         // Switch states.
631         client->state = Client::SENDING_HEADER;
632
633         epoll_event ev;
634         ev.events = EPOLLOUT | EPOLLET | EPOLLRDHUP;
635         ev.data.u64 = 0;  // Keep Valgrind happy.
636         ev.data.fd = client->sock;
637
638         if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, client->sock, &ev) == -1) {
639                 perror("epoll_ctl(EPOLL_CTL_MOD)");
640                 exit(1);
641         }
642 }
643         
644 void Server::construct_error(Client *client, int error_code)
645 {
646         char error[256];
647         snprintf(error, 256, "HTTP/1.0 %d Error\r\nContent-type: text/plain\r\n\r\nSomething went wrong. Sorry.\r\n",
648                 error_code);
649         client->header_or_error = error;
650
651         // Switch states.
652         client->state = Client::SENDING_ERROR;
653
654         epoll_event ev;
655         ev.events = EPOLLOUT | EPOLLET | EPOLLRDHUP;
656         ev.data.u64 = 0;  // Keep Valgrind happy.
657         ev.data.fd = client->sock;
658
659         if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, client->sock, &ev) == -1) {
660                 perror("epoll_ctl(EPOLL_CTL_MOD)");
661                 exit(1);
662         }
663 }
664
665 template<class T>
666 void delete_from(vector<T> *v, T elem)
667 {
668         typename vector<T>::iterator new_end = remove(v->begin(), v->end(), elem);
669         v->erase(new_end, v->end());
670 }
671         
672 void Server::close_client(Client *client)
673 {
674         if (epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client->sock, NULL) == -1) {
675                 perror("epoll_ctl(EPOLL_CTL_DEL)");
676                 exit(1);
677         }
678
679         // This client could be sleeping, so we'll need to fix that. (Argh, O(n).)
680         if (client->stream != NULL) {
681                 delete_from(&client->stream->sleeping_clients, client);
682                 delete_from(&client->stream->to_process, client);
683                 if (client->stream->mark_pool != NULL) {
684                         int fwmark = client->fwmark;
685                         client->stream->mark_pool->release_mark(fwmark);
686                 }
687         }
688
689         // Bye-bye!
690         int ret;
691         do {
692                 ret = close(client->sock);
693         } while (ret == -1 && errno == EINTR);
694
695         if (ret == -1) {
696                 perror("close");
697         }
698
699         clients.erase(client->sock);
700 }
701         
702 Stream *Server::find_stream(const string &stream_id)
703 {
704         map<string, Stream *>::iterator it = streams.find(stream_id);
705         assert(it != streams.end());
706         return it->second;
707 }
708
709 void Server::process_queued_data()
710 {
711         MutexLock lock(&queued_data_mutex);
712
713         for (size_t i = 0; i < queued_add_clients.size(); ++i) {
714                 add_client(queued_add_clients[i]);
715         }
716         queued_add_clients.clear();     
717         
718         for (map<string, string>::iterator queued_it = queued_data.begin();
719              queued_it != queued_data.end();
720              ++queued_it) {
721                 add_data(queued_it->first, queued_it->second.data(), queued_it->second.size());
722         }
723         queued_data.clear();
724 }