]> git.sesse.net Git - ffmpeg/commitdiff
dnn: put DNNModel.set_input and DNNModule.execute_model together
authorGuo, Yejun <yejun.guo@intel.com>
Thu, 10 Sep 2020 14:29:57 +0000 (22:29 +0800)
committerGuo, Yejun <yejun.guo@intel.com>
Mon, 21 Sep 2020 13:26:56 +0000 (21:26 +0800)
suppose we have a detect and classify filter in the future, the
detect filter generates some bounding boxes (BBox) as AVFrame sidedata,
and the classify filter executes DNN model for each BBox. For each
BBox, we need to crop the AVFrame, copy data to DNN model input and do
the model execution. So we have to save the in_frame at DNNModel.set_input
and use it at DNNModule.execute_model, such saving is not feasible
when we support async execute_model.

This patch sets the in_frame as execution_model parameter, and so
all the information are put together within the same function for
each inference. It also makes easy to support BBox async inference.

libavfilter/dnn/dnn_backend_native.c
libavfilter/dnn/dnn_backend_native.h
libavfilter/dnn/dnn_backend_openvino.c
libavfilter/dnn/dnn_backend_openvino.h
libavfilter/dnn/dnn_backend_tf.c
libavfilter/dnn/dnn_backend_tf.h
libavfilter/dnn_interface.h
libavfilter/vf_derain.c
libavfilter/vf_dnn_processing.c
libavfilter/vf_sr.c

index 14e878b6b8204d4f4cd5dcee3731dc961bbd2169..dc47c9b5429ee970807217ad4c9242ac73566fb3 100644 (file)
@@ -70,64 +70,6 @@ static DNNReturnType get_input_native(void *model, DNNData *input, const char *i
     return DNN_ERROR;
 }
 
-static DNNReturnType set_input_native(void *model, AVFrame *frame, const char *input_name)
-{
-    NativeModel *native_model = (NativeModel *)model;
-    NativeContext *ctx = &native_model->ctx;
-    DnnOperand *oprd = NULL;
-    DNNData input;
-
-    if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
-        av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
-        return DNN_ERROR;
-    }
-
-    /* inputs */
-    for (int i = 0; i < native_model->operands_num; ++i) {
-        oprd = &native_model->operands[i];
-        if (strcmp(oprd->name, input_name) == 0) {
-            if (oprd->type != DOT_INPUT) {
-                av_log(ctx, AV_LOG_ERROR, "Found \"%s\" in model, but it is not input node\n", input_name);
-                return DNN_ERROR;
-            }
-            break;
-        }
-        oprd = NULL;
-    }
-    if (!oprd) {
-        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
-        return DNN_ERROR;
-    }
-
-    oprd->dims[1] = frame->height;
-    oprd->dims[2] = frame->width;
-
-    av_freep(&oprd->data);
-    oprd->length = calculate_operand_data_length(oprd);
-    if (oprd->length <= 0) {
-        av_log(ctx, AV_LOG_ERROR, "The input data length overflow\n");
-        return DNN_ERROR;
-    }
-    oprd->data = av_malloc(oprd->length);
-    if (!oprd->data) {
-        av_log(ctx, AV_LOG_ERROR, "Failed to malloc memory for input data\n");
-        return DNN_ERROR;
-    }
-
-    input.height = oprd->dims[1];
-    input.width = oprd->dims[2];
-    input.channels = oprd->dims[3];
-    input.data = oprd->data;
-    input.dt = oprd->data_type;
-    if (native_model->model->pre_proc != NULL) {
-        native_model->model->pre_proc(frame, &input, native_model->model->userdata);
-    } else {
-        proc_from_frame_to_dnn(frame, &input, ctx);
-    }
-
-    return DNN_SUCCESS;
-}
-
 // Loads model and its parameters that are stored in a binary file with following structure:
 // layers_num,layer_type,layer_parameterss,layer_type,layer_parameters...
 // For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases
@@ -273,7 +215,6 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio
         return NULL;
     }
 
-    model->set_input = &set_input_native;
     model->get_input = &get_input_native;
     model->userdata = userdata;
 
@@ -285,26 +226,66 @@ fail:
     return NULL;
 }
 
-DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
+DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                          const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     NativeModel *native_model = (NativeModel *)model->model;
     NativeContext *ctx = &native_model->ctx;
     int32_t layer;
-    DNNData output;
+    DNNData input, output;
+    DnnOperand *oprd = NULL;
 
