]> git.sesse.net Git - ffmpeg/commitdiff
libavfilter/dnn: add more data type support for dnn model input
authorGuo, Yejun <yejun.guo@intel.com>
Thu, 25 Apr 2019 02:14:42 +0000 (10:14 +0800)
committerPedro Arthur <bygrandao@gmail.com>
Wed, 8 May 2019 15:33:00 +0000 (12:33 -0300)
currently, only float is supported as model input, actually, there
are other data types, this patch adds uint8.

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
libavfilter/dnn_backend_native.c
libavfilter/dnn_backend_tf.c
libavfilter/dnn_interface.h
libavfilter/vf_sr.c

index 8a83c63c73d88d3cb31edd53a5f7e72e53da1c38..06fbdf368bb0c305279ebc79e1716868e9f3bc98 100644 (file)
@@ -24,8 +24,9 @@
  */
 
 #include "dnn_backend_native.h"
+#include "libavutil/avassert.h"
 
-static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
 {
     ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
     InputParams *input_params;
@@ -45,6 +46,7 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const
         if (input->data){
             av_freep(&input->data);
         }
+        av_assert0(input->dt == DNN_FLOAT);
         network->layers[0].output = input->data = av_malloc(cur_height * cur_width * cur_channels * sizeof(float));
         if (!network->layers[0].output){
             return DNN_ERROR;
index ca6472d445ad115d71b7a42e2bbcebba7e545343..ba959ae3a260f0e1c4b56b87c512b2b68dd6189f 100644 (file)
@@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename)
     return graph_buf;
 }
 
-static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
 {
-    TFModel *tf_model = (TFModel *)model;
+    TF_DataType dt;
+    size_t size;
     int64_t input_dims[] = {1, input->height, input->width, input->channels};
+    switch (input->dt) {
+    case DNN_FLOAT:
+        dt = TF_FLOAT;
+        size = sizeof(float);
+        break;
+    case DNN_UINT8:
+        dt = TF_UINT8;
+        size = sizeof(char);
+        break;
+    default:
+        av_assert0(!"should not reach here");
+    }
+
+    return TF_AllocateTensor(dt, input_dims, 4,
+                             input_dims[1] * input_dims[2] * input_dims[3] * size);
+}
+
+static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+{
+    TFModel *tf_model = (TFModel *)model;
     TF_SessionOptions *sess_opts;
     const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
 
@@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char
     if (tf_model->input_tensor){
         TF_DeleteTensor(tf_model->input_tensor);
     }
-    tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4,
-                                               input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float));
+    tf_model->input_tensor = allocate_input_tensor(input);
     if (!tf_model->input_tensor){
         return DNN_ERROR;
     }
index 73d226ec91bd2c50508ccf72711ae8449d7dbe32..c24df0e96174e0c68b61c22a5a7e2b1ae597b34d 100644 (file)
@@ -32,6 +32,14 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType;
 
 typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType;
 
+typedef enum {DNN_FLOAT, DNN_UINT8} DNNDataType;
+
+typedef struct DNNInputData{
+    void *data;
+    DNNDataType dt;
+    int width, height, channels;
+} DNNInputData;
+
 typedef struct DNNData{
     float *data;
     int width, height, channels;
@@ -42,7 +50,7 @@ typedef struct DNNModel{
     void *model;
     // Sets model input and output.
     // Should be called at least once before model execution.
-    DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output);
+    DNNReturnType (*set_input_output)(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output);
 } DNNModel;
 
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
index 0145511d1151805d1c7f7fb42092b4b956636de1..65baf5f901bdb4aef411ca6793e0c1cad9e9bbb7 100644 (file)
@@ -40,7 +40,8 @@ typedef struct SRContext {
     DNNBackendType backend_type;
     DNNModule *dnn_module;
     DNNModel *model;
-    DNNData input, output;
+    DNNInputData input;
+    DNNData output;
     int scale_factor;
     struct SwsContext *sws_contexts[3];
     int sws_slice_h, sws_input_linesize, sws_output_linesize;
@@ -86,6 +87,7 @@ static av_cold int init(AVFilterContext *context)
         return AVERROR(EIO);
     }
 
+    sr_context->input.dt = DNN_FLOAT;
     sr_context->sws_contexts[0] = NULL;
     sr_context->sws_contexts[1] = NULL;
     sr_context->sws_contexts[2] = NULL;