]> git.sesse.net Git - ffmpeg/blobdiff - libavfilter/dnn/dnn_backend_tf.c
avfilter: add dblur video filter
[ffmpeg] / libavfilter / dnn / dnn_backend_tf.c
index 46dfa009cd356fbf981fab777688da5046684d62..9ceca5cea09bf21824df486ea6860abb14403b70 100644 (file)
 #include "dnn_backend_tf.h"
 #include "dnn_backend_native.h"
 #include "dnn_backend_native_layer_conv2d.h"
+#include "dnn_backend_native_layer_depth2space.h"
 #include "libavformat/avio.h"
 #include "libavutil/avassert.h"
 #include "dnn_backend_native_layer_pad.h"
+#include "dnn_backend_native_layer_maximum.h"
 
 #include <tensorflow/c/c_api.h>
 
@@ -81,7 +83,7 @@ static TF_Buffer *read_graph(const char *model_filename)
     return graph_buf;
 }
 
-static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
+static TF_Tensor *allocate_input_tensor(const DNNData *input)
 {
     TF_DataType dt;
     size_t size;
@@ -93,7 +95,7 @@ static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
         break;
     case DNN_UINT8:
         dt = TF_UINT8;
-        size = sizeof(char);
+        size = 1;
         break;
     default:
         av_assert0(!"should not reach here");
@@ -103,7 +105,38 @@ static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
                              input_dims[1] * input_dims[2] * input_dims[3] * size);
 }
 
-static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name)
+{
+    TFModel *tf_model = (TFModel *)model;
+    TF_Status *status;
+    int64_t dims[4];
+
+    TF_Output tf_output;
+    tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name);
+    if (!tf_output.oper)
+        return DNN_ERROR;
+
+    tf_output.index = 0;
+    input->dt = TF_OperationOutputType(tf_output);
+
+    status = TF_NewStatus();
+    TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status);
+    if (TF_GetCode(status) != TF_OK){
+        TF_DeleteStatus(status);
+        return DNN_ERROR;
+    }
+    TF_DeleteStatus(status);
+
+    // currently only NHWC is supported
+    av_assert0(dims[0] == 1);
+    input->height = dims[1];
+    input->width = dims[2];
+    input->channels = dims[3];
+
+    return DNN_SUCCESS;
+}
+
+static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
 {
     TFModel *tf_model = (TFModel *)model;
     TF_SessionOptions *sess_opts;
@@ -400,6 +433,48 @@ static DNNReturnType add_pad_layer(TFModel *tf_model, TF_Operation **cur_op,
     return DNN_SUCCESS;
 }
 
+static DNNReturnType add_maximum_layer(TFModel *tf_model, TF_Operation **cur_op,
+                                       DnnLayerMaximumParams *params, const int layer)
+{
+    TF_Operation *op;
+    TF_Tensor *tensor;
+    TF_OperationDescription *op_desc;
+    TF_Output input;
+    float *y;
+
+    char name_buffer[NAME_BUFFER_SIZE];
+    snprintf(name_buffer, NAME_BUFFER_SIZE, "maximum/y%d", layer);
+
+    op_desc = TF_NewOperation(tf_model->graph, "Const", name_buffer);
+    TF_SetAttrType(op_desc, "dtype", TF_FLOAT);
+    tensor = TF_AllocateTensor(TF_FLOAT, NULL, 0, TF_DataTypeSize(TF_FLOAT));
+    y = (float *)TF_TensorData(tensor);
+    *y = params->val.y;
+    TF_SetAttrTensor(op_desc, "value", tensor, tf_model->status);
+    if (TF_GetCode(tf_model->status) != TF_OK){
+        return DNN_ERROR;
+    }
+    op = TF_FinishOperation(op_desc, tf_model->status);
+    if (TF_GetCode(tf_model->status) != TF_OK){
+        return DNN_ERROR;
+    }
+
+    snprintf(name_buffer, NAME_BUFFER_SIZE, "maximum%d", layer);
+    op_desc = TF_NewOperation(tf_model->graph, "Maximum", name_buffer);
+    input.oper = *cur_op;
+    input.index = 0;
+    TF_AddInput(op_desc, input);
+    input.oper = op;
+    TF_AddInput(op_desc, input);
+    TF_SetAttrType(op_desc, "T", TF_FLOAT);
+    *cur_op = TF_FinishOperation(op_desc, tf_model->status);
+    if (TF_GetCode(tf_model->status) != TF_OK){
+        return DNN_ERROR;
+    }
+
+    return DNN_SUCCESS;
+}
+
 static DNNReturnType load_native_model(TFModel *tf_model, const char *model_filename)
 {
     int32_t layer;
@@ -455,21 +530,25 @@ static DNNReturnType load_native_model(TFModel *tf_model, const char *model_file
 
     for (layer = 0; layer < conv_network->layers_num; ++layer){
         switch (conv_network->layers[layer].type){
-        case INPUT:
+        case DLT_INPUT:
             layer_add_res = DNN_SUCCESS;
             break;
-        case CONV:
+        case DLT_CONV2D:
             layer_add_res = add_conv_layer(tf_model, transpose_op, &op,
                                            (ConvolutionalParams *)conv_network->layers[layer].params, layer);
             break;
-        case DEPTH_TO_SPACE:
+        case DLT_DEPTH_TO_SPACE:
             layer_add_res = add_depth_to_space_layer(tf_model, &op,
                                                      (DepthToSpaceParams *)conv_network->layers[layer].params, layer);
             break;
-        case MIRROR_PAD:
+        case DLT_MIRROR_PAD:
             layer_add_res = add_pad_layer(tf_model, &op,
                                           (LayerPadParams *)conv_network->layers[layer].params, layer);
             break;
+        case DLT_MAXIMUM:
+            layer_add_res = add_maximum_layer(tf_model, &op,
+                                          (DnnLayerMaximumParams *)conv_network->layers[layer].params, layer);
+            break;
         default:
             CLEANUP_ON_ERROR(tf_model);
         }
@@ -520,6 +599,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
 
     model->model = (void *)tf_model;
     model->set_input_output = &set_input_output_tf;
+    model->get_input = &get_input_tf;
 
     return model;
 }
@@ -555,6 +635,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, u
         outputs[i].width = TF_Dim(tf_model->output_tensors[i], 2);
         outputs[i].channels = TF_Dim(tf_model->output_tensors[i], 3);
         outputs[i].data = TF_TensorData(tf_model->output_tensors[i]);
+        outputs[i].dt = TF_TensorType(tf_model->output_tensors[i]);
     }
 
     return DNN_SUCCESS;