-    if (nb_output != 1) {
-        // currently, the filter does not need multiple outputs,
-        // so we just pending the support until we really need it.
-        av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
+    if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
+        av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
         return DNN_ERROR;
     }
 
-    if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
-        av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
+    for (int i = 0; i < native_model->operands_num; ++i) {
+        oprd = &native_model->operands[i];
+        if (strcmp(oprd->name, input_name) == 0) {
+            if (oprd->type != DOT_INPUT) {
+                av_log(ctx, AV_LOG_ERROR, "Found \"%s\" in model, but it is not input node\n", input_name);
+                return DNN_ERROR;
+            }
+            break;
+        }
+        oprd = NULL;
+    }
+    if (!oprd) {
+        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
+        return DNN_ERROR;
+    }
+
+    oprd->dims[1] = in_frame->height;
+    oprd->dims[2] = in_frame->width;
+
+    av_freep(&oprd->data);
+    oprd->length = calculate_operand_data_length(oprd);
+    if (oprd->length <= 0) {
+        av_log(ctx, AV_LOG_ERROR, "The input data length overflow\n");
         return DNN_ERROR;
     }
-    if (!native_model->operands[0].data) {
-        av_log(ctx, AV_LOG_ERROR, "Empty model input data\n");
+    oprd->data = av_malloc(oprd->length);
+    if (!oprd->data) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to malloc memory for input data\n");
+        return DNN_ERROR;
+    }
+
+    input.height = oprd->dims[1];
+    input.width = oprd->dims[2];
+    input.channels = oprd->dims[3];
+    input.data = oprd->data;
+    input.dt = oprd->data_type;
+    if (native_model->model->pre_proc != NULL) {
+        native_model->model->pre_proc(in_frame, &input, native_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(in_frame, &input, ctx);
+    }
+
+    if (nb_output != 1) {
+        // currently, the filter does not need multiple outputs,
+        // so we just pending the support until we really need it.
+        av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
         return DNN_ERROR;
     }
 
index 553438bd22763466920a84b4b260e86c677df144..2f8d73fcf604db3c8b0f30fe4ce3f754138eb918 100644 (file)
@@ -128,7 +128,8 @@ typedef struct NativeModel{
 
 DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                          const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_native(DNNModel **model);
 
index b1bad3f659d2ac4d63ef99af2d685c0f39ac8ec0..0dba1c1adcf63d61ba13f9622e1509fb503722d3 100644 (file)
@@ -48,7 +48,6 @@ typedef struct OVModel{
     ie_network_t *network;
     ie_executable_network_t *exe_network;
     ie_infer_request_t *infer_request;
-    ie_blob_t *input_blob;
 } OVModel;
 
 #define APPEND_STRING(generated_string, iterate_string)                                            \
@@ -133,49 +132,6 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
     return DNN_ERROR;
 }
 
-static DNNReturnType set_input_ov(void *model, AVFrame *frame, const char *input_name)
-{
-    OVModel *ov_model = (OVModel *)model;
-    OVContext *ctx = &ov_model->ctx;
-    IEStatusCode status;
-    dimensions_t dims;
-    precision_e precision;
-    ie_blob_buffer_t blob_buffer;
-    DNNData input;
-
-    status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob);
-    if (status != OK)
-        goto err;
-
-    status |= ie_blob_get_dims(ov_model->input_blob, &dims);
-    status |= ie_blob_get_precision(ov_model->input_blob, &precision);
-    if (status != OK)
-        goto err;
-
-    status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer);
-    if (status != OK)
-        goto err;
-
-    input.height = dims.dims[2];
-    input.width = dims.dims[3];
-    input.channels = dims.dims[1];
-    input.data = blob_buffer.buffer;
-    input.dt = precision_to_datatype(precision);
-    if (ov_model->model->pre_proc != NULL) {
-        ov_model->model->pre_proc(frame, &input, ov_model->model->userdata);
-    } else {
-        proc_from_frame_to_dnn(frame, &input, ctx);
-    }
-
-    return DNN_SUCCESS;
-
-err:
-    if (ov_model->input_blob)
-        ie_blob_free(&ov_model->input_blob);
-    av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n");
-    return DNN_ERROR;
-}
-
 DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata)
 {
     char *all_dev_names = NULL;
@@ -234,7 +190,6 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
         goto err;
 
     model->model = (void *)ov_model;
-    model->set_input = &set_input_ov;
     model->get_input = &get_input_ov;
     model->options = options;
     model->userdata = userdata;
@@ -258,7 +213,8 @@ err:
     return NULL;
 }
 
-DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
+DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     char *model_output_name = NULL;
     char *all_output_names = NULL;
