]> git.sesse.net Git - casparcg/blobdiff - protocol/util/AsyncEventServer.cpp
Enabled TCP keep alive on TCP connections to clients. This might help in discovering...
[casparcg] / protocol / util / AsyncEventServer.cpp
index 762599f564627c8cd5d1b655123b00659a13485f..d404becdbb9c5711f5d6f2203a8a14864ef49441 100644 (file)
@@ -19,7 +19,7 @@
 * Author: Robert Nagy, ronag89@gmail.com
 */
 
-#include "..\stdafx.h"
+#include "../StdAfx.h"
 
 #include "AsyncEventServer.h"
 
 #include <functional>
 
 #include <boost/asio.hpp>
-#include <boost/thread.hpp>
 #include <boost/lexical_cast.hpp>
 
-#include <tbb\mutex.h>
-#include <tbb\concurrent_hash_map.h>
-#include <tbb\concurrent_queue.h>
+#include <tbb/mutex.h>
+#include <tbb/concurrent_hash_map.h>
+#include <tbb/concurrent_queue.h>
 
 using boost::asio::ip::tcp;
 
@@ -52,9 +51,9 @@ class connection : public spl::enable_shared_from_this<connection>
        typedef tbb::concurrent_queue<std::string>      send_queue;
 
     const spl::shared_ptr<tcp::socket>                         socket_; 
-       boost::asio::io_service&                                                service_;
+       std::shared_ptr<boost::asio::io_service>                service_;
+       const std::wstring                                                              listen_port_;
        const spl::shared_ptr<connection_set>                   connection_set_;
-       const std::wstring                                                              name_;
        protocol_strategy_factory<char>::ptr                    protocol_factory_;
        std::shared_ptr<protocol_strategy<char>>                protocol_;
 
@@ -70,7 +69,7 @@ class connection : public spl::enable_shared_from_this<connection>
                explicit connection_holder(std::weak_ptr<connection> conn) : connection_(std::move(conn))
                {}
 
-               virtual void send(std::basic_string<char>&& data)
+               void send(std::basic_string<char>&& data) override
                {
                        auto conn = connection_.lock();
 
@@ -78,7 +77,7 @@ class connection : public spl::enable_shared_from_this<connection>
                                conn->send(std::move(data));
                }
 
-               virtual void disconnect()
+               void disconnect() override
                {
                        auto conn = connection_.lock();
 
@@ -86,24 +85,25 @@ class connection : public spl::enable_shared_from_this<connection>
                                conn->disconnect();
                }
 
-               virtual std::wstring print() const
+               std::wstring address() const override
                {
                        auto conn = connection_.lock();
 
                        if (conn)
-                               return conn->print();
+                               return conn->ipv4_address();
                        else
                                return L"[destroyed-connection]";
                }
 
-               virtual void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void>& lifecycle_bound)
+               void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void>& lifecycle_bound) override
                {
                        auto conn = connection_.lock();
 
                        if (conn)
                                return conn->add_lifecycle_bound_object(key, lifecycle_bound);
                }
-               virtual std::shared_ptr<void> remove_lifecycle_bound_object(const std::wstring& key)
+
+               std::shared_ptr<void> remove_lifecycle_bound_object(const std::wstring& key) override
                {
                        auto conn = connection_.lock();
 
@@ -115,32 +115,50 @@ class connection : public spl::enable_shared_from_this<connection>
        };
 
 public:
-    static spl::shared_ptr<connection> create(spl::shared_ptr<tcp::socket> socket, const protocol_strategy_factory<char>::ptr& protocol, spl::shared_ptr<connection_set> connection_set)
+       static spl::shared_ptr<connection> create(std::shared_ptr<boost::asio::io_service> service, spl::shared_ptr<tcp::socket> socket, const protocol_strategy_factory<char>::ptr& protocol, spl::shared_ptr<connection_set> connection_set)
        {
-               spl::shared_ptr<connection> con(new connection(std::move(socket), std::move(protocol), std::move(connection_set)));
+               spl::shared_ptr<connection> con(new connection(std::move(service), std::move(socket), std::move(protocol), std::move(connection_set)));
+               con->init();
                con->read_some();
                return con;
     }
 
+       void init()
+       {
+               protocol_ = protocol_factory_->create(spl::make_shared<connection_holder>(shared_from_this()));
+       }
+
        ~connection()
        {
-               CASPAR_LOG(info) << print() << L" connection destroyed.";
+               CASPAR_LOG(debug) << print() << L" connection destroyed.";
        }
 
        std::wstring print() const
        {
-               return L"[" + name_ + L"]";
+               return L"async_event_server[:" + listen_port_ + L"]";
+       }
+
+       std::wstring address() const
+       {
+               return u16(socket_->local_endpoint().address().to_string());
        }
        
