]> git.sesse.net Git - casparcg/blob - protocol/amcp/AMCPProtocolStrategy.cpp
140e266e5bcbca03e99a964e06ad6ee7768bbede
[casparcg] / protocol / amcp / AMCPProtocolStrategy.cpp
1 /*
2 * Copyright (c) 2011 Sveriges Television AB <info@casparcg.com>
3 *
4 * This file is part of CasparCG (www.casparcg.com).
5 *
6 * CasparCG is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * CasparCG is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with CasparCG. If not, see <http://www.gnu.org/licenses/>.
18 *
19 * Author: Nicklas P Andersson
20 */
21
22  
23 #include "../StdAfx.h"
24
25 #include "AMCPProtocolStrategy.h"
26 #include "AMCPCommandsImpl.h"
27 #include "amcp_shared.h"
28 #include "AMCPCommand.h"
29 #include "AMCPCommandQueue.h"
30
31 #include <stdio.h>
32 //#include <crtdbg.h>
33 #include <string.h>
34 #include <algorithm>
35 #include <cctype>
36 #include <future>
37
38 #include <boost/algorithm/string/trim.hpp>
39 #include <boost/algorithm/string/split.hpp>
40 #include <boost/algorithm/string/replace.hpp>
41 #include <boost/lexical_cast.hpp>
42
43 #if defined(_MSC_VER)
44 #pragma warning (push, 1) // TODO: Legacy code, just disable warnings
45 #endif
46
47 namespace caspar { namespace protocol { namespace amcp {
48
49 using IO::ClientInfoPtr;
50
51 struct AMCPProtocolStrategy::impl
52 {
53 private:
54         std::vector<channel_context>                                                    channels_;
55         std::vector<AMCPCommandQueue::ptr_type>                                 commandQueues_;
56         std::shared_ptr<core::thumbnail_generator>                              thumb_gen_;
57         spl::shared_ptr<core::media_info_repository>                    media_info_repo_;
58         spl::shared_ptr<core::system_info_provider_repository>  system_info_provider_repo_;
59         spl::shared_ptr<core::cg_producer_registry>                             cg_registry_;
60         std::promise<bool>&                                                                             shutdown_server_now_;
61
62 public:
63         impl(
64                         const std::vector<spl::shared_ptr<core::video_channel>>& channels,
65                         const std::shared_ptr<core::thumbnail_generator>& thumb_gen,
66                         const spl::shared_ptr<core::media_info_repository>& media_info_repo,
67                         const spl::shared_ptr<core::system_info_provider_repository>& system_info_provider_repo,
68                         const spl::shared_ptr<core::cg_producer_registry>& cg_registry,
69                         std::promise<bool>& shutdown_server_now)
70                 : thumb_gen_(thumb_gen)
71                 , media_info_repo_(media_info_repo)
72                 , system_info_provider_repo_(system_info_provider_repo)
73                 , cg_registry_(cg_registry)
74                 , shutdown_server_now_(shutdown_server_now)
75         {
76                 commandQueues_.push_back(std::make_shared<AMCPCommandQueue>());
77
78                 int index = 0;
79                 for (const auto& channel : channels)
80                 {
81                         std::wstring lifecycle_key = L"lock" + boost::lexical_cast<std::wstring>(index);
82                         channels_.push_back(channel_context(channel, lifecycle_key));
83                         auto queue(std::make_shared<AMCPCommandQueue>());
84                         commandQueues_.push_back(queue);
85                         ++index;
86                 }
87         }
88
89         ~impl() {}
90
91         enum class parser_state {
92                 New = 0,
93                 GetSwitch,
94                 GetCommand,
95                 GetParameters
96         };
97         enum class error_state {
98                 no_error = 0,
99                 command_error,
100                 channel_error,
101                 parameters_error,
102                 unknown_error,
103                 access_error
104         };
105
106         struct command_interpreter_result
107         {
108                 command_interpreter_result() : error(error_state::no_error) {}
109
110                 std::shared_ptr<caspar::IO::lock_container>     lock;
111                 std::wstring                                                            command_name;
112                 AMCPCommand::ptr_type                                           command;
113                 error_state                                                                     error;
114                 AMCPCommandQueue::ptr_type                                      queue;
115         };
116
117         //The paser method expects message to be complete messages with the delimiter stripped away.
118         //Thesefore the AMCPProtocolStrategy should be decorated with a delimiter_based_chunking_strategy
119         void Parse(const std::wstring& message, ClientInfoPtr client)
120         {
121                 CASPAR_LOG(info) << L"Received message from " << client->print() << ": " << message << L"\\r\\n";
122         
123                 command_interpreter_result result;
124                 if(interpret_command_string(message, result, client))
125                 {
126                         if(result.lock && !result.lock->check_access(client))
127                                 result.error = error_state::access_error;
128                         else
129                                 result.queue->AddCommand(result.command);
130                 }
131                 
132                 if (result.error != error_state::no_error)
133                 {
134                         std::wstringstream answer;
135                         boost::to_upper(result.command_name);
136
137                         switch(result.error)
138                         {
139                         case error_state::command_error:
140                                 answer << L"400 ERROR\r\n" << message << "\r\n";
141                                 break;
142                         case error_state::channel_error:
143                                 answer << L"401 " << result.command_name << " ERROR\r\n";
144                                 break;
145                         case error_state::parameters_error:
146                                 answer << L"402 " << result.command_name << " ERROR\r\n";
147                                 break;
148                         case error_state::access_error:
149                                 answer << L"503 " << result.command_name << " FAILED\r\n";
150                                 break;
151                         default:
152                                 answer << L"500 FAILED\r\n";
153                                 break;
154                         }
155                         client->send(answer.str());
156                 }
157         }
158
159 private:
160         friend class AMCPCommand;
161
162         bool interpret_command_string(const std::wstring& message, command_interpreter_result& result, ClientInfoPtr client)
163         {
164                 try
165                 {
166                         std::vector<std::wstring> tokens;
167                         parser_state state = parser_state::New;
168
169                         tokenize(message, &tokens);
170
171                         //parse the message one token at the time
172                         auto end = tokens.end();
173                         auto it = tokens.begin();
174                         while (it != end && result.error == error_state::no_error)
175                         {
176                                 switch(state)
177                                 {
178                                 case parser_state::New:
179                                         if((*it)[0] == L'/')
180                                                 state = parser_state::GetSwitch;
181                                         else
182                                                 state = parser_state::GetCommand;
183                                         break;
184
185                                 case parser_state::GetSwitch:
186                                         //command_switch = (*it);       //we dont care for the switch anymore
187                                         state = parser_state::GetCommand;
188                                         ++it;
189                                         break;
190
191                                 case parser_state::GetCommand:
192                                         {
193                                                 result.command_name = (*it);
194                                                 result.command = create_command(result.command_name, client);
195                                                 if(result.command)      //the command doesn't need a channel
196                                                 {
197                                                         result.queue = commandQueues_[0];
198                                                         state = parser_state::GetParameters;
199                                                 }
200                                                 else
201                                                 {
202                                                         //get channel index from next token
203                                                         int channel_index = -1;
204                                                         int layer_index = -1;
205
206                                                         ++it;
207                                                         if(it == end)
208                                                         {
209                                                                 if(create_channel_command(result.command_name, client, channels_.at(0), 0, 0))  //check if there is a command like this
210                                                                         result.error = error_state::channel_error;
211                                                                 else
212                                                                         result.error = error_state::command_error;
213
214                                                                 break;
215                                                         }
216
217                                                         {       //parse channel/layer token
218                                                                 try
219                                                                 {
220                                                                         std::wstring channelid_str = boost::trim_copy(*it);
221                                                                         std::vector<std::wstring> split;
222                                                                         boost::split(split, channelid_str, boost::is_any_of("-"));
223
224                                                                         channel_index = boost::lexical_cast<int>(split[0]) - 1;
225                                                                         if(split.size() > 1)
226                                                                                 layer_index = boost::lexical_cast<int>(split[1]);
227                                                                 }
228                                                                 catch(...)
229                                                                 {
230                                                                         result.error = error_state::channel_error;
231                                                                         break;
232                                                                 }
233                                                         }
234                                                 
235                                                         if(channel_index >= 0 && channel_index < channels_.size())
236                                                         {
237                                                                 result.command = create_channel_command(result.command_name, client, channels_.at(channel_index), channel_index, layer_index);
238                                                                 if(result.command)
239                                                                 {
240                                                                         result.lock = channels_.at(channel_index).lock;
241                                                                         result.queue = commandQueues_[channel_index + 1];
242                                                                 }
243                                                                 else
244                                                                 {
245                                                                         result.error = error_state::command_error;
246                                                                         break;
247                                                                 }
248                                                         }
249                                                         else
250                                                         {
251                                                                 result.error = error_state::channel_error;
252                                                                 break;
253                                                         }
254                                                 }
255
256                                                 state = parser_state::GetParameters;
257                                                 ++it;
258                                         }
259                                         break;
260
261                                 case parser_state::GetParameters:
262                                         {
263                                                 int parameterCount=0;
264                                                 while(it != end)
265                                                 {
266                                                         result.command->parameters().push_back((*it));
267                                                         ++it;
268                                                         ++parameterCount;
269                                                 }
270                                         }
271                                         break;
272                                 }
273                         }
274
275                         if(result.command && result.error == error_state::no_error && result.command->parameters().size() < result.command->minimum_parameters()) {
276                                 result.error = error_state::parameters_error;
277                         }
278                 }
279                 catch(...)
280                 {
281                         CASPAR_LOG_CURRENT_EXCEPTION();
282                         result.error = error_state::unknown_error;
283                 }
284
285                 return result.error == error_state::no_error;
286         }
287
288         std::size_t tokenize(const std::wstring& message, std::vector<std::wstring>* pTokenVector)
289         {
290                 //split on whitespace but keep strings within quotationmarks
291                 //treat \ as the start of an escape-sequence: the following char will indicate what to actually put in the string
292
293                 std::wstring currentToken;
294
295                 bool inQuote = false;
296                 bool getSpecialCode = false;
297
298                 for(unsigned int charIndex=0; charIndex<message.size(); ++charIndex)
299                 {
300                         if(getSpecialCode)
301                         {
302                                 //insert code-handling here
303                                 switch(message[charIndex])
304                                 {
305                                 case L'\\':
306                                         currentToken += L"\\";
307                                         break;
308                                 case L'\"':
309                                         currentToken += L"\"";
310                                         break;
311                                 case L'n':
312                                         currentToken += L"\n";
313                                         break;
314                                 default:
315                                         break;
316                                 };
317                                 getSpecialCode = false;
318                                 continue;
319                         }
320
321                         if(message[charIndex]==L'\\')
322                         {
323                                 getSpecialCode = true;
324                                 continue;
325                         }
326
327                         if(message[charIndex]==L' ' && inQuote==false)
328                         {
329                                 if(currentToken.size()>0)
330                                 {
331                                         pTokenVector->push_back(currentToken);
332                                         currentToken.clear();
333                                 }
334                                 continue;
335                         }
336
337                         if(message[charIndex]==L'\"')
338                         {
339                                 inQuote = !inQuote;
340
341                                 if(currentToken.size()>0 || !inQuote)
342                                 {
343                                         pTokenVector->push_back(currentToken);
344                                         currentToken.clear();
345                                 }
346                                 continue;
347                         }
348
349                         currentToken += message[charIndex];
350                 }
351
352                 if(currentToken.size()>0)
353                 {
354                         pTokenVector->push_back(currentToken);
355                         currentToken.clear();
356                 }
357
358                 return pTokenVector->size();
359         }
360
361         AMCPCommand::ptr_type create_command(const std::wstring& str, ClientInfoPtr client)
362         {
363                 std::wstring s = boost::to_upper_copy(str);
364                 if (     s == L"DIAG")                  return std::make_shared<DiagnosticsCommand>(client);
365                 else if (s == L"CHANNEL_GRID")  return std::make_shared<ChannelGridCommand>(client, channels_);
366                 else if (s == L"DATA")                  return std::make_shared<DataCommand>(client);
367                 else if (s == L"CINF")                  return std::make_shared<CinfCommand>(client, media_info_repo_);
368                 else if (s == L"INFO")                  return std::make_shared<InfoCommand>(client, channels_, system_info_provider_repo_, cg_registry_);
369                 else if (s == L"CLS")                   return std::make_shared<ClsCommand>(client, media_info_repo_);
370                 else if (s == L"TLS")                   return std::make_shared<TlsCommand>(client, cg_registry_);
371                 else if (s == L"VERSION")               return std::make_shared<VersionCommand>(client, system_info_provider_repo_);
372                 else if (s == L"BYE")                   return std::make_shared<ByeCommand>(client);
373                 else if (s == L"LOCK")                  return std::make_shared<LockCommand>(client, channels_);
374                 else if (s == L"LOG")                   return std::make_shared<LogCommand>(client);
375                 else if (s == L"THUMBNAIL")             return std::make_shared<ThumbnailCommand>(client, thumb_gen_);
376                 else if (s == L"KILL")                  return std::make_shared<KillCommand>(client, shutdown_server_now_);
377                 else if (s == L"RESTART")               return std::make_shared<RestartCommand>(client, shutdown_server_now_);
378
379                 return nullptr;
380         }
381
382         AMCPCommand::ptr_type create_channel_command(const std::wstring& str, ClientInfoPtr client, const channel_context& channel, unsigned int channel_index, int layer_index)
383         {
384                 std::wstring s = boost::to_upper_copy(str);
385         
386                 if (     s == L"MIXER")         return std::make_shared<MixerCommand>(client, channel, channel_index, layer_index);
387                 else if (s == L"CALL")          return std::make_shared<CallCommand>(client, channel, channel_index, layer_index);
388                 else if (s == L"SWAP")          return std::make_shared<SwapCommand>(client, channel, channel_index, layer_index, channels_);
389                 else if (s == L"LOAD")          return std::make_shared<LoadCommand>(client, channel, channel_index, layer_index);
390                 else if (s == L"LOADBG")        return std::make_shared<LoadbgCommand>(client, channel, channel_index, layer_index, channels_);
391                 else if (s == L"ADD")           return std::make_shared<AddCommand>(client, channel, channel_index, layer_index);
392                 else if (s == L"REMOVE")        return std::make_shared<RemoveCommand>(client, channel, channel_index, layer_index);
393                 else if (s == L"PAUSE")         return std::make_shared<PauseCommand>(client, channel, channel_index, layer_index);
394                 else if (s == L"PLAY")          return std::make_shared<PlayCommand>(client, channel, channel_index, layer_index, channels_);
395                 else if (s == L"STOP")          return std::make_shared<StopCommand>(client, channel, channel_index, layer_index);
396                 else if (s == L"CLEAR")         return std::make_shared<ClearCommand>(client, channel, channel_index, layer_index);
397                 else if (s == L"PRINT")         return std::make_shared<PrintCommand>(client, channel, channel_index, layer_index);
398                 else if (s == L"CG")            return std::make_shared<CGCommand>(client, channel, channel_index, layer_index, cg_registry_);
399                 else if (s == L"SET")           return std::make_shared<SetCommand>(client, channel, channel_index, layer_index);
400
401                 return nullptr;
402         }
403 };
404
405
406 AMCPProtocolStrategy::AMCPProtocolStrategy(
407                 const std::vector<spl::shared_ptr<core::video_channel>>& channels,
408                 const std::shared_ptr<core::thumbnail_generator>& thumb_gen,
409                 const spl::shared_ptr<core::media_info_repository>& media_info_repo,
410                 const spl::shared_ptr<core::system_info_provider_repository>& system_info_provider_repo,
411                 const spl::shared_ptr<core::cg_producer_registry>& cg_registry,
412                 std::promise<bool>& shutdown_server_now)
413         : impl_(spl::make_unique<impl>(
414                         channels,
415                         thumb_gen,
416                         media_info_repo,
417                         system_info_provider_repo,
418                         cg_registry,
419                         shutdown_server_now))
420 {
421 }
422 AMCPProtocolStrategy::~AMCPProtocolStrategy() {}
423 void AMCPProtocolStrategy::Parse(const std::wstring& msg, IO::ClientInfoPtr pClientInfo) { impl_->Parse(msg, pClientInfo); }
424
425
426 }       //namespace amcp
427 }}      //namespace caspar