@@ -269,7 +225,39 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output
     OVContext *ctx = &ov_model->ctx;
     IEStatusCode status;
     size_t model_output_count = 0;
-    DNNData output;
+    DNNData input, output;
+    ie_blob_t *input_blob = NULL;
+
+    status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &input_blob);
+    if (status != OK) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to get input blob\n");
+        return DNN_ERROR;
+    }
+
+    status |= ie_blob_get_dims(input_blob, &dims);
+    status |= ie_blob_get_precision(input_blob, &precision);
+    if (status != OK) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to get input blob dims/precision\n");
+        return DNN_ERROR;
+    }
+
+    status = ie_blob_get_buffer(input_blob, &blob_buffer);
+    if (status != OK) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to get input blob buffer\n");
+        return DNN_ERROR;
+    }
+
+    input.height = dims.dims[2];
+    input.width = dims.dims[3];
+    input.channels = dims.dims[1];
+    input.data = blob_buffer.buffer;
+    input.dt = precision_to_datatype(precision);
+    if (ov_model->model->pre_proc != NULL) {
+        ov_model->model->pre_proc(in_frame, &input, ov_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(in_frame, &input, ctx);
+    }
+    ie_blob_free(&input_blob);
 
     if (nb_output != 1) {
         // currently, the filter does not need multiple outputs,
@@ -330,6 +318,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output
                 proc_from_dnn_to_frame(out_frame, &output, ctx);
             }
         }
+        ie_blob_free(&output_blob);
     }
 
     return DNN_SUCCESS;
