]> git.sesse.net Git - ffmpeg/blobdiff - libavfilter/dnn/dnn_backend_tf.c
avfilter/dnn: get the data type of network output from dnn execution result
[ffmpeg] / libavfilter / dnn / dnn_backend_tf.c
index c8dff51744836a9c2fba30d63cf33a4e872a2330..ed91d0500dc69bac6d6557bfbb131cf91f09a430 100644 (file)
@@ -83,7 +83,7 @@ static TF_Buffer *read_graph(const char *model_filename)
     return graph_buf;
 }
 
-static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
+static TF_Tensor *allocate_input_tensor(const DNNData *input)
 {
     TF_DataType dt;
     size_t size;
@@ -105,7 +105,7 @@ static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
                              input_dims[1] * input_dims[2] * input_dims[3] * size);
 }
 
-static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
 {
     TFModel *tf_model = (TFModel *)model;
     TF_SessionOptions *sess_opts;
@@ -603,6 +603,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, u
         outputs[i].width = TF_Dim(tf_model->output_tensors[i], 2);
         outputs[i].channels = TF_Dim(tf_model->output_tensors[i], 3);
         outputs[i].data = TF_TensorData(tf_model->output_tensors[i]);
+        outputs[i].dt = TF_TensorType(tf_model->output_tensors[i]);
     }
 
     return DNN_SUCCESS;