]> git.sesse.net Git - casparcg/blobdiff - protocol/amcp/AMCPProtocolStrategy.cpp
#430 Fixed bug where it was assumed that all Decklink devices implements the IDeckLin...
[casparcg] / protocol / amcp / AMCPProtocolStrategy.cpp
index f18008cf089662b636b72473f2539faf592cb1fe..871f25c918d17d7d9133925f28d8e0d43c825028 100644 (file)
 #include "../StdAfx.h"
 
 #include "AMCPProtocolStrategy.h"
-
-#include "../util/AsyncEventServer.h"
-#include "AMCPCommandsImpl.h"
+#include "amcp_shared.h"
+#include "AMCPCommand.h"
+#include "AMCPCommandQueue.h"
+#include "amcp_command_repository.h"
 
 #include <stdio.h>
-#include <crtdbg.h>
 #include <string.h>
 #include <algorithm>
 #include <cctype>
+#include <future>
+
+#include <core/help/help_repository.h>
+#include <core/help/help_sink.h>
 
 #include <boost/algorithm/string/trim.hpp>
 #include <boost/algorithm/string/split.hpp>
@@ -46,365 +50,287 @@ namespace caspar { namespace protocol { namespace amcp {
 
 using IO::ClientInfoPtr;
 
-const std::wstring AMCPProtocolStrategy::MessageDelimiter = TEXT("\r\n");
-
-inline std::shared_ptr<core::video_channel> GetChannelSafe(unsigned int index, const std::vector<spl::shared_ptr<core::video_channel>>& channels)
+template <typename Out, typename In>
+bool try_lexical_cast(const In& input, Out& result)
 {
-       return index < channels.size() ? std::shared_ptr<core::video_channel>(channels[index]) : nullptr;
-}
+       Out saved = result;
+       bool success = boost::conversion::detail::try_lexical_convert(input, result);
 
-AMCPProtocolStrategy::AMCPProtocolStrategy(const std::vector<spl::shared_ptr<core::video_channel>>& channels) : channels_(channels) {
-       AMCPCommandQueuePtr pGeneralCommandQueue(new AMCPCommandQueue());
-       commandQueues_.push_back(pGeneralCommandQueue);
+       if (!success)
+               result = saved; // Needed because of how try_lexical_convert is implemented.
 
-
-       std::shared_ptr<core::video_channel> pChannel;
-       unsigned int index = -1;
-       //Create a commandpump for each video_channel
-       while((pChannel = GetChannelSafe(++index, channels_)) != 0) {
-               AMCPCommandQueuePtr pChannelCommandQueue(new AMCPCommandQueue());
-               std::wstring title = TEXT("video_channel ");
-
-               //HACK: Perform real conversion from int to string
-               TCHAR num = TEXT('1')+static_cast<TCHAR>(index);
-               title += num;
-               
-               commandQueues_.push_back(pChannelCommandQueue);
-       }
+       return success;
 }
 
-AMCPProtocolStrategy::~AMCPProtocolStrategy() {
-}
-
-void AMCPProtocolStrategy::Parse(const TCHAR* pData, int charCount, ClientInfoPtr pClientInfo)
+struct AMCPProtocolStrategy::impl
 {
-       size_t pos;
-       std::wstring recvData(pData, charCount);
-       std::wstring availibleData = (pClientInfo != nullptr ? pClientInfo->currentMessage_ : L"") + recvData;
-
-       while(true) {
-               pos = availibleData.find(MessageDelimiter);
-               if(pos != std::wstring::npos)
-               {
-                       std::wstring message = availibleData.substr(0,pos);
+private:
+       std::vector<AMCPCommandQueue::ptr_type>         commandQueues_;
+       spl::shared_ptr<amcp_command_repository>        repo_;
 
-                       //This is where a complete message gets taken care of
-                       if(message.length() > 0) {
-                               ProcessMessage(message, pClientInfo);
-                       }
+public:
+       impl(const std::wstring& name, const spl::shared_ptr<amcp_command_repository>& repo)
+               : repo_(repo)
+       {
+               commandQueues_.push_back(spl::make_shared<AMCPCommandQueue>(L"General Queue for " + name));
 
-                       std::size_t nextStartPos = pos + MessageDelimiter.length();
-                       if(nextStartPos < availibleData.length())
-                               availibleData = availibleData.substr(nextStartPos);
-                       else {
-                               availibleData.clear();
-                               break;
-                       }
-               }
-               else
+               for (int i = 0; i < repo_->channels().size(); ++i)
                {
-                       break;
+                       commandQueues_.push_back(spl::make_shared<AMCPCommandQueue>(
+                                       L"Channel " + boost::lexical_cast<std::wstring>(i + 1) + L" for " + name));
                }
        }
-       if(pClientInfo)
-               pClientInfo->currentMessage_ = availibleData;
-}
-
-void AMCPProtocolStrategy::ProcessMessage(const std::wstring& message, ClientInfoPtr& pClientInfo)
-{      
-       CASPAR_LOG(info) << L"Received message from " << pClientInfo->print() << ": " << message << L"\\r\\n";
-       
-       bool bError = true;
-       MessageParserState state = New;
 
-       AMCPCommand::ptr_type pCommand(InterpretCommandString(message, &state));
+       ~impl() {}
 
-       if(pCommand != 0) {
-               pCommand->SetClientInfo(pClientInfo);   
-               if(QueueCommand(pCommand))
-                       bError = false;
-               else
-                       state = GetChannel;
-       }
+       enum class error_state {
+               no_error = 0,
+               command_error,
+               channel_error,
+               parameters_error,
+               unknown_error,
+               access_error
+       };
 
-       if(bError == true) {
-               std::wstringstream answer;
-               switch(state)
+       struct command_interpreter_result
+       {
+               std::shared_ptr<caspar::IO::lock_container>     lock;
+               std::wstring                                                            command_name;
+               AMCPCommand::ptr_type                                           command;
+               error_state                                                                     error                   = error_state::no_error;
+               std::shared_ptr<AMCPCommandQueue>                       queue;
+       };
+
+       //The paser method expects message to be complete messages with the delimiter stripped away.
+       //Thesefore the AMCPProtocolStrategy should be decorated with a delimiter_based_chunking_strategy
+       void Parse(const std::wstring& message, ClientInfoPtr client)
+       {
+               CASPAR_LOG_COMMUNICATION(info) << L"Received message from " << client->address() << ": " << message << L"\\r\\n";
+       
+               command_interpreter_result result;
+               if(interpret_command_string(message, result, client))
                {
-               case GetCommand:
-                       answer << TEXT("400 ERROR\r\n") + message << "\r\n";
-                       break;
-               case GetChannel:
-                       answer << TEXT("401 ERROR\r\n");
-                       break;
-               case GetParameters:
-                       answer << TEXT("402 ERROR\r\n");
-                       break;
-               default:
-                       answer << TEXT("500 FAILED\r\n");
-                       break;
+                       if(result.lock && !result.lock->check_access(client))
+                               result.error = error_state::access_error;
+                       else
+                               result.queue->AddCommand(result.command);
                }
-               pClientInfo->Send(answer.str());
-       }
-}
-
-AMCPCommand::ptr_type AMCPProtocolStrategy::InterpretCommandString(const std::wstring& message, MessageParserState* pOutState)
-{
-       std::vector<std::wstring> tokens;
-       unsigned int currentToken = 0;
-       std::wstring commandSwitch;
-
-       AMCPCommand::ptr_type pCommand;
-       MessageParserState state = New;
+               
+               if (result.error != error_state::no_error)
+               {
+                       std::wstringstream answer;
 
-       std::size_t tokensInMessage = TokenizeMessage(message, &tokens);
+                       switch(result.error)
+                       {
+                       case error_state::command_error:
+                               answer << L"400 ERROR\r\n" << message << "\r\n";
+                               break;
+                       case error_state::channel_error:
+                               answer << L"401 " << result.command_name << " ERROR\r\n";
+                               break;
+                       case error_state::parameters_error:
+                               answer << L"402 " << result.command_name << " ERROR\r\n";
+                               break;
+                       case error_state::access_error:
+                               answer << L"503 " << result.command_name << " FAILED\r\n";
+                               break;
+                       case error_state::unknown_error:
+                               answer << L"500 FAILED\r\n";
+                               break;
+                       default:
+                               CASPAR_THROW_EXCEPTION(programming_error()
+                                               << msg_info(L"Unhandled error_state enum constant " + boost::lexical_cast<std::wstring>(static_cast<int>(result.error))));
+                       }
+                       client->send(answer.str());
+               }
+       }
 
-       //parse the message one token at the time
-       while(currentToken < tokensInMessage)
+private:
+       bool interpret_command_string(const std::wstring& message, command_interpreter_result& result, ClientInfoPtr client)
        {
-               switch(state)
+               try
                {
-               case New:
-                       if(tokens[currentToken][0] == TEXT('/'))
-                               state = GetSwitch;
-                       else
-                               state = GetCommand;
-                       break;
-
-               case GetSwitch:
-                       commandSwitch = tokens[currentToken];
-                       state = GetCommand;
-                       ++currentToken;
-                       break;
-
-               case GetCommand:
-                       pCommand = CommandFactory(tokens[currentToken]);
-                       if(pCommand == 0) {
-                               goto ParseFinnished;
-                       }
-                       else
-                       {
-                               pCommand->SetChannels(channels_);
-                               //Set scheduling
-                               if(commandSwitch.size() > 0) {
-                                       transform(commandSwitch.begin(), commandSwitch.end(), commandSwitch.begin(), toupper);
-
-                                       //if(commandSwitch == TEXT("/APP"))
-                                       //      pCommand->SetScheduling(AddToQueue);
-                                       //else if(commandSwitch  == TEXT("/IMMF"))
-                                       //      pCommand->SetScheduling(ImmediatelyAndClear);
-                               }
+                       std::list<std::wstring> tokens;
+                       tokenize(message, tokens);
 
-                               if(pCommand->NeedChannel())
-                                       state = GetChannel;
-                               else
-                                       state = GetParameters;
+                       // Discard GetSwitch
+                       if (!tokens.empty() && tokens.front().at(0) == L'/')
+                               tokens.pop_front();
+
+                       // Fail if no more tokens.
+                       if (tokens.empty())
+                       {
+                               result.error = error_state::command_error;
+                               return false;
                        }
-                       ++currentToken;
-                       break;
 
-               case GetParameters:
+                       // Consume command name
+                       result.command_name = boost::to_upper_copy(tokens.front());
+                       tokens.pop_front();
+
+                       // Determine whether the next parameter is a channel spec or not
+                       int channel_index = -1;
+                       int layer_index = -1;
+                       std::wstring channel_spec;
+
+                       if (!tokens.empty())
                        {
-                               _ASSERTE(pCommand != 0);
-                               int parameterCount=0;
-                               while(currentToken<tokensInMessage)
+                               channel_spec = tokens.front();
+                               std::wstring channelid_str = boost::trim_copy(channel_spec);
+                               std::vector<std::wstring> split;
+                               boost::split(split, channelid_str, boost::is_any_of("-"));
+
+                               // Use non_throwing lexical cast to not hit exception break point all the time.
+                               if (try_lexical_cast(split[0], channel_index))
                                {
-                                       pCommand->AddParameter(tokens[currentToken++]);
-                                       ++parameterCount;
-                               }
+                                       --channel_index;
 
-                               if(parameterCount < pCommand->GetMinimumParameters()) {
-                                       goto ParseFinnished;
-                               }
+                                       if (split.size() > 1)
+                                               try_lexical_cast(split[1], layer_index);
 
-                               state = Done;
-                               break;
+                                       // Consume channel-spec
+                                       tokens.pop_front();
+                               }
                        }
 
-               case GetChannel:
+                       bool is_channel_command = channel_index != -1;
+
+                       // Create command instance
+                       if (is_channel_command)
                        {
-//                             assert(pCommand != 0);
+                               result.command = repo_->create_channel_command(result.command_name, client, channel_index, layer_index, tokens);
 
-                               std::wstring str = boost::trim_copy(tokens[currentToken]);
-                               std::vector<std::wstring> split;
-                               boost::split(split, str, boost::is_any_of("-"));
-                                       
-                               int channelIndex = -1;
-                               int layerIndex = -1;
-                               try
+                               if (result.command)
                                {
-                                       channelIndex = boost::lexical_cast<int>(split[0]) - 1;
-
-                                       if(split.size() > 1)
-                                               layerIndex = boost::lexical_cast<int>(split[1]);
+                                       result.lock = repo_->channels().at(channel_index).lock;
+                                       result.queue = commandQueues_.at(channel_index + 1);
                                }
-                               catch(...)
+                               else // Might be a non channel command, although the first argument is numeric
                                {
-                                       goto ParseFinnished;
-                               }
+                                       // Restore backed up channel spec string.
+                                       tokens.push_front(channel_spec);
+                                       result.command = repo_->create_command(result.command_name, client, tokens);
 
-                               std::shared_ptr<core::video_channel> pChannel = GetChannelSafe(channelIndex, channels_);
-                               if(pChannel == 0) {
-                                       goto ParseFinnished;
+                                       if (result.command)
+                                               result.queue = commandQueues_.at(0);
                                }
-
-                               pCommand->SetChannel(pChannel);
-                               pCommand->SetChannels(channels_);
-                               pCommand->SetChannelIndex(channelIndex);
-                               pCommand->SetLayerIntex(layerIndex);
-
-                               state = GetParameters;
-                               ++currentToken;
-                               break;
                        }
+                       else
+                       {
+                               result.command = repo_->create_command(result.command_name, client, tokens);
 
-               default:        //Done and unexpected
-                       goto ParseFinnished;
-               }
-       }
-
-ParseFinnished:
-       if(state == GetParameters && pCommand->GetMinimumParameters()==0)
-               state = Done;
-
-       if(state != Done) {
-               pCommand.reset();
-       }
+                               if (result.command)
+                                       result.queue = commandQueues_.at(0);
+                       }
 
-       if(pOutState != 0) {
-               *pOutState = state;
-       }
+                       if (!result.command)
+                               result.error = error_state::command_error;
+                       else
+                       {
+                               std::vector<std::wstring> parameters(tokens.begin(), tokens.end());
 
-       return pCommand;
-}
+                               result.command->parameters() = std::move(parameters);
 
-bool AMCPProtocolStrategy::QueueCommand(AMCPCommand::ptr_type pCommand) {
-       if(pCommand->NeedChannel()) {
-               unsigned int channelIndex = pCommand->GetChannelIndex() + 1;
-               if(commandQueues_.size() > channelIndex) {
-                       commandQueues_[channelIndex]->AddCommand(pCommand);
+                               if (result.command->parameters().size() < result.command->minimum_parameters())
+                                       result.error = error_state::parameters_error;
+                       }
+               }
+               catch (std::out_of_range&)
+               {
+                       CASPAR_LOG(error) << "Invalid channel specified.";
+                       result.error = error_state::channel_error;
+               }
+               catch (...)
+               {
+                       CASPAR_LOG_CURRENT_EXCEPTION();
+                       result.error = error_state::unknown_error;
                }
-               else
-                       return false;
-       }
-       else {
-               commandQueues_[0]->AddCommand(pCommand);
-       }
-       return true;
-}
 
-AMCPCommand::ptr_type AMCPProtocolStrategy::CommandFactory(const std::wstring& str)
-{
-       std::wstring s = str;
-       transform(s.begin(), s.end(), s.begin(), toupper);
-       
-       if         (s == TEXT("MIXER"))                 return std::make_shared<MixerCommand>();
-       else if(s == TEXT("DIAG"))                      return std::make_shared<DiagnosticsCommand>();
-       else if(s == TEXT("CHANNEL_GRID"))      return std::make_shared<ChannelGridCommand>();
-       else if(s == TEXT("CALL"))                      return std::make_shared<CallCommand>();
-       else if(s == TEXT("SWAP"))                      return std::make_shared<SwapCommand>();
-       else if(s == TEXT("LOAD"))                      return std::make_shared<LoadCommand>();
-       else if(s == TEXT("LOADBG"))            return std::make_shared<LoadbgCommand>();
-       else if(s == TEXT("ADD"))                       return std::make_shared<AddCommand>();
-       else if(s == TEXT("REMOVE"))            return std::make_shared<RemoveCommand>();
-       else if(s == TEXT("PAUSE"))                     return std::make_shared<PauseCommand>();
-       else if(s == TEXT("PLAY"))                      return std::make_shared<PlayCommand>();
-       else if(s == TEXT("STOP"))                      return std::make_shared<StopCommand>();
-       else if(s == TEXT("CLEAR"))                     return std::make_shared<ClearCommand>();
-       else if(s == TEXT("PRINT"))                     return std::make_shared<PrintCommand>();
-       else if(s == TEXT("LOG"))                       return std::make_shared<LogCommand>();
-       else if(s == TEXT("CG"))                        return std::make_shared<CGCommand>();
-       else if(s == TEXT("DATA"))                      return std::make_shared<DataCommand>();
-       else if(s == TEXT("CINF"))                      return std::make_shared<CinfCommand>();
-       else if(s == TEXT("INFO"))                      return std::make_shared<InfoCommand>(channels_);
-       else if(s == TEXT("CLS"))                       return std::make_shared<ClsCommand>();
-       else if(s == TEXT("TLS"))                       return std::make_shared<TlsCommand>();
-       else if(s == TEXT("VERSION"))           return std::make_shared<VersionCommand>();
-       else if(s == TEXT("BYE"))                       return std::make_shared<ByeCommand>();
-       else if(s == TEXT("SET"))                       return std::make_shared<SetCommand>();
-       //else if(s == TEXT("MONITOR"))
-       //{
-       //      result = AMCPCommandPtr(new MonitorCommand());
-       //}
-       //else if(s == TEXT("KILL"))
-       //{
-       //      result = AMCPCommandPtr(new KillCommand());
-       //}
-       return nullptr;
-}
+               return result.error == error_state::no_error;
+       }
 
-std::size_t AMCPProtocolStrategy::TokenizeMessage(const std::wstring& message, std::vector<std::wstring>* pTokenVector)
-{
-       //split on whitespace but keep strings within quotationmarks
-       //treat \ as the start of an escape-sequence: the following char will indicate what to actually put in the string
+       template<typename C>
+       std::size_t tokenize(const std::wstring& message, C& pTokenVector)
+       {
+               //split on whitespace but keep strings within quotationmarks
+               //treat \ as the start of an escape-sequence: the following char will indicate what to actually put in the string
 
-       std::wstring currentToken;
+               std::wstring currentToken;
 
-       bool inQuote = false;
-       bool getSpecialCode = false;
+               bool inQuote = false;
+               bool getSpecialCode = false;
 
-       for(unsigned int charIndex=0; charIndex<message.size(); ++charIndex)
-       {
-               if(getSpecialCode)
+               for(unsigned int charIndex=0; charIndex<message.size(); ++charIndex)
                {
-                       //insert code-handling here
-                       switch(message[charIndex])
+                       if(getSpecialCode)
                        {
-                       case TEXT('\\'):
-                               currentToken += TEXT("\\");
-                               break;
-                       case TEXT('\"'):
-                               currentToken += TEXT("\"");
-                               break;
-                       case TEXT('n'):
-                               currentToken += TEXT("\n");
-                               break;
-                       default:
-                               break;
-                       };
-                       getSpecialCode = false;
-                       continue;
-               }
-
-               if(message[charIndex]==TEXT('\\'))
-               {
-                       getSpecialCode = true;
-                       continue;
-               }
+                               //insert code-handling here
+                               switch(message[charIndex])
+                               {
+                               case L'\\':
+                                       currentToken += L"\\";
+                                       break;
+                               case L'\"':
+                                       currentToken += L"\"";
+                                       break;
+                               case L'n':
+                                       currentToken += L"\n";
+                                       break;
+                               default:
+                                       break;
+                               };
+                               getSpecialCode = false;
+                               continue;
+                       }
 
-               if(message[charIndex]==' ' && inQuote==false)
-               {
-                       if(currentToken.size()>0)
+                       if(message[charIndex]==L'\\')
                        {
-                               pTokenVector->push_back(currentToken);
-                               currentToken.clear();
+                               getSpecialCode = true;
+                               continue;
                        }
-                       continue;
-               }
 
-               if(message[charIndex]==TEXT('\"'))
-               {
-                       inQuote = !inQuote;
+                       if(message[charIndex]==L' ' && inQuote==false)
+                       {
+                               if(!currentToken.empty())
+                               {
+                                       pTokenVector.push_back(currentToken);
+                                       currentToken.clear();
+                               }
+                               continue;
+                       }
 
-                       if(currentToken.size() > 0 || !inQuote)
+                       if(message[charIndex]==L'\"')
                        {
-                               pTokenVector->push_back(currentToken);
-                               currentToken.clear();
+                               inQuote = !inQuote;
+
+                               if(!currentToken.empty() || !inQuote)
+                               {
+                                       pTokenVector.push_back(currentToken);
+                                       currentToken.clear();
+                               }
+                               continue;
                        }
-                       continue;
+
+                       currentToken += message[charIndex];
                }
 
-               currentToken += message[charIndex];
-       }
+               if(!currentToken.empty())
+               {
+                       pTokenVector.push_back(currentToken);
+                       currentToken.clear();
+               }
 
-       if(currentToken.size()>0)
-       {
-               pTokenVector->push_back(currentToken);
-               currentToken.clear();
+               return pTokenVector.size();
        }
+};
 
-       return pTokenVector->size();
+AMCPProtocolStrategy::AMCPProtocolStrategy(const std::wstring& name, const spl::shared_ptr<amcp_command_repository>& repo)
+       : impl_(spl::make_unique<impl>(name, repo))
+{
 }
+AMCPProtocolStrategy::~AMCPProtocolStrategy() {}
+void AMCPProtocolStrategy::Parse(const std::wstring& msg, IO::ClientInfoPtr pClientInfo) { impl_->Parse(msg, pClientInfo); }
+
 
 }      //namespace amcp
-}}     //namespace caspar
\ No newline at end of file
+}}     //namespace caspar