@@ -339,8 +328,6 @@ void ff_dnn_free_model_ov(DNNModel **model)
 {
     if (*model){
         OVModel *ov_model = (OVModel *)(*model)->model;
-        if (ov_model->input_blob)
-            ie_blob_free(&ov_model->input_blob);
         if (ov_model->infer_request)
             ie_infer_request_free(&ov_model->infer_request);
         if (ov_model->exe_network)
index efb349cb49776d8ffbb411b79a06d064d15953e1..3f8f01da60a890a639664928fa08378db0353803 100644 (file)
@@ -31,7 +31,8 @@
 
 DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_ov(DNNModel **model);
 
index c2d8c06931346f64fc05bbfe9b6838d5ba434b42..8467f8a459b27308ce153aa3704fd5e3aa2e31ec 100644 (file)
@@ -45,8 +45,6 @@ typedef struct TFModel{
     TF_Graph *graph;
     TF_Session *session;
     TF_Status *status;
-    TF_Output input;
-    TF_Tensor *input_tensor;
 } TFModel;
 
 static const AVClass dnn_tensorflow_class = {
@@ -152,48 +150,33 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
     return DNN_SUCCESS;
 }
 
-static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input_name)
+static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
 {
-    TFModel *tf_model = (TFModel *)model;
     TFContext *ctx = &tf_model->ctx;
-    DNNData input;
+    TF_Buffer *graph_def;
+    TF_ImportGraphDefOptions *graph_opts;
     TF_SessionOptions *sess_opts;
-    const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
-
-    if (get_input_tf(model, &input, input_name) != DNN_SUCCESS)
-        return DNN_ERROR;
-    input.height = frame->height;
-    input.width = frame->width;
+    const TF_Operation *init_op;
 
-    // Input operation
-    tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
-    if (!tf_model->input.oper){
-        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
+    graph_def = read_graph(model_filename);
+    if (!graph_def){
+        av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
         return DNN_ERROR;
     }
-    tf_model->input.index = 0;
-    if (tf_model->input_tensor){
-        TF_DeleteTensor(tf_model->input_tensor);
-    }
-    tf_model->input_tensor = allocate_input_tensor(&input);
-    if (!tf_model->input_tensor){
-        av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n");
+    tf_model->graph = TF_NewGraph();
+    tf_model->status = TF_NewStatus();
+    graph_opts = TF_NewImportGraphDefOptions();
+    TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
+    TF_DeleteImportGraphDefOptions(graph_opts);
+    TF_DeleteBuffer(graph_def);
+    if (TF_GetCode(tf_model->status) != TF_OK){
+        TF_DeleteGraph(tf_model->graph);
+        TF_DeleteStatus(tf_model->status);
+        av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
         return DNN_ERROR;
     }
-    input.data = (float *)TF_TensorData(tf_model->input_tensor);
-
-    if (tf_model->model->pre_proc != NULL) {
-        tf_model->model->pre_proc(frame, &input, tf_model->model->userdata);
-    } else {
-        proc_from_frame_to_dnn(frame, &input, ctx);
-    }
-
-    // session
-    if (tf_model->session){
-        TF_CloseSession(tf_model->session, tf_model->status);
-        TF_DeleteSession(tf_model->session, tf_model->status);
-    }
 
+    init_op = TF_GraphOperationByName(tf_model->graph, "init");
     sess_opts = TF_NewSessionOptions();
     tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
     TF_DeleteSessionOptions(sess_opts);
@@ -219,33 +202,6 @@ static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input
     return DNN_SUCCESS;
 }
 
-static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
-{
-    TFContext *ctx = &tf_model->ctx;
-    TF_Buffer *graph_def;
-    TF_ImportGraphDefOptions *graph_opts;
-
-    graph_def = read_graph(model_filename);
-    if (!graph_def){
-        av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
-        return DNN_ERROR;
-    }
-    tf_model->graph = TF_NewGraph();
-    tf_model->status = TF_NewStatus();
-    graph_opts = TF_NewImportGraphDefOptions();
-    TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
-    TF_DeleteImportGraphDefOptions(graph_opts);
-    TF_DeleteBuffer(graph_def);
-    if (TF_GetCode(tf_model->status) != TF_OK){
-        TF_DeleteGraph(tf_model->graph);
-        TF_DeleteStatus(tf_model->status);
-        av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
-        return DNN_ERROR;
-    }
-
-    return DNN_SUCCESS;
-}
-
 #define NAME_BUFFER_SIZE 256
 
 static DNNReturnType add_conv_layer(TFModel *tf_model, TF_Operation *transpose_op, TF_Operation **cur_op,
@@ -626,7 +582,6 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
     }
 
     model->model = (void *)tf_model;
-    model->set_input = &set_input_tf;
     model->get_input = &get_input_tf;
     model->options = options;
     model->userdata = userdata;
@@ -634,13 +589,40 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
     return model;
 }
 
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     TF_Output *tf_outputs;
     TFModel *tf_model = (TFModel *)model->model;
     TFContext *ctx = &tf_model->ctx;
-    DNNData output;
+    DNNData input, output;
     TF_Tensor **output_tensors;
+    TF_Output tf_input;
+    TF_Tensor *input_tensor;
+
+    if (get_input_tf(tf_model, &input, input_name) != DNN_SUCCESS)
+        return DNN_ERROR;
+    input.height = in_frame->height;
+    input.width = in_frame->width;
+
+    tf_input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
+    if (!tf_input.oper){
+        av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
+        return DNN_ERROR;
+    }
+    tf_input.index = 0;
+    input_tensor = allocate_input_tensor(&input);
+    if (!input_tensor){
+        av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n");
+        return DNN_ERROR;
+    }
+    input.data = (float *)TF_TensorData(input_tensor);
+
+    if (tf_model->model->pre_proc != NULL) {
+        tf_model->model->pre_proc(in_frame, &input, tf_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(in_frame, &input, ctx);
+    }
 
     if (nb_output != 1) {
         // currently, the filter does not need multiple outputs,
@@ -674,7 +656,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output
     }
 
     TF_SessionRun(tf_model->session, NULL,
-                  &tf_model->input, &tf_model->input_tensor, 1,
+                  &tf_input, &input_tensor, 1,
                   tf_outputs, output_tensors, nb_output,
                   NULL, 0, NULL, tf_model->status);
     if (TF_GetCode(tf_model->status) != TF_OK) {
@@ -708,6 +690,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output
             TF_DeleteTensor(output_tensors[i]);
         }
     }
+    TF_DeleteTensor(input_tensor);
     av_freep(&output_tensors);
     av_freep(&tf_outputs);
     return DNN_SUCCESS;
@@ -729,9 +712,6 @@ void ff_dnn_free_model_tf(DNNModel **model)
         if (tf_model->status){
             TF_DeleteStatus(tf_model->status);
         }
-        if (tf_model->input_tensor){
-            TF_DeleteTensor(tf_model->input_tensor);
-        }
         av_freep(&tf_model);
         av_freep(model);
     }
index f379e83d8d6cdc1f1694c07786354932f88440ff..1e006697360eeba3b414e0ed8ec96748678352d9 100644 (file)
@@ -31,7 +31,8 @@
 
 DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_tf(DNNModel **model);
 
index 6debc506078295da569d3b88de0d26e2d675ffd1..0369ee4f716ecf0c420fc6badf22bb613710f127 100644 (file)
@@ -51,9 +51,6 @@ typedef struct DNNModel{
     // Gets model input information
     // Just reuse struct DNNData here, actually the DNNData.data field is not needed.
     DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
-    // Sets model input.
-    // Should be called every time before model execution.
-    DNNReturnType (*set_input)(void *model, AVFrame *frame, const char *input_name);
     // set the pre process to transfer data from AVFrame to DNNData
     // the default implementation within DNN is used if it is not provided by the filter
     int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data);
@@ -66,8 +63,9 @@ typedef struct DNNModel{
 typedef struct DNNModule{
     // Loads model and parameters from given file. Returns NULL if it is not possible.
     DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata);
-    // Executes model with specified output. Returns DNN_ERROR otherwise.
-    DNNReturnType (*execute_model)(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+    // Executes model with specified input and output. Returns DNN_ERROR otherwise.
+    DNNReturnType (*execute_model)(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                   const char **output_names, uint32_t nb_output, AVFrame *out_frame);
     // Frees memory allocated for model.
     void (*free_model)(DNNModel **model);
 } DNNModule;
index a59cd6e941ad28cb8fabc6842dfa9589db169857..77dd401263608063b35d7af7fffcaaa653a48edd 100644 (file)
@@ -80,13 +80,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     const char *model_output_name = "y";
     AVFrame *out;
 
-    dnn_result = (dr_context->model->set_input)(dr_context->model->model, in, "x");
-    if (dnn_result != DNN_SUCCESS) {
-        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
-        av_frame_free(&in);
-        return AVERROR(EIO);
-    }
-
     out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     if (!out) {
         av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n");
@@ -95,7 +88,7 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     }
     av_frame_copy_props(out, in);
 
-    dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &model_output_name, 1, out);
+    dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, "x", in, &model_output_name, 1, out);
     if (dnn_result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
         av_frame_free(&in);
index d7462bc82811ff38a9e03b619fec06696bd590f6..2c8578c9b0626caa5289917744692ac2766c71aa 100644 (file)
@@ -236,15 +236,11 @@ static int config_output(AVFilterLink *outlink)
     AVFrame *out = NULL;
 
     AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname);
