]> git.sesse.net Git - casparcg/blob - protocol/util/AsyncEventServer.cpp
* made lifecycle_bound_objects_ thread-safe
[casparcg] / protocol / util / AsyncEventServer.cpp
1 /*
2 * Copyright (c) 2011 Sveriges Television AB <info@casparcg.com>
3 *
4 * This file is part of CasparCG (www.casparcg.com).
5 *
6 * CasparCG is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * CasparCG is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with CasparCG. If not, see <http://www.gnu.org/licenses/>.
18 *
19 * Author: Robert Nagy, ronag89@gmail.com
20 */
21
22 #include "..\stdafx.h"
23
24 #include "AsyncEventServer.h"
25
26 #include <algorithm>
27 #include <array>
28 #include <string>
29 #include <set>
30 #include <memory>
31
32 #include <boost/asio.hpp>
33 #include <boost/thread.hpp>
34 #include <boost/lexical_cast.hpp>
35
36 #include <tbb\mutex.h>
37 #include <tbb\concurrent_hash_map.h>
38
39 using boost::asio::ip::tcp;
40
41 namespace caspar { namespace IO {
42         
43 class connection;
44
45 typedef std::set<spl::shared_ptr<connection>> connection_set;
46
47 class connection : public spl::enable_shared_from_this<connection>
48 {   
49         typedef tbb::concurrent_hash_map<std::wstring, std::shared_ptr<void>> lifecycle_map_type;
50
51     const spl::shared_ptr<tcp::socket>                          socket_; 
52         boost::asio::io_service&                                                service_;
53         const spl::shared_ptr<connection_set>                   connection_set_;
54         const std::wstring                                                              name_;
55         protocol_strategy_factory<char>::ptr                    protocol_factory_;
56         std::shared_ptr<protocol_strategy<char>>                protocol_;
57
58         std::array<char, 32768>                                                 data_;
59         //std::map<std::wstring, std::shared_ptr<void>> lifecycle_bound_objects_;
60         lifecycle_map_type lifecycle_bound_objects_;
61
62         class connection_holder : public client_connection<char>
63         {
64                 std::weak_ptr<connection> connection_;
65         public:
66                 explicit connection_holder(std::weak_ptr<connection> conn) : connection_(conn)
67                 {}
68
69                 virtual void send(std::basic_string<char>&& data)
70                 {
71                         //TODO: need to implement a send-queue
72                         auto conn = connection_.lock();
73                         conn->send(std::move(data));
74                 }
75                 virtual void disconnect()
76                 {
77                         auto conn = connection_.lock();
78                         conn->disconnect();
79                 }
80                 virtual std::wstring print() const
81                 {
82                         auto conn = connection_.lock();
83                         return conn->print();
84                 }
85
86                 virtual void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void>& lifecycle_bound)
87                 {
88                         auto conn = connection_.lock();
89                         return conn->add_lifecycle_bound_object(key, lifecycle_bound);
90                 }
91                 virtual std::shared_ptr<void> remove_lifecycle_bound_object(const std::wstring& key)
92                 {
93                         auto conn = connection_.lock();
94                         return conn->remove_lifecycle_bound_object(key);
95                 }
96         };
97
98 public:
99     static spl::shared_ptr<connection> create(spl::shared_ptr<tcp::socket> socket, const protocol_strategy_factory<char>::ptr& protocol, spl::shared_ptr<connection_set> connection_set)
100         {
101                 spl::shared_ptr<connection> con(new connection(std::move(socket), std::move(protocol), std::move(connection_set)));
102                 con->read_some();
103                 return con;
104     }
105
106         ~connection()
107         {
108                 CASPAR_LOG(info) << print() << L" connection destroyed.";
109         }
110
111         std::wstring print() const
112         {
113                 return L"[" + name_ + L"]";
114         }
115         
116         virtual void send(std::string&& data)
117         {
118                 write_some(std::move(data));
119         }
120
121         virtual void disconnect()
122         {
123                 service_.dispatch([=] { stop(); });
124         }
125
126         void add_lifecycle_bound_object(const std::wstring& key, const std::shared_ptr<void> lifecycle_bound)
127         {
128                 //thread-safe tbb_concurrent_hash_map
129                 lifecycle_bound_objects_.insert(std::pair<std::wstring, std::shared_ptr<void>>(key, lifecycle_bound));
130         }
131         std::shared_ptr<void> remove_lifecycle_bound_object(const std::wstring& key)
132         {
133                 //thread-safe tbb_concurrent_hash_map
134                 lifecycle_map_type::const_accessor acc;
135                 if(lifecycle_bound_objects_.find(acc, key))
136                 {
137                         auto result = acc->second;
138                         lifecycle_bound_objects_.erase(acc);
139                         return result;
140                 }
141                 return std::shared_ptr<void>();
142         }
143
144         /**************/
145 private:
146         void stop()
147         {
148                 connection_set_->erase(shared_from_this());
149                 try
150                 {
151                         socket_->close();
152                 }
153                 catch(...)
154                 {
155                         CASPAR_LOG_CURRENT_EXCEPTION();
156                 }
157                 
158                 CASPAR_LOG(info) << print() << L" Disconnected.";
159         }
160
161         const std::string ipv4_address() const
162         {
163                 return socket_->is_open() ? socket_->local_endpoint().address().to_string() : "no-address";
164         }
165
166     connection(const spl::shared_ptr<tcp::socket>& socket, const protocol_strategy_factory<char>::ptr& protocol_factory, const spl::shared_ptr<connection_set>& connection_set) 
167                 : socket_(socket)
168                 , service_(socket->get_io_service())
169                 , name_((socket_->is_open() ? u16(socket_->local_endpoint().address().to_string() + ":" + boost::lexical_cast<std::string>(socket_->local_endpoint().port())) : L"no-address"))
170                 , connection_set_(connection_set)
171                 , protocol_factory_(protocol_factory)
172         {
173                 CASPAR_LOG(info) << print() << L" Connected.";
174     }
175
176         protocol_strategy<char>& protocol()
177         {
178                 if (!protocol_)
179                         protocol_ = protocol_factory_->create(spl::make_shared<connection_holder>(shared_from_this()));
180
181                 return *protocol_;
182         }
183                         
184     void handle_read(const boost::system::error_code& error, size_t bytes_transferred) 
185         {               
186                 if(!error)
187                 {
188                         try
189                         {
190                                 std::string data(data_.begin(), data_.begin() + bytes_transferred);
191
192                                 CASPAR_LOG(trace) << print() << L" Received: " << u16(data);
193
194                                 protocol().parse(data);
195                         }
196                         catch(...)
197                         {
198                                 CASPAR_LOG_CURRENT_EXCEPTION();
199                         }
200                         
201                         read_some();
202                 }  
203                 else if (error != boost::asio::error::operation_aborted)
204                         stop();         
205     }
206
207     void handle_write(const spl::shared_ptr<std::string>& data, const boost::system::error_code& error, size_t bytes_transferred)
208         {
209                 if(!error)
210                 {
211                         CASPAR_LOG(trace) << print() << L" Sent: " << (data->size() < 512 ? u16(*data) : L"more than 512 bytes.");              
212                 }
213                 else if (error != boost::asio::error::operation_aborted)                
214                         stop();
215     }
216
217         void read_some()
218         {
219                 socket_->async_read_some(boost::asio::buffer(data_.data(), data_.size()), std::bind(&connection::handle_read, shared_from_this(), std::placeholders::_1, std::placeholders::_2));
220         }
221         
222         void write_some(std::string&& data)
223         {
224                 auto str = spl::make_shared<std::string>(std::move(data));
225                 socket_->async_write_some(boost::asio::buffer(str->data(), str->size()), std::bind(&connection::handle_write, shared_from_this(), str, std::placeholders::_1, std::placeholders::_2));
226         }
227
228         friend struct AsyncEventServer::implementation;
229 };
230
231 struct AsyncEventServer::implementation
232 {
233         boost::asio::io_service                                 service_;
234         tcp::acceptor                                                   acceptor_;
235         protocol_strategy_factory<char>::ptr    protocol_factory_;
236         spl::shared_ptr<connection_set>                 connection_set_;
237         boost::thread                                                   thread_;
238         std::vector<lifecycle_factory_t>                lifecycle_factories_;
239         tbb::mutex mutex_;
240
241         implementation(const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
242                 : acceptor_(service_, tcp::endpoint(tcp::v4(), port))
243                 , protocol_factory_(protocol)
244                 , thread_(std::bind(&boost::asio::io_service::run, &service_))
245         {
246                 start_accept();
247         }
248
249         ~implementation()
250         {
251                 try
252                 {
253                         acceptor_.close();                      
254                 }
255                 catch(...)
256                 {
257                         CASPAR_LOG_CURRENT_EXCEPTION();
258                 }
259
260                 service_.post([=]
261                 {
262                         auto connections = *connection_set_;
263                         BOOST_FOREACH(auto& connection, connections)
264                                 connection->stop();                             
265                 });
266
267                 thread_.join();
268         }
269                 
270     void start_accept() 
271         {
272                 spl::shared_ptr<tcp::socket> socket(new tcp::socket(service_));
273                 acceptor_.async_accept(*socket, std::bind(&implementation::handle_accept, this, socket, std::placeholders::_1));
274     }
275
276         void handle_accept(const spl::shared_ptr<tcp::socket>& socket, const boost::system::error_code& error) 
277         {
278                 if (!acceptor_.is_open())
279                         return;
280                 
281         if (!error)
282                 {
283                         auto conn = connection::create(socket, protocol_factory_, connection_set_);
284                         connection_set_->insert(conn);
285
286                         BOOST_FOREACH(auto& lifecycle_factory, lifecycle_factories_)
287                         {
288                                 auto lifecycle_bound = lifecycle_factory(conn->ipv4_address());
289                                 conn->add_lifecycle_bound_object(lifecycle_bound.first, lifecycle_bound.second);
290                         }
291                 }
292                 start_accept();
293     }
294
295         void add_client_lifecycle_object_factory(const lifecycle_factory_t& factory)
296         {
297                 service_.post([=]{ lifecycle_factories_.push_back(factory); });
298         }
299 };
300
301 AsyncEventServer::AsyncEventServer(
302                 const protocol_strategy_factory<char>::ptr& protocol, unsigned short port)
303         : impl_(new implementation(protocol, port)) {}
304
305 AsyncEventServer::~AsyncEventServer() {}
306 void AsyncEventServer::add_client_lifecycle_object_factory(const lifecycle_factory_t& factory) { impl_->add_client_lifecycle_object_factory(factory); }
307
308 }}