Add a listen statement to listen on only specific IP addresses, in addition to the...
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Sun, 1 Dec 2013 00:16:41 +0000 (01:16 +0100)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Sun, 1 Dec 2013 00:16:41 +0000 (01:16 +0100)
Makefile
acceptor.cpp
acceptor.h
config.cpp
config.h
cubemap.config.sample
main.cpp
sa_compare.cpp [new file with mode: 0644]
sa_compare.h [new file with mode: 0644]
state.proto
udpinput.cpp

index 103c6a4..40ae397 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -5,7 +5,7 @@ PROTOC=protoc
 CXXFLAGS=-Wall -O2 -g -pthread
 LDLIBS=-lprotobuf -pthread -lrt
 
-OBJS=main.o client.o server.o stream.o udpstream.o serverpool.o mutexlock.o input.o input_stats.o httpinput.o udpinput.o parse.o config.o markpool.o acceptor.o stats.o accesslog.o thread.o util.o log.o metacube2.o state.pb.o
+OBJS=main.o client.o server.o stream.o udpstream.o serverpool.o mutexlock.o input.o input_stats.o httpinput.o udpinput.o parse.o config.o markpool.o acceptor.o stats.o accesslog.o thread.o util.o log.o metacube2.o sa_compare.o state.pb.o
 
 all: cubemap
 
index d4d4efb..b5b4039 100644 (file)
@@ -19,7 +19,7 @@ using namespace std;
 
 extern ServerPool *servers;
 
-int create_server_socket(int port, SocketType socket_type)
+int create_server_socket(const sockaddr_in6 &addr, SocketType socket_type)
 {
        int server_sock;
        if (socket_type == TCP_SOCKET) {
@@ -52,12 +52,7 @@ int create_server_socket(int port, SocketType socket_type)
                exit(1);
        }
 
-       sockaddr_in6 addr;
-       memset(&addr, 0, sizeof(addr));
-       addr.sin6_family = AF_INET6;
-       addr.sin6_port = htons(port);
-
-       if (bind(server_sock, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)) == -1) {
+       if (bind(server_sock, reinterpret_cast<const sockaddr *>(&addr), sizeof(addr)) == -1) {
                log_perror("bind");
                exit(1);
        }
@@ -71,24 +66,52 @@ int create_server_socket(int port, SocketType socket_type)
 
        return server_sock;
 }
+
+sockaddr_in6 CreateAnyAddress(int port)
+{
+       sockaddr_in6 sin6;
+       memset(&sin6, 0, sizeof(sin6));
+       sin6.sin6_family = AF_INET6;
+       sin6.sin6_port = htons(port);
+       return sin6;
+}
+
+sockaddr_in6 ExtractAddressFromAcceptorProto(const AcceptorProto &proto)
+{
+       sockaddr_in6 sin6;
+       memset(&sin6, 0, sizeof(sin6));
+       sin6.sin6_family = AF_INET6;
+
+       if (!proto.addr().empty()) {
+               int ret = inet_pton(AF_INET6, proto.addr().c_str(), &sin6.sin6_addr);
+               assert(ret == 1);
+       }
+
+       sin6.sin6_port = htons(proto.port());
+       return sin6;
+}
        
-Acceptor::Acceptor(int server_sock, int port)
+Acceptor::Acceptor(int server_sock, const sockaddr_in6 &addr)
        : server_sock(server_sock),
-         port(port)
+         addr(addr)
 {
 }
 
 Acceptor::Acceptor(const AcceptorProto &serialized)
        : server_sock(serialized.server_sock()),
-         port(serialized.port())
+         addr(ExtractAddressFromAcceptorProto(serialized))
 {
 }
 
 AcceptorProto Acceptor::serialize() const
 {
+       char buf[INET6_ADDRSTRLEN];
+       inet_ntop(addr.sin6_family, &addr.sin6_addr, buf, sizeof(buf));
+
        AcceptorProto serialized;
        serialized.set_server_sock(server_sock);
-       serialized.set_port(port);
+       serialized.set_addr(buf);
+       serialized.set_port(ntohs(addr.sin6_port));
        return serialized;
 }
 
