]> git.sesse.net Git - casparcg/blob - protocol/util/AsyncEventServer.cpp
Merged fix for asio::io_service lifetime race condition (sometimes destroyed too...
[casparcg] / protocol / util / AsyncEventServer.cpp
1 /*
2 * Copyright (c) 2011 Sveriges Television AB <info@casparcg.com>
3 *
4 * This file is part of CasparCG (www.casparcg.com).
5 *
6 * CasparCG is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * CasparCG is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with CasparCG. If not, see <http://www.gnu.org/licenses/>.
18 *
19 * Author: Robert Nagy, ronag89@gmail.com
20 */
21
22 #include "../StdAfx.h"
23
24 #include "AsyncEventServer.h"
25
26 #include <algorithm>
27 #include <array>
28 #include <string>
29 #include <set>
30 #include <memory>
31 #include <functional>
32
33 #include <boost/asio.hpp>
34 #include <boost/lexical_cast.hpp>
35
36 #include <tbb/mutex.h>
37 #include <tbb/concurrent_hash_map.h>
38 #include <tbb/concurrent_queue.h>
39
40 using boost::asio::ip::tcp;
41
42 namespace caspar { namespace IO {
43         
44 class connection;
45
46 typedef std::set<spl::shared_ptr<connection>> connection_set;
47
48 class connection : public spl::enable_shared_from_this<connection>
49 {   
50         typedef tbb::concurrent_hash_map<std::wstring, std::shared_ptr<void>> lifecycle_map_type;
51         typedef tbb::concurrent_queue<std::string>      send_queue;
52
53     const spl::shared_ptr<tcp::socket>                          socket_; 
54         boost::asio::io_service&                                                service_;
55         const spl::shared_ptr<connection_set>                   connection_set_;
56         const std::wstring                                                              name_;
57         protocol_strategy_factory<char>::ptr                    protocol_factory_;
58         std::shared_ptr<protocol_strategy<char>>                protocol_;
59
60         std::array<char, 32768>                                                 data_;
61         lifecycle_map_type                                                              lifecycle_bound_objects_;
62         send_queue                                                                              send_queue_;
63         bool                                                                                    is_writing_;
64
65         class connection_holder : public client_connection<char>
66         {
67                 std::weak_ptr<connection> connection_;
68         public:
69                 explicit connection_holder(std::weak_ptr<connection> conn) : connection_(std::move(conn))
70                 {}
71
72                 void send(std::basic_string<char>&& data) override
73                 {
74                         auto conn = connection_.lock();
75
76                         if (conn)
77                                 conn->send(std::move(data));
78                 }
79
80                 void disconnect() override
81                 {
82                         auto conn = connection_.lock();
83
84                         if (conn)
85                                 conn->disconnect();
86                 }
87
88                 std::wstring print() const override
89                 {
90                         auto conn = connection_.lock();
91
92                         if (conn)
93                                 return conn->print();
94                         else
95                                 return L"[destroyed-connection]";
96                 }
97
98                 std::wstring address() const override
99                 {
100                         auto conn = connection_.lock();
101
102                         if (conn)
103                                 return conn->address();
104                         else
105                                 return L"[destroyed-connection]";
106                 }
107
108                 void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void>& lifecycle_bound) override
109                 {
110                         auto conn = connection_.lock();
111
112                         if (conn)
113                                 return conn->add_lifecycle_bound_object(key, lifecycle_bound);
114                 }
115
116                 std::shared_ptr<void> remove_lifecycle_bound_object(const std::wstring& key) override
117                 {
118                         auto conn = connection_.lock();
119
120                         if (conn)
121                                 return conn->remove_lifecycle_bound_object(key);
122                         else
123                                 return std::shared_ptr<void>();
124                 }
125         };
126
127 public:
128     static spl::shared_ptr<connection> create(spl::shared_ptr<tcp::socket> socket, const protocol_strategy_factory<char>::ptr& protocol, spl::shared_ptr<connection_set> connection_set)
129         {
130                 spl::shared_ptr<connection> con(new connection(std::move(socket), std::move(protocol), std::move(connection_set)));
131                 con->read_some();
132                 return con;
133     }
134
135         ~connection()
136         {
137                 CASPAR_LOG(info) << print() << L" connection destroyed.";
138         }
139
140         std::wstring print() const
141         {
142                 return L"[" + name_ + L"]";
143         }
144
145         std::wstring address() const
146         {
147                 return u16(socket_->local_endpoint().address().to_string());
148         }
149         
150         virtual void send(std::string&& data)
151         {
152                 send_queue_.push(std::move(data));
153                 service_.dispatch([=] { do_write(); });
154         }
155
156         virtual void disconnect()
157         {
158                 service_.dispatch([=] { stop(); });
159         }
160
161         void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void>& lifecycle_bound)
162         {
163                 //thread-safe tbb_concurrent_hash_map
164                 lifecycle_bound_objects_.insert(std::pair<std::wstring, std::shared_ptr<void>>(key, lifecycle_bound));
165         }
166         std::shared_ptr<void> remove_lifecycle_bound_object(const std::wstring& key)
167         {
168                 //thread-safe tbb_concurrent_hash_map
169                 lifecycle_map_type::const_accessor acc;
170                 if(lifecycle_bound_objects_.find(acc, key))
171                 {
172                         auto result = acc->second;
173                         lifecycle_bound_objects_.erase(acc);
174                         return result;
175                 }
176                 return std::shared_ptr<void>();
177         }
178
179         /**************/
180 private:
181         void do_write() //always called from the asio-service-thread
182         {
183                 if(!is_writing_)
184                 {
185                         std::string data;
186                         if(send_queue_.try_pop(data))
187                         {
188                                 write_some(std::move(data));
189                         }
190                 }
191         }
192
193         void stop()     //always called from the asio-service-thread
194         {
195                 connection_set_->erase(shared_from_this());
196                 try
197                 {
198                         socket_->close();
199                 }
200                 catch(...)
201                 {
202                         CASPAR_LOG_CURRENT_EXCEPTION();
203                 }
204                 
205                 CASPAR_LOG(info) << print() << L" Disconnected.";
206         }
207
208         const std::string ipv4_address() const
209         {
210                 return socket_->is_open() ? socket_->remote_endpoint().address().to_string() : "no-address";
211         }
212
213     connection(const spl::shared_ptr<tcp::socket>& socket, const protocol_strategy_factory<char>::ptr& protocol_factory, const spl::shared_ptr<connection_set>& connection_set) 
214                 : socket_(socket)
215                 , service_(socket->get_io_service())
216                 , name_((socket_->is_open() ? u16(socket_->local_endpoint().address().to_string() + ":" + boost::lexical_cast<std::string>(socket_->local_endpoint().port())) : L"no-address"))
217                 , connection_set_(connection_set)
218                 , protocol_factory_(protocol_factory)
219                 , is_writing_(false)
220         {
221                 CASPAR_LOG(info) << print() << L" Connected.";
222     }
223
224         protocol_strategy<char>& protocol()     //always called from the asio-service-thread
225         {
226                 if (!protocol_)
227                         protocol_ = protocol_factory_->create(spl::make_shared<connection_holder>(shared_from_this()));
228
229                 return *protocol_;
230         }
231                         
232     void handle_read(const boost::system::error_code& error, size_t bytes_transferred)  //always called from the asio-service-thread
233         {               
234                 if(!error)
235                 {
236                         try
237                         {
238                                 std::string data(data_.begin(), data_.begin() + bytes_transferred);
239
240                                 CASPAR_LOG(trace) << print() << L" Received: " << u16(data);
241
242                                 protocol().parse(data);
243                         }
244                         catch(...)
245                         {
246                                 CASPAR_LOG_CURRENT_EXCEPTION();
247                         }
248                         
249                         read_some();
250                 }  
251                 else if (error != boost::asio::error::operation_aborted)
252                         stop();         
253     }
254
255     void handle_write(const spl::shared_ptr<std::string>& str, const boost::system::error_code& error, size_t bytes_transferred)        //always called from the asio-service-thread
256         {
257                 if(!error)
258                 {
259                         CASPAR_LOG(trace) << print() << L" Sent: " << (str->size() < 512 ? u16(*str) : L"more than 512 bytes.");
260                         if(bytes_transferred != str->size())
261                         {
262                                 str->assign(str->substr(bytes_transferred));
263                                 socket_->async_write_some(boost::asio::buffer(str->data(), str->size()), std::bind(&connection::handle_write, shared_from_this(), str, std::placeholders::_1, std::placeholders::_2));
264                         }
265                         else
266                         {
267                                 is_writing_ = false;
268                                 do_write();
269                         }
270                 }
271                 else if (error != boost::asio::error::operation_aborted)                
272                         stop();
273     }
274
275         void read_some()        //always called from the asio-service-thread
276         {
277                 socket_->async_read_some(boost::asio::buffer(data_.data(), data_.size()), std::bind(&connection::handle_read, shared_from_this(), std::placeholders::_1, std::placeholders::_2));
278         }
279         
280         void write_some(std::string&& data)     //always called from the asio-service-thread
281         {
282                 is_writing_ = true;
283                 auto str = spl::make_shared<std::string>(std::move(data));
284                 socket_->async_write_some(boost::asio::buffer(str->data(), str->size()), std::bind(&connection::handle_write, shared_from_this(), str, std::placeholders::_1, std::placeholders::_2));
285         }
286
287         friend struct AsyncEventServer::implementation;
288 };
289
290 struct AsyncEventServer::implementation
291 {
292         std::shared_ptr<boost::asio::io_service>        service_;
293         tcp::acceptor                                                           acceptor_;
294         protocol_strategy_factory<char>::ptr            protocol_factory_;
295         spl::shared_ptr<connection_set>                         connection_set_;
296         std::vector<lifecycle_factory_t>                        lifecycle_factories_;
297         tbb::mutex mutex_;
298
299         implementation(std::shared_ptr<boost::asio::io_service> service, const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
300                 : service_(std::move(service))
301                 , acceptor_(*service_, tcp::endpoint(tcp::v4(), port))
302                 , protocol_factory_(protocol)
303         {
304                 start_accept();
305         }
306
307         ~implementation()
308         {
309                 try
310                 {
311                         acceptor_.close();                      
312                 }
313                 catch(...)
314                 {
315                         CASPAR_LOG_CURRENT_EXCEPTION();
316                 }
317
318                 service_->post([=]
319                 {
320                         auto connections = *connection_set_;
321                         for (auto& connection : connections)
322                                 connection->stop();                             
323                 });
324         }
325                 
326         void start_accept() 
327         {
328                 spl::shared_ptr<tcp::socket> socket(new tcp::socket(*service_));
329                 acceptor_.async_accept(*socket, std::bind(&implementation::handle_accept, this, socket, std::placeholders::_1));
330     }
331
332         void handle_accept(const spl::shared_ptr<tcp::socket>& socket, const boost::system::error_code& error) 
333         {
334                 if (!acceptor_.is_open())
335                         return;
336                 
337         if (!error)
338                 {
339                         auto conn = connection::create(socket, protocol_factory_, connection_set_);
340                         connection_set_->insert(conn);
341
342                         for (auto& lifecycle_factory : lifecycle_factories_)
343                         {
344                                 auto lifecycle_bound = lifecycle_factory(conn->ipv4_address());
345                                 conn->add_lifecycle_bound_object(lifecycle_bound.first, lifecycle_bound.second);
346                         }
347                 }
348                 start_accept();
349     }
350
351         void add_client_lifecycle_object_factory(const lifecycle_factory_t& factory)
352         {
353                 service_->post([=]{ lifecycle_factories_.push_back(factory); });
354         }
355 };
356
357 AsyncEventServer::AsyncEventServer(
358                 std::shared_ptr<boost::asio::io_service> service, const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
359         : impl_(new implementation(std::move(service), protocol, port)) {}
360
361 AsyncEventServer::~AsyncEventServer() {}
362 void AsyncEventServer::add_client_lifecycle_object_factory(const lifecycle_factory_t& factory) { impl_->add_client_lifecycle_object_factory(factory); }
363
364 }}