-    if (result != DNN_SUCCESS) {
-        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
-        return AVERROR(EIO);
-    }
 
     // have a try run in case that the dnn model resize the frame
     out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out);
+    result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, fake_in,
+                                              (const char **)&ctx->model_outputname, 1, out);
     if (result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
         return AVERROR(EIO);
@@ -293,13 +289,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     DNNReturnType dnn_result;
     AVFrame *out;
 
-    dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname);
-    if (dnn_result != DNN_SUCCESS) {
-        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
-        av_frame_free(&in);
-        return AVERROR(EIO);
-    }
-
     out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     if (!out) {
         av_frame_free(&in);
@@ -307,7 +296,8 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     }
     av_frame_copy_props(out, in);
 
-    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out);
+    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, in,
+                                                  (const char **)&ctx->model_outputname, 1, out);
     if (dnn_result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
         av_frame_free(&in);
index 2eda8c3219c29e46e9aa68ad3a1f3812dde289cc..72a3137262733a8b94495132d31259ec965bae4a 100644 (file)
@@ -114,16 +114,11 @@ static int config_output(AVFilterLink *outlink)
     AVFrame *out = NULL;
     const char *model_output_name = "y";
 
-    AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->model->set_input)(ctx->model->model, fake_in, "x");
-    if (result != DNN_SUCCESS) {
-        av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
-        return AVERROR(EIO);
-    }
-
     // have a try run in case that the dnn model resize the frame
+    AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
     out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
+    result = (ctx->dnn_module->execute_model)(ctx->model, "x", fake_in,
+                                              (const char **)&model_output_name, 1, out);
     if (result != DNN_SUCCESS){
         av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
         return AVERROR(EIO);
@@ -178,19 +173,13 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
         sws_scale(ctx->sws_pre_scale,
                     (const uint8_t **)in->data, in->linesize, 0, in->height,
                     out->data, out->linesize);
-        dnn_result = (ctx->model->set_input)(ctx->model->model, out, "x");
+        dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", out,
+                                                      (const char **)&model_output_name, 1, out);
     } else {
-        dnn_result = (ctx->model->set_input)(ctx->model->model, in, "x");
-    }
-
-    if (dnn_result != DNN_SUCCESS) {
-        av_frame_free(&in);
-        av_frame_free(&out);
-        av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
-        return AVERROR(EIO);
+        dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", in,
+                                                      (const char **)&model_output_name, 1, out);
     }
 
-    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
     if (dnn_result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n");
         av_frame_free(&in);