index e0a1bd5..199aea3 100644 (file)
@@ -1,21 +1,26 @@
 #ifndef _ACCEPTOR_H
 #define _ACCEPTOR_H
 
+#include <netinet/in.h>
+
 #include "thread.h"
 
 enum SocketType {
        TCP_SOCKET,
        UDP_SOCKET,
 };
-int create_server_socket(int port, SocketType socket_type);
+int create_server_socket(const sockaddr_in6 &addr, SocketType socket_type);
 
 class AcceptorProto;
 
+sockaddr_in6 CreateAnyAddress(int port);
+sockaddr_in6 ExtractAddressFromAcceptorProto(const AcceptorProto &proto);
+
 // A thread that accepts new connections on a given socket,
 // and hands them off to the server pool.
 class Acceptor : public Thread {
 public:
-       Acceptor(int server_sock, int port);
+       Acceptor(int server_sock, const sockaddr_in6 &addr);
 
        // Serialization/deserialization.
        Acceptor(const AcceptorProto &serialized);
@@ -26,7 +31,8 @@ public:
 private:
        virtual void do_work();
 
-       int server_sock, port;
+       int server_sock;
+       sockaddr_in6 addr;
 };
 
 #endif  // !defined(_ACCEPTOR_H)
index f5182aa..e7c9c84 100644 (file)
@@ -11,6 +11,7 @@
 #include <utility>
 #include <vector>
 
+#include "acceptor.h"
 #include "config.h"
 #include "log.h"
 #include "parse.h"
@@ -119,12 +120,80 @@ bool parse_port(const ConfigLine &line, Config *config)
                return false;
        }
 
