X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=protocol%2Futil%2FAsyncEventServer.cpp;h=d404becdbb9c5711f5d6f2203a8a14864ef49441;hb=f036a5c1bf55701cfcd2317c34302f837204db50;hp=614593f7c160929fa3c64775573ca0a64da3658a;hpb=52ec6eb600439392872d60b1505f85fb79203173;p=casparcg diff --git a/protocol/util/AsyncEventServer.cpp b/protocol/util/AsyncEventServer.cpp index 614593f7c..d404becdb 100644 --- a/protocol/util/AsyncEventServer.cpp +++ b/protocol/util/AsyncEventServer.cpp @@ -19,7 +19,7 @@ * Author: Robert Nagy, ronag89@gmail.com */ -#include "..\stdafx.h" +#include "../StdAfx.h" #include "AsyncEventServer.h" @@ -28,12 +28,14 @@ #include #include #include +#include #include -#include #include -#include +#include +#include +#include using boost::asio::ip::tcp; @@ -45,141 +47,180 @@ typedef std::set> connection_set; class connection : public spl::enable_shared_from_this { + typedef tbb::concurrent_hash_map> lifecycle_map_type; + typedef tbb::concurrent_queue send_queue; + const spl::shared_ptr socket_; - boost::asio::io_service& service_; + std::shared_ptr service_; + const std::wstring listen_port_; const spl::shared_ptr connection_set_; - const std::wstring name_; protocol_strategy_factory::ptr protocol_factory_; std::shared_ptr> protocol_; std::array data_; - std::map> lifecycle_bound_objects_; + lifecycle_map_type lifecycle_bound_objects_; + send_queue send_queue_; + bool is_writing_; class connection_holder : public client_connection { std::weak_ptr connection_; public: - explicit connection_holder(std::weak_ptr conn) : connection_(conn) + explicit connection_holder(std::weak_ptr conn) : connection_(std::move(conn)) {} - virtual void send(std::basic_string&& data) + void send(std::basic_string&& data) override { - //TODO: need to implement a send-queue auto conn = connection_.lock(); - conn->send(std::move(data)); + + if (conn) + conn->send(std::move(data)); } - virtual void disconnect() + + void disconnect() override { auto conn = connection_.lock(); - conn->disconnect(); + + if (conn) + conn->disconnect(); } - virtual std::wstring print() const + + std::wstring address() const override { auto conn = connection_.lock(); - return conn->print(); + + if (conn) + return conn->ipv4_address(); + else + return L"[destroyed-connection]"; } - virtual void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr& lifecycle_bound) + void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr& lifecycle_bound) override { auto conn = connection_.lock(); - return conn->add_lifecycle_bound_object(key, lifecycle_bound); + + if (conn) + return conn->add_lifecycle_bound_object(key, lifecycle_bound); } - virtual std::shared_ptr remove_lifecycle_bound_object(const std::wstring& key) + + std::shared_ptr remove_lifecycle_bound_object(const std::wstring& key) override { auto conn = connection_.lock(); - return conn->remove_lifecycle_bound_object(key); + + if (conn) + return conn->remove_lifecycle_bound_object(key); + else + return std::shared_ptr(); } }; public: - static spl::shared_ptr create(spl::shared_ptr socket, const protocol_strategy_factory::ptr& protocol, spl::shared_ptr connection_set) + static spl::shared_ptr create(std::shared_ptr service, spl::shared_ptr socket, const protocol_strategy_factory::ptr& protocol, spl::shared_ptr connection_set) { - spl::shared_ptr con(new connection(std::move(socket), std::move(protocol), std::move(connection_set))); + spl::shared_ptr con(new connection(std::move(service), std::move(socket), std::move(protocol), std::move(connection_set))); + con->init(); con->read_some(); return con; } + void init() + { + protocol_ = protocol_factory_->create(spl::make_shared(shared_from_this())); + } + ~connection() { - CASPAR_LOG(info) << print() << L" connection destroyed."; + CASPAR_LOG(debug) << print() << L" connection destroyed."; } std::wstring print() const { - return L"[" + name_ + L"]"; + return L"async_event_server[:" + listen_port_ + L"]"; + } + + std::wstring address() const + { + return u16(socket_->local_endpoint().address().to_string()); } - virtual void send(std::string&& data) + std::wstring ipv4_address() const { - write_some(std::move(data)); + return socket_->is_open() ? u16(socket_->remote_endpoint().address().to_string()) : L"no-address"; } - virtual void disconnect() + void send(std::string&& data) { - service_.dispatch([=] { stop(); }); + send_queue_.push(std::move(data)); + auto self = shared_from_this(); + service_->dispatch([=] { self->do_write(); }); } - void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr lifecycle_bound) + void disconnect() { - //TODO: needs protection from evil concurrent access - //tbb::concurrent_hash_map ? + auto self = shared_from_this(); + service_->dispatch([=] { self->stop(); }); + } + + void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr& lifecycle_bound) + { + //thread-safe tbb_concurrent_hash_map lifecycle_bound_objects_.insert(std::pair>(key, lifecycle_bound)); } std::shared_ptr remove_lifecycle_bound_object(const std::wstring& key) { - //TODO: needs protection from evil concurrent access - //tbb::concurrent_hash_map ? - auto it = lifecycle_bound_objects_.find(key); - if(it != lifecycle_bound_objects_.end()) + //thread-safe tbb_concurrent_hash_map + lifecycle_map_type::const_accessor acc; + if(lifecycle_bound_objects_.find(acc, key)) { - auto result = (*it).second; - lifecycle_bound_objects_.erase(it); + auto result = acc->second; + lifecycle_bound_objects_.erase(acc); return result; } return std::shared_ptr(); } - /**************/ private: - void stop() + void do_write() //always called from the asio-service-thread + { + if(!is_writing_) + { + std::string data; + if(send_queue_.try_pop(data)) + { + write_some(std::move(data)); + } + } + } + + void stop() //always called from the asio-service-thread { connection_set_->erase(shared_from_this()); + + CASPAR_LOG(info) << print() << L" Client " << ipv4_address() << L" disconnected (" << connection_set_->size() << L" connections)."; + try { + socket_->cancel(); socket_->close(); } catch(...) { CASPAR_LOG_CURRENT_EXCEPTION(); } - - CASPAR_LOG(info) << print() << L" Disconnected."; - } - - const std::string ipv4_address() const - { - return socket_->is_open() ? socket_->local_endpoint().address().to_string() : "no-address"; } - connection(const spl::shared_ptr& socket, const protocol_strategy_factory::ptr& protocol_factory, const spl::shared_ptr& connection_set) + connection(const std::shared_ptr& service, const spl::shared_ptr& socket, const protocol_strategy_factory::ptr& protocol_factory, const spl::shared_ptr& connection_set) : socket_(socket) - , service_(socket->get_io_service()) - , name_((socket_->is_open() ? u16(socket_->local_endpoint().address().to_string() + ":" + boost::lexical_cast(socket_->local_endpoint().port())) : L"no-address")) + , service_(service) + , listen_port_(socket_->is_open() ? boost::lexical_cast(socket_->local_endpoint().port()) : L"no-port") , connection_set_(connection_set) , protocol_factory_(protocol_factory) + , is_writing_(false) { - CASPAR_LOG(info) << print() << L" Connected."; + CASPAR_LOG(info) << print() << L" Accepted connection from " << ipv4_address() << L" (" << (connection_set_->size() + 1) << L" connections)."; } - - protocol_strategy& protocol() - { - if (!protocol_) - protocol_ = protocol_factory_->create(spl::make_shared(shared_from_this())); - - return *protocol_; - } - void handle_read(const boost::system::error_code& error, size_t bytes_transferred) + void handle_read(const boost::system::error_code& error, size_t bytes_transferred) //always called from the asio-service-thread { if(!error) { @@ -187,9 +228,7 @@ private: { std::string data(data_.begin(), data_.begin() + bytes_transferred); - CASPAR_LOG(trace) << print() << L" Received: " << u16(data); - - protocol().parse(data); + protocol_->parse(data); } catch(...) { @@ -202,21 +241,33 @@ private: stop(); } - void handle_write(const spl::shared_ptr& data, const boost::system::error_code& error, size_t bytes_transferred) + void handle_write(const spl::shared_ptr& str, const boost::system::error_code& error, size_t bytes_transferred) //always called from the asio-service-thread { - if(!error) - CASPAR_LOG(trace) << print() << L" Sent: " << (data->size() < 512 ? u16(*data) : L"more than 512 bytes."); - else if (error != boost::asio::error::operation_aborted) + if(!error) + { + if(bytes_transferred != str->size()) + { + str->assign(str->substr(bytes_transferred)); + socket_->async_write_some(boost::asio::buffer(str->data(), str->size()), std::bind(&connection::handle_write, shared_from_this(), str, std::placeholders::_1, std::placeholders::_2)); + } + else + { + is_writing_ = false; + do_write(); + } + } + else if (error != boost::asio::error::operation_aborted && socket_->is_open()) stop(); } - void read_some() + void read_some() //always called from the asio-service-thread { socket_->async_read_some(boost::asio::buffer(data_.data(), data_.size()), std::bind(&connection::handle_read, shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } - void write_some(std::string&& data) + void write_some(std::string&& data) //always called from the asio-service-thread { + is_writing_ = true; auto str = spl::make_shared(std::move(data)); socket_->async_write_some(boost::asio::buffer(str->data(), str->size()), std::bind(&connection::handle_write, shared_from_this(), str, std::placeholders::_1, std::placeholders::_2)); } @@ -224,49 +275,51 @@ private: friend struct AsyncEventServer::implementation; }; -struct AsyncEventServer::implementation +struct AsyncEventServer::implementation : public spl::enable_shared_from_this { - boost::asio::io_service service_; - tcp::acceptor acceptor_; - protocol_strategy_factory::ptr protocol_factory_; - spl::shared_ptr connection_set_; - boost::thread thread_; - std::vector lifecycle_factories_; - tbb::mutex mutex_; - - implementation(const protocol_strategy_factory::ptr& protocol, unsigned short port) - : acceptor_(service_, tcp::endpoint(tcp::v4(), port)) + std::shared_ptr service_; + tcp::acceptor acceptor_; + protocol_strategy_factory::ptr protocol_factory_; + spl::shared_ptr connection_set_; + std::vector lifecycle_factories_; + tbb::mutex mutex_; + + implementation(std::shared_ptr service, const protocol_strategy_factory::ptr& protocol, unsigned short port) + : service_(std::move(service)) + , acceptor_(*service_, tcp::endpoint(tcp::v4(), port)) , protocol_factory_(protocol) - , thread_(std::bind(&boost::asio::io_service::run, &service_)) { - start_accept(); } - ~implementation() + void stop() { try { - acceptor_.close(); + acceptor_.cancel(); + acceptor_.close(); } - catch(...) + catch (...) { CASPAR_LOG_CURRENT_EXCEPTION(); } + } + + ~implementation() + { + auto conns_set = connection_set_; - service_.post([=] + service_->post([conns_set] { - auto connections = *connection_set_; - BOOST_FOREACH(auto& connection, connections) - connection->stop(); + auto connections = *conns_set; + for (auto& connection : connections) + connection->stop(); }); - - thread_.join(); } - void start_accept() + void start_accept() { - spl::shared_ptr socket(new tcp::socket(service_)); - acceptor_.async_accept(*socket, std::bind(&implementation::handle_accept, this, socket, std::placeholders::_1)); + spl::shared_ptr socket(new tcp::socket(*service_)); + acceptor_.async_accept(*socket, std::bind(&implementation::handle_accept, shared_from_this(), socket, std::placeholders::_1)); } void handle_accept(const spl::shared_ptr& socket, const boost::system::error_code& error) @@ -276,29 +329,48 @@ struct AsyncEventServer::implementation if (!error) { - auto conn = connection::create(socket, protocol_factory_, connection_set_); + boost::system::error_code ec; + socket->set_option(boost::asio::socket_base::keep_alive(true), ec); + + if (ec) + CASPAR_LOG(warning) << print() << L" Failed to enable TCP keep-alive on socket"; + + auto conn = connection::create(service_, socket, protocol_factory_, connection_set_); connection_set_->insert(conn); - BOOST_FOREACH(auto& lifecycle_factory, lifecycle_factories_) + for (auto& lifecycle_factory : lifecycle_factories_) { - auto lifecycle_bound = lifecycle_factory(conn->ipv4_address()); + auto lifecycle_bound = lifecycle_factory(u8(conn->ipv4_address())); conn->add_lifecycle_bound_object(lifecycle_bound.first, lifecycle_bound.second); } } start_accept(); } + std::wstring print() const + { + return L"async_event_server[:" + boost::lexical_cast(acceptor_.local_endpoint().port()) + L"]"; + } + void add_client_lifecycle_object_factory(const lifecycle_factory_t& factory) { - service_.post([=]{ lifecycle_factories_.push_back(factory); }); + auto self = shared_from_this(); + service_->post([=]{ self->lifecycle_factories_.push_back(factory); }); } }; AsyncEventServer::AsyncEventServer( - const protocol_strategy_factory::ptr& protocol, unsigned short port) - : impl_(new implementation(protocol, port)) {} + std::shared_ptr service, const protocol_strategy_factory::ptr& protocol, unsigned short port) + : impl_(new implementation(std::move(service), protocol, port)) +{ + impl_->start_accept(); +} + +AsyncEventServer::~AsyncEventServer() +{ + impl_->stop(); +} -AsyncEventServer::~AsyncEventServer() {} void AsyncEventServer::add_client_lifecycle_object_factory(const lifecycle_factory_t& factory) { impl_->add_client_lifecycle_object_factory(factory); } -}} \ No newline at end of file +}}