-       virtual void send(std::string&& data)
+       std::wstring ipv4_address() const
+       {
+               return socket_->is_open() ? u16(socket_->remote_endpoint().address().to_string()) : L"no-address";
+       }
+
+       void send(std::string&& data)
        {
                send_queue_.push(std::move(data));
-               service_.dispatch([=] { do_write(); });
+               auto self = shared_from_this();
+               service_->dispatch([=] { self->do_write(); });
        }
 
-       virtual void disconnect()
+       void disconnect()
        {
-               service_.dispatch([=] { stop(); });
+               auto self = shared_from_this();
+               service_->dispatch([=] { self->stop(); });
        }
 
        void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void>& lifecycle_bound)
@@ -161,7 +179,6 @@ public:
                return std::shared_ptr<void>();
        }
 
-       /**************/
 private:
        void do_write() //always called from the asio-service-thread
        {
@@ -178,41 +195,30 @@ private:
        void stop()     //always called from the asio-service-thread
        {
                connection_set_->erase(shared_from_this());
+
+               CASPAR_LOG(info) << print() << L" Client " << ipv4_address() << L" disconnected (" << connection_set_->size() << L" connections).";
+
                try
                {
+                       socket_->cancel();
                        socket_->close();
                }
                catch(...)
                {
                        CASPAR_LOG_CURRENT_EXCEPTION();
                }
-               
-               CASPAR_LOG(info) << print() << L" Disconnected.";
        }
 
-       const std::string ipv4_address() const
-       {
-               return socket_->is_open() ? socket_->remote_endpoint().address().to_string() : "no-address";
-       }
-
-    connection(const spl::shared_ptr<tcp::socket>& socket, const protocol_strategy_factory<char>::ptr& protocol_factory, const spl::shared_ptr<connection_set>& connection_set) 
+    connection(const std::shared_ptr<boost::asio::io_service>& service, const spl::shared_ptr<tcp::socket>& socket, const protocol_strategy_factory<char>::ptr& protocol_factory, const spl::shared_ptr<connection_set>& connection_set) 
                : socket_(socket)
-               , service_(socket->get_io_service())
-               , name_((socket_->is_open() ? u16(socket_->local_endpoint().address().to_string() + ":" + boost::lexical_cast<std::string>(socket_->local_endpoint().port())) : L"no-address"))
+               , service_(service)
+               , listen_port_(socket_->is_open() ? boost::lexical_cast<std::wstring>(socket_->local_endpoint().port()) : L"no-port")
                , connection_set_(connection_set)
                , protocol_factory_(protocol_factory)
                , is_writing_(false)
        {
-               CASPAR_LOG(info) << print() << L" Connected.";
+               CASPAR_LOG(info) << print() << L" Accepted connection from " << ipv4_address() << L" (" << (connection_set_->size() + 1) << L" connections).";
     }
-
-       protocol_strategy<char>& protocol()     //always called from the asio-service-thread
-       {
-               if (!protocol_)
-                       protocol_ = protocol_factory_->create(spl::make_shared<connection_holder>(shared_from_this()));
-
-               return *protocol_;
-       }
                        
     void handle_read(const boost::system::error_code& error, size_t bytes_transferred)         //always called from the asio-service-thread
        {               
@@ -222,9 +228,7 @@ private:
                        {
                                std::string data(data_.begin(), data_.begin() + bytes_transferred);
 
-                               CASPAR_LOG(trace) << print() << L" Received: " << u16(data);
-
-                               protocol().parse(data);
+                               protocol_->parse(data);
                        }
                        catch(...)
                        {
@@ -241,7 +245,6 @@ private:
        {
                if(!error)
                {
-                       CASPAR_LOG(trace) << print() << L" Sent: " << (str->size() < 512 ? u16(*str) : L"more than 512 bytes.");
                        if(bytes_transferred != str->size())
                        {
                                str->assign(str->substr(bytes_transferred));
@@ -253,7 +256,7 @@ private:
                                do_write();
                        }
                }
-               else if (error != boost::asio::error::operation_aborted)                
+               else if (error != boost::asio::error::operation_aborted && socket_->is_open())          
                        stop();
     }
 
@@ -272,49 +275,51 @@ private:
        friend struct AsyncEventServer::implementation;
 };
 
-struct AsyncEventServer::implementation
+struct AsyncEventServer::implementation : public spl::enable_shared_from_this<implementation>
 {
-       boost::asio::io_service                                 service_;
-       tcp::acceptor                                                   acceptor_;
-       protocol_strategy_factory<char>::ptr    protocol_factory_;
-       spl::shared_ptr<connection_set>                 connection_set_;
-       boost::thread                                                   thread_;
-       std::vector<lifecycle_factory_t>                lifecycle_factories_;
-       tbb::mutex mutex_;
-
-       implementation(const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
-               : acceptor_(service_, tcp::endpoint(tcp::v4(), port))
+       std::shared_ptr<boost::asio::io_service>        service_;
+       tcp::acceptor                                                           acceptor_;
+       protocol_strategy_factory<char>::ptr            protocol_factory_;
+       spl::shared_ptr<connection_set>                         connection_set_;
+       std::vector<lifecycle_factory_t>                        lifecycle_factories_;
+       tbb::mutex                                                                      mutex_;
+
+       implementation(std::shared_ptr<boost::asio::io_service> service, const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
+               : service_(std::move(service))
+               , acceptor_(*service_, tcp::endpoint(tcp::v4(), port))
                , protocol_factory_(protocol)
-               , thread_([&] { service_.run(); })
        {
-               start_accept();
        }
 
-       ~implementation()
+       void stop()
        {
                try
                {
-                       acceptor_.close();                      
+                       acceptor_.cancel();
+                       acceptor_.close();
                }
-               catch(...)
+               catch (...)
                {
                        CASPAR_LOG_CURRENT_EXCEPTION();
                }
+       }
 
-               service_.post([=]
+       ~implementation()
+       {
+               auto conns_set = connection_set_;
+
+               service_->post([conns_set]
                {
-                       auto connections = *connection_set_;
+                       auto connections = *conns_set;
                        for (auto& connection : connections)
-                               connection->stop();                             
+                               connection->stop();
                });
-
-               thread_.join();
        }
                
-    void start_accept() 
+       void start_accept() 
        {
-               spl::shared_ptr<tcp::socket> socket(new tcp::socket(service_));
-               acceptor_.async_accept(*socket, std::bind(&implementation::handle_accept, this, socket, std::placeholders::_1));
+               spl::shared_ptr<tcp::socket> socket(new tcp::socket(*service_));
+               acceptor_.async_accept(*socket, std::bind(&implementation::handle_accept, shared_from_this(), socket, std::placeholders::_1));
     }
 
        void handle_accept(const spl::shared_ptr<tcp::socket>& socket, const boost::system::error_code& error) 
@@ -324,29 +329,48 @@ struct AsyncEventServer::implementation
                
         if (!error)
                {
-                       auto conn = connection::create(socket, protocol_factory_, connection_set_);
+                       boost::system::error_code ec;
+                       socket->set_option(boost::asio::socket_base::keep_alive(true), ec);
+
+                       if (ec)
+                               CASPAR_LOG(warning) << print() << L" Failed to enable TCP keep-alive on socket";
+                       
+                       auto conn = connection::create(service_, socket, protocol_factory_, connection_set_);
                        connection_set_->insert(conn);
 
                        for (auto& lifecycle_factory : lifecycle_factories_)
                        {
-                               auto lifecycle_bound = lifecycle_factory(conn->ipv4_address());
+                               auto lifecycle_bound = lifecycle_factory(u8(conn->ipv4_address()));
                                conn->add_lifecycle_bound_object(lifecycle_bound.first, lifecycle_bound.second);
                        }
                }
                start_accept();
     }
 
+       std::wstring print() const
+       {
+               return L"async_event_server[:" + boost::lexical_cast<std::wstring>(acceptor_.local_endpoint().port()) + L"]";
+       }
+
        void add_client_lifecycle_object_factory(const lifecycle_factory_t& factory)
        {
-               service_.post([=]{ lifecycle_factories_.push_back(factory); });
+               auto self = shared_from_this();
+               service_->post([=]{ self->lifecycle_factories_.push_back(factory); });
        }
 };
 
 AsyncEventServer::AsyncEventServer(
-               const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
-       : impl_(new implementation(protocol, port)) {}
+               std::shared_ptr<boost::asio::io_service> service, const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
+       : impl_(new implementation(std::move(service), protocol, port))
+{
+       impl_->start_accept();
+}
+
+AsyncEventServer::~AsyncEventServer()
+{
+       impl_->stop();
+}
 
-AsyncEventServer::~AsyncEventServer() {}
 void AsyncEventServer::add_client_lifecycle_object_factory(const lifecycle_factory_t& factory) { impl_->add_client_lifecycle_object_factory(factory); }
 
-}}
\ No newline at end of file
+}}