+       int port = atoi(line.arguments[0].c_str());
+       if (port < 1 || port >= 65536) {
+               log(ERROR, "port %d is out of range (must be [1,65536>).", port);
+               return false;
+       }
+
+       AcceptorConfig acceptor;
+       acceptor.addr = CreateAnyAddress(port);
+
+       config->acceptors.push_back(acceptor);
+       return true;
+}
+
+bool parse_listen(const ConfigLine &line, Config *config)
+{
+       if (line.arguments.size() != 1) {
+               log(ERROR, "'listen' takes exactly one argument");
+               return false;
+       }
+
+       string addr_string = line.arguments[0];
+       if (addr_string.empty()) {
+               // Actually, this should never happen.
+               log(ERROR, "'listen' argument cannot be empty");
+               return false;
+       }
+
+       string port_string;
+
        AcceptorConfig acceptor;
-       acceptor.port = atoi(line.arguments[0].c_str());
-       if (acceptor.port < 1 || acceptor.port >= 65536) {
-               log(ERROR, "port %d is out of range (must be [1,65536>).", acceptor.port);
+       memset(&acceptor.addr, 0, sizeof(acceptor.addr));
+       acceptor.addr.sin6_family = AF_INET6;
+       if (addr_string[0] == '[') {
+               // IPv6 address: [addr]:port.
+               size_t addr_end = addr_string.find("]:");
+               if (addr_end == string::npos) {
+                       log(ERROR, "IPv6 address '%s' should be on form [address]:port", addr_string.c_str());
+                       return false;
+               }
+
+               string addr_only = addr_string.substr(1, addr_end - 1);
+               if (inet_pton(AF_INET6, addr_only.c_str(), &acceptor.addr.sin6_addr) != 1) {
+                       log(ERROR, "Invalid IPv6 address '%s'", addr_only.c_str());
+                       return false;
+               }
+
+               port_string = addr_string.substr(addr_end + 2);
+       } else {
+               // IPv4 address: addr:port.
+               size_t addr_end = addr_string.find(":");
+               if (addr_end == string::npos) {
+                       log(ERROR, "IPv4 address '%s' should be on form address:port", addr_string.c_str());
+                       return false;
+               }
+
+               in_addr addr4;
+               string addr_only = addr_string.substr(0, addr_end);
+               if (inet_pton(AF_INET, addr_only.c_str(), &addr4) != 1) {
+                       log(ERROR, "Invalid IPv4 address '%s'", addr_only.c_str());
+                       return false;
+               }
+
+               // Convert to a v4-mapped address.
+               acceptor.addr.sin6_addr.s6_addr32[2] = htonl(0xffff);
+               acceptor.addr.sin6_addr.s6_addr32[3] = addr4.s_addr;
+               port_string = addr_string.substr(addr_end + 1);
+       }
+
+       int port = atoi(port_string.c_str());
+       if (port < 1 || port >= 65536) {
+               log(ERROR, "port %d is out of range (must be [1,65536>).", port);
                return false;
        }
+       acceptor.addr.sin6_port = ntohs(port);
 
        config->acceptors.push_back(acceptor);
        return true;
@@ -423,6 +492,10 @@ bool parse_config(const string &filename, Config *config)
                        if (!parse_port(line, config)) {
                                return false;
                        }
+               } else if (line.keyword == "listen") {
+                       if (!parse_listen(line, config)) {
+                               return false;
+                       }
                } else if (line.keyword == "stream") {
                        if (!parse_stream(line, config)) {
                                return false;
index b84b957..f06b9c5 100644 (file)
--- a/config.h
+++ b/config.h
@@ -30,7 +30,7 @@ struct UDPStreamConfig {
 };
 
 struct AcceptorConfig {
-       int port;
+       sockaddr_in6 addr;
 };
 
 struct LogConfig {
index bce0659..83b24ff 100644 (file)
@@ -12,6 +12,8 @@ num_servers 1
 # All input ports are treated exactly the same, but you may use multiple ones nevertheless.
 #
 port 9094
+# listen 127.0.0.1:9095
+# listen [::1]:9095
 
 stats_file cubemap.stats
 stats_interval 60
index 86ba28a..194ffc7 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -23,6 +23,7 @@
 #include "input_stats.h"
 #include "log.h"
 #include "markpool.h"
+#include "sa_compare.h"
 #include "serverpool.h"
 #include "state.pb.h"
 #include "stats.h"
@@ -80,27 +81,28 @@ CubemapStateProto collect_state(const timeval &serialize_start,
 // Find all port statements in the configuration file, and create acceptors for htem.
 vector<Acceptor *> create_acceptors(
        const Config &config,
-       map<int, Acceptor *> *deserialized_acceptors)
+       map<sockaddr_in6, Acceptor *, Sockaddr6Compare> *deserialized_acceptors)
 {
        vector<Acceptor *> acceptors;
        for (unsigned i = 0; i < config.acceptors.size(); ++i) {
                const AcceptorConfig &acceptor_config = config.acceptors[i];
                Acceptor *acceptor = NULL;
-               map<int, Acceptor *>::iterator deserialized_acceptor_it =
-                       deserialized_acceptors->find(acceptor_config.port);
+               map<sockaddr_in6, Acceptor *, Sockaddr6Compare>::iterator deserialized_acceptor_it =
+                       deserialized_acceptors->find(acceptor_config.addr);
                if (deserialized_acceptor_it != deserialized_acceptors->end()) {
                        acceptor = deserialized_acceptor_it->second;
                        deserialized_acceptors->erase(deserialized_acceptor_it);
                } else {
-                       int server_sock = create_server_socket(acceptor_config.port, TCP_SOCKET);
-                       acceptor = new Acceptor(server_sock, acceptor_config.port);
+                       int server_sock = create_server_socket(acceptor_config.addr, TCP_SOCKET);
+                       acceptor = new Acceptor(server_sock, acceptor_config.addr);
                }
                acceptor->run();
                acceptors.push_back(acceptor);
        }
 
        // Close all acceptors that are no longer in the configuration file.
-       for (map<int, Acceptor *>::iterator acceptor_it = deserialized_acceptors->begin();
+       for (map<sockaddr_in6, Acceptor *, Sockaddr6Compare>::iterator
+                acceptor_it = deserialized_acceptors->begin();
             acceptor_it != deserialized_acceptors->end();
             ++acceptor_it) {
                acceptor_it->second->close_socket();
@@ -386,7 +388,7 @@ start:
        CubemapStateProto loaded_state;
        struct timeval serialize_start;
        set<string> deserialized_urls;
-       map<int, Acceptor *> deserialized_acceptors;
+       map<sockaddr_in6, Acceptor *, Sockaddr6Compare> deserialized_acceptors;
        multimap<string, InputWithRefcount> inputs;  // multimap due to older versions without deduplication.
        if (state_fd != -1) {
                log(INFO, "Deserializing state from previous process...");
@@ -452,8 +454,9 @@ start:
 
                // Deserialize the acceptors.
                for (int i = 0; i < loaded_state.acceptors_size(); ++i) {
+                       sockaddr_in6 sin6 = ExtractAddressFromAcceptorProto(loaded_state.acceptors(i));
                        deserialized_acceptors.insert(make_pair(
-                               loaded_state.acceptors(i).port(),
+                               sin6,
                                new Acceptor(loaded_state.acceptors(i))));
                }
 
diff --git a/sa_compare.cpp b/sa_compare.cpp
new file mode 100644 (file)
index 0000000..aff5d2f
--- /dev/null
@@ -0,0 +1,17 @@
+#include "sa_compare.h"
+
+#include <arpa/inet.h>
+#include <assert.h>
+#include <string.h>
+
+bool Sockaddr6Compare::operator() (const sockaddr_in6 &a, const sockaddr_in6 &b) const
+{
+       assert(a.sin6_family == AF_INET6);
+       assert(b.sin6_family == AF_INET6);
+       int addr_cmp = memcmp(&a.sin6_addr, &b.sin6_addr, sizeof(a.sin6_addr));
+       if (addr_cmp == 0) {
+               return (ntohs(a.sin6_port) < ntohs(b.sin6_port));
+       } else {
+               return (addr_cmp < 0);
+       }
+}
diff --git a/sa_compare.h b/sa_compare.h
new file mode 100644 (file)
index 0000000..95fa0d8
--- /dev/null
@@ -0,0 +1,11 @@
+#ifndef _SA_COMPARE_H
+#define _SA_COMPARE_H
+
+#include <netinet/in.h>
+
+// A utility functor to help use sockaddr_in6 as keys in a map.
+struct Sockaddr6Compare {
+       bool operator() (const sockaddr_in6 &a, const sockaddr_in6 &b) const;
+};
+
+#endif  // !defined(_SA_COMPARE_H)
index 9479c2a..cd17253 100644 (file)
@@ -53,6 +53,7 @@ message InputProto {
 message AcceptorProto {
        optional int32 server_sock = 1;
        optional int32 port = 2;
+       optional string addr = 3;  // As a string. Empty is equivalent to "::".
 };
 
 message CubemapStateProto {
index f4c78fe..6b17799 100644 (file)
@@ -100,7 +100,8 @@ void UDPInput::do_work()
        while (!should_stop()) {
                if (sock == -1) {
                        int port_num = atoi(port.c_str());
-                       sock = create_server_socket(port_num, UDP_SOCKET);
+                       sockaddr_in6 addr = CreateAnyAddress(port_num);
+                       sock = create_server_socket(addr, UDP_SOCKET);
                        if (sock == -1) {
                                log(WARNING, "[%s] UDP socket creation failed. Waiting 0.2 seconds and trying again...",
                                             url.c_str());