]> git.sesse.net Git - ffmpeg/blobdiff - libavfilter/dnn/dnn_backend_tf.c
avfilter: add dblur video filter
[ffmpeg] / libavfilter / dnn / dnn_backend_tf.c
index ed91d0500dc69bac6d6557bfbb131cf91f09a430..9ceca5cea09bf21824df486ea6860abb14403b70 100644 (file)
@@ -95,7 +95,7 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input)
         break;
     case DNN_UINT8:
         dt = TF_UINT8;
-        size = sizeof(char);
+        size = 1;
         break;
     default:
         av_assert0(!"should not reach here");
@@ -105,6 +105,37 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input)
                              input_dims[1] * input_dims[2] * input_dims[3] * size);
 }
 
+static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name)
+{
+    TFModel *tf_model = (TFModel *)model;
+    TF_Status *status;
+    int64_t dims[4];
+
+    TF_Output tf_output;
+    tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name);
+    if (!tf_output.oper)
+        return DNN_ERROR;
+
+    tf_output.index = 0;
+    input->dt = TF_OperationOutputType(tf_output);
+
+    status = TF_NewStatus();
+    TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status);
+    if (TF_GetCode(status) != TF_OK){
+        TF_DeleteStatus(status);
+        return DNN_ERROR;
+    }
+    TF_DeleteStatus(status);
+
+    // currently only NHWC is supported
+    av_assert0(dims[0] == 1);
+    input->height = dims[1];
+    input->width = dims[2];
+    input->channels = dims[3];
+
+    return DNN_SUCCESS;
+}
+
 static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
 {
     TFModel *tf_model = (TFModel *)model;
@@ -568,6 +599,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
 
     model->model = (void *)tf_model;
     model->set_input_output = &set_input_output_tf;
+    model->get_input = &get_input_tf;
 
     return model;
 }