X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=libavfilter%2Fdnn%2Fdnn_backend_tf.c;h=9ceca5cea09bf21824df486ea6860abb14403b70;hb=726dbc57f8162ce82c245a2fdfef2fa074c99dc4;hp=ed91d0500dc69bac6d6557bfbb131cf91f09a430;hpb=e1b45b85963b5aa9d67e23638ef9b045e7fbd875;p=ffmpeg diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c index ed91d0500dc..9ceca5cea09 100644 --- a/libavfilter/dnn/dnn_backend_tf.c +++ b/libavfilter/dnn/dnn_backend_tf.c @@ -95,7 +95,7 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input) break; case DNN_UINT8: dt = TF_UINT8; - size = sizeof(char); + size = 1; break; default: av_assert0(!"should not reach here"); @@ -105,6 +105,37 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input) input_dims[1] * input_dims[2] * input_dims[3] * size); } +static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name) +{ + TFModel *tf_model = (TFModel *)model; + TF_Status *status; + int64_t dims[4]; + + TF_Output tf_output; + tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name); + if (!tf_output.oper) + return DNN_ERROR; + + tf_output.index = 0; + input->dt = TF_OperationOutputType(tf_output); + + status = TF_NewStatus(); + TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status); + if (TF_GetCode(status) != TF_OK){ + TF_DeleteStatus(status); + return DNN_ERROR; + } + TF_DeleteStatus(status); + + // currently only NHWC is supported + av_assert0(dims[0] == 1); + input->height = dims[1]; + input->width = dims[2]; + input->channels = dims[3]; + + return DNN_SUCCESS; +} + static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) { TFModel *tf_model = (TFModel *)model; @@ -568,6 +599,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename) model->model = (void *)tf_model; model->set_input_output = &set_input_output_tf; + model->get_input = &get_input_tf; return model; }