]> git.sesse.net Git - casparcg/blob - protocol/util/AsyncEventServer.cpp
[AsyncEventServer] Fixed bug where server expected to be the one closing the socket...
[casparcg] / protocol / util / AsyncEventServer.cpp
1 /*
2 * Copyright (c) 2011 Sveriges Television AB <info@casparcg.com>
3 *
4 * This file is part of CasparCG (www.casparcg.com).
5 *
6 * CasparCG is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * CasparCG is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with CasparCG. If not, see <http://www.gnu.org/licenses/>.
18 *
19 * Author: Robert Nagy, ronag89@gmail.com
20 */
21
22 #include "../StdAfx.h"
23
24 #include "AsyncEventServer.h"
25
26 #include <algorithm>
27 #include <array>
28 #include <string>
29 #include <set>
30 #include <memory>
31 #include <functional>
32
33 #include <boost/asio.hpp>
34 #include <boost/lexical_cast.hpp>
35
36 #include <tbb/mutex.h>
37 #include <tbb/concurrent_hash_map.h>
38 #include <tbb/concurrent_queue.h>
39
40 using boost::asio::ip::tcp;
41
42 namespace caspar { namespace IO {
43
44 class connection;
45
46 typedef std::set<spl::shared_ptr<connection>> connection_set;
47
48 class connection : public spl::enable_shared_from_this<connection>
49 {
50         typedef tbb::concurrent_hash_map<std::wstring, std::shared_ptr<void>> lifecycle_map_type;
51         typedef tbb::concurrent_queue<std::string>      send_queue;
52
53     const spl::shared_ptr<tcp::socket>                          socket_;
54         std::shared_ptr<boost::asio::io_service>                service_;
55         const std::wstring                                                              listen_port_;
56         const spl::shared_ptr<connection_set>                   connection_set_;
57         protocol_strategy_factory<char>::ptr                    protocol_factory_;
58         std::shared_ptr<protocol_strategy<char>>                protocol_;
59
60         std::array<char, 32768>                                                 data_;
61         lifecycle_map_type                                                              lifecycle_bound_objects_;
62         send_queue                                                                              send_queue_;
63         bool                                                                                    is_writing_;
64
65         class connection_holder : public client_connection<char>
66         {
67                 std::weak_ptr<connection> connection_;
68         public:
69                 explicit connection_holder(std::weak_ptr<connection> conn) : connection_(std::move(conn))
70                 {}
71
72                 void send(std::basic_string<char>&& data) override
73                 {
74                         auto conn = connection_.lock();
75
76                         if (conn)
77                                 conn->send(std::move(data));
78                 }
79
80                 void disconnect() override
81                 {
82                         auto conn = connection_.lock();
83
84                         if (conn)
85                                 conn->disconnect();
86                 }
87
88                 std::wstring address() const override
89                 {
90                         auto conn = connection_.lock();
91
92                         if (conn)
93                                 return conn->ipv4_address();
94                         else
95                                 return L"[destroyed-connection]";
96                 }
97
98                 void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void>& lifecycle_bound) override
99                 {
100                         auto conn = connection_.lock();
101
102                         if (conn)
103                                 return conn->add_lifecycle_bound_object(key, lifecycle_bound);
104                 }
105
106                 std::shared_ptr<void> remove_lifecycle_bound_object(const std::wstring& key) override
107                 {
108                         auto conn = connection_.lock();
109
110                         if (conn)
111                                 return conn->remove_lifecycle_bound_object(key);
112                         else
113                                 return std::shared_ptr<void>();
114                 }
115         };
116
117 public:
118         static spl::shared_ptr<connection> create(std::shared_ptr<boost::asio::io_service> service, spl::shared_ptr<tcp::socket> socket, const protocol_strategy_factory<char>::ptr& protocol, spl::shared_ptr<connection_set> connection_set)
119         {
120                 spl::shared_ptr<connection> con(new connection(std::move(service), std::move(socket), std::move(protocol), std::move(connection_set)));
121                 con->init();
122                 con->read_some();
123                 return con;
124     }
125
126         void init()
127         {
128                 protocol_ = protocol_factory_->create(spl::make_shared<connection_holder>(shared_from_this()));
129         }
130
131         ~connection()
132         {
133                 CASPAR_LOG(debug) << print() << L" connection destroyed.";
134         }
135
136         std::wstring print() const
137         {
138                 return L"async_event_server[:" + listen_port_ + L"]";
139         }
140
141         std::wstring address() const
142         {
143                 return u16(socket_->local_endpoint().address().to_string());
144         }
145
146         std::wstring ipv4_address() const
147         {
148                 return socket_->is_open() ? u16(socket_->remote_endpoint().address().to_string()) : L"no-address";
149         }
150
151         void send(std::string&& data)
152         {
153                 send_queue_.push(std::move(data));
154                 auto self = shared_from_this();
155                 service_->dispatch([=] { self->do_write(); });
156         }
157
158         void disconnect()
159         {
160                 auto self = shared_from_this();
161                 service_->dispatch([=] { self->stop(); });
162         }
163
164         void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void>& lifecycle_bound)
165         {
166                 //thread-safe tbb_concurrent_hash_map
167                 lifecycle_bound_objects_.insert(std::pair<std::wstring, std::shared_ptr<void>>(key, lifecycle_bound));
168         }
169         std::shared_ptr<void> remove_lifecycle_bound_object(const std::wstring& key)
170         {
171                 //thread-safe tbb_concurrent_hash_map
172                 lifecycle_map_type::const_accessor acc;
173                 if(lifecycle_bound_objects_.find(acc, key))
174                 {
175                         auto result = acc->second;
176                         lifecycle_bound_objects_.erase(acc);
177                         return result;
178                 }
179                 return std::shared_ptr<void>();
180         }
181
182 private:
183         void do_write() //always called from the asio-service-thread
184         {
185                 if(!is_writing_)
186                 {
187                         std::string data;
188                         if(send_queue_.try_pop(data))
189                         {
190                                 write_some(std::move(data));
191                         }
192                 }
193         }
194
195         void stop()     //always called from the asio-service-thread
196         {
197                 connection_set_->erase(shared_from_this());
198
199                 CASPAR_LOG(info) << print() << L" Client " << ipv4_address() << L" disconnected (" << connection_set_->size() << L" connections).";
200
201                 boost::system::error_code ec;
202                 socket_->shutdown(boost::asio::socket_base::shutdown_type::shutdown_both, ec);
203                 socket_->close(ec);
204         }
205
206     connection(const std::shared_ptr<boost::asio::io_service>& service, const spl::shared_ptr<tcp::socket>& socket, const protocol_strategy_factory<char>::ptr& protocol_factory, const spl::shared_ptr<connection_set>& connection_set)
207                 : socket_(socket)
208                 , service_(service)
209                 , listen_port_(socket_->is_open() ? boost::lexical_cast<std::wstring>(socket_->local_endpoint().port()) : L"no-port")
210                 , connection_set_(connection_set)
211                 , protocol_factory_(protocol_factory)
212                 , is_writing_(false)
213         {
214                 CASPAR_LOG(info) << print() << L" Accepted connection from " << ipv4_address() << L" (" << (connection_set_->size() + 1) << L" connections).";
215     }
216
217     void handle_read(const boost::system::error_code& error, size_t bytes_transferred)  //always called from the asio-service-thread
218         {
219                 if(!error)
220                 {
221                         try
222                         {
223                                 std::string data(data_.begin(), data_.begin() + bytes_transferred);
224
225                                 protocol_->parse(data);
226                         }
227                         catch(...)
228                         {
229                                 CASPAR_LOG_CURRENT_EXCEPTION();
230                         }
231
232                         read_some();
233                 }
234                 else if (error != boost::asio::error::operation_aborted)
235                         stop();
236     }
237
238     void handle_write(const spl::shared_ptr<std::string>& str, const boost::system::error_code& error, size_t bytes_transferred)        //always called from the asio-service-thread
239         {
240                 if(!error)
241                 {
242                         if(bytes_transferred != str->size())
243                         {
244                                 str->assign(str->substr(bytes_transferred));
245                                 socket_->async_write_some(boost::asio::buffer(str->data(), str->size()), std::bind(&connection::handle_write, shared_from_this(), str, std::placeholders::_1, std::placeholders::_2));
246                         }
247                         else
248                         {
249                                 is_writing_ = false;
250                                 do_write();
251                         }
252                 }
253                 else if (error != boost::asio::error::operation_aborted && socket_->is_open())
254                         stop();
255     }
256
257         void read_some()        //always called from the asio-service-thread
258         {
259                 socket_->async_read_some(boost::asio::buffer(data_.data(), data_.size()), std::bind(&connection::handle_read, shared_from_this(), std::placeholders::_1, std::placeholders::_2));
260         }
261
262         void write_some(std::string&& data)     //always called from the asio-service-thread
263         {
264                 is_writing_ = true;
265                 auto str = spl::make_shared<std::string>(std::move(data));
266                 socket_->async_write_some(boost::asio::buffer(str->data(), str->size()), std::bind(&connection::handle_write, shared_from_this(), str, std::placeholders::_1, std::placeholders::_2));
267         }
268
269         friend struct AsyncEventServer::implementation;
270 };
271
272 struct AsyncEventServer::implementation : public spl::enable_shared_from_this<implementation>
273 {
274         std::shared_ptr<boost::asio::io_service>        service_;
275         tcp::acceptor                                                           acceptor_;
276         protocol_strategy_factory<char>::ptr            protocol_factory_;
277         spl::shared_ptr<connection_set>                         connection_set_;
278         std::vector<lifecycle_factory_t>                        lifecycle_factories_;
279         tbb::mutex                                                                      mutex_;
280
281         implementation(std::shared_ptr<boost::asio::io_service> service, const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
282                 : service_(std::move(service))
283                 , acceptor_(*service_, tcp::endpoint(tcp::v4(), port))
284                 , protocol_factory_(protocol)
285         {
286         }
287
288         void stop()
289         {
290                 try
291                 {
292                         acceptor_.cancel();
293                         acceptor_.close();
294                 }
295                 catch (...)
296                 {
297                         CASPAR_LOG_CURRENT_EXCEPTION();
298                 }
299         }
300
301         ~implementation()
302         {
303                 auto conns_set = connection_set_;
304
305                 service_->post([conns_set]
306                 {
307                         auto connections = *conns_set;
308                         for (auto& connection : connections)
309                                 connection->stop();
310                 });
311         }
312
313         void start_accept()
314         {
315                 spl::shared_ptr<tcp::socket> socket(new tcp::socket(*service_));
316                 acceptor_.async_accept(*socket, std::bind(&implementation::handle_accept, shared_from_this(), socket, std::placeholders::_1));
317     }
318
319         void handle_accept(const spl::shared_ptr<tcp::socket>& socket, const boost::system::error_code& error)
320         {
321                 if (!acceptor_.is_open())
322                         return;
323
324         if (!error)
325                 {
326                         boost::system::error_code ec;
327                         socket->set_option(boost::asio::socket_base::keep_alive(true), ec);
328
329                         if (ec)
330                                 CASPAR_LOG(warning) << print() << L" Failed to enable TCP keep-alive on socket";
331
332                         auto conn = connection::create(service_, socket, protocol_factory_, connection_set_);
333                         connection_set_->insert(conn);
334
335                         for (auto& lifecycle_factory : lifecycle_factories_)
336                         {
337                                 auto lifecycle_bound = lifecycle_factory(u8(conn->ipv4_address()));
338                                 conn->add_lifecycle_bound_object(lifecycle_bound.first, lifecycle_bound.second);
339                         }
340                 }
341                 start_accept();
342     }
343
344         std::wstring print() const
345         {
346                 return L"async_event_server[:" + boost::lexical_cast<std::wstring>(acceptor_.local_endpoint().port()) + L"]";
347         }
348
349         void add_client_lifecycle_object_factory(const lifecycle_factory_t& factory)
350         {
351                 auto self = shared_from_this();
352                 service_->post([=]{ self->lifecycle_factories_.push_back(factory); });
353         }
354 };
355
356 AsyncEventServer::AsyncEventServer(
357                 std::shared_ptr<boost::asio::io_service> service, const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
358         : impl_(new implementation(std::move(service), protocol, port))
359 {
360         impl_->start_accept();
361 }
362
363 AsyncEventServer::~AsyncEventServer()
364 {
365         impl_->stop();
366 }
367
368 void AsyncEventServer::add_client_lifecycle_object_factory(const lifecycle_factory_t& factory) { impl_->add_client_lifecycle_object_factory(factory); }
369
370 }}