]> git.sesse.net Git - ffmpeg/commitdiff
dnn: add function type for model
authorGuo, Yejun <yejun.guo@intel.com>
Sun, 7 Feb 2021 06:35:22 +0000 (14:35 +0800)
committerGuo, Yejun <yejun.guo@intel.com>
Thu, 18 Feb 2021 01:59:37 +0000 (09:59 +0800)
So the backend knows the usage of model is for frame processing,
detect, classify, etc. Each function type has different behavior
in backend when handling the input/output data of the model.

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
12 files changed:
libavfilter/dnn/dnn_backend_native.c
libavfilter/dnn/dnn_backend_native.h
libavfilter/dnn/dnn_backend_openvino.c
libavfilter/dnn/dnn_backend_openvino.h
libavfilter/dnn/dnn_backend_tf.c
libavfilter/dnn/dnn_backend_tf.h
libavfilter/dnn_filter_common.c
libavfilter/dnn_filter_common.h
libavfilter/dnn_interface.h
libavfilter/vf_derain.c
libavfilter/vf_dnn_processing.c
libavfilter/vf_sr.c

index 87f3568cc21fecabe218ba85359032b9922674f2..be6451367a32e3dceec997ff2695db20df58aaab 100644 (file)
@@ -112,7 +112,7 @@ static DNNReturnType get_output_native(void *model, const char *input_name, int
 // layers_num,layer_type,layer_parameterss,layer_type,layer_parameters...
 // For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases
 // For DEPTH_TO_SPACE layer: block_size
-DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, AVFilterContext *filter_ctx)
+DNNModel *ff_dnn_load_model_native(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
 {
     DNNModel *model = NULL;
     char header_expected[] = "FFMPEGDNNNATIVE";
@@ -256,6 +256,7 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio
     model->get_input = &get_input_native;
     model->get_output = &get_output_native;
     model->filter_ctx = filter_ctx;
+    model->func_type = func_type;
 
     return model;
 
index 5c8ce82b35a0dd0f415910c32e6627180b5820c2..d313c48f3a47a01d96c6cf607063307441defd89 100644 (file)
@@ -128,7 +128,7 @@ typedef struct NativeModel{
     int32_t operands_num;
 } NativeModel;
 
-DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, AVFilterContext *filter_ctx);
+DNNModel *ff_dnn_load_model_native(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
 
 DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
                                           const char **output_names, uint32_t nb_output, AVFrame *out_frame);
index ed41b721fc54fd3be0d2d18218da695843a584ad..7c1abb3eebcb01d518e0e1e41c025436cf9a6736 100644 (file)
@@ -524,7 +524,7 @@ static DNNReturnType get_output_ov(void *model, const char *input_name, int inpu
     return ret;
 }
 
-DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, AVFilterContext *filter_ctx)
+DNNModel *ff_dnn_load_model_ov(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
 {
     DNNModel *model = NULL;
     OVModel *ov_model = NULL;
@@ -572,6 +572,7 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
     model->get_output = &get_output_ov;
     model->options = options;
     model->filter_ctx = filter_ctx;
+    model->func_type = func_type;
 
     return model;
 
index 23b819440eb81c10a33594baaf70082b1cdccffa..a484a7be32a75a72841ce9d7a808e62a4bf95a3a 100644 (file)
@@ -29,7 +29,7 @@
 
 #include "../dnn_interface.h"
 
-DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, AVFilterContext *filter_ctx);
+DNNModel *ff_dnn_load_model_ov(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
 
 DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
                                       const char **output_names, uint32_t nb_output, AVFrame *out_frame);
index 71a2a308b5376d86adc2728361bae731526b487b..e7e5f221f385d1f23a83a48bba77dca48da97430 100644 (file)
@@ -580,7 +580,7 @@ static DNNReturnType load_native_model(TFModel *tf_model, const char *model_file
     DNNModel *model = NULL;
     NativeModel *native_model;
 
-    model = ff_dnn_load_model_native(model_filename, NULL, NULL);
+    model = ff_dnn_load_model_native(model_filename, DFT_PROCESS_FRAME, NULL, NULL);
     if (!model){
         av_log(ctx, AV_LOG_ERROR, "Failed to load native model\n");
         return DNN_ERROR;
@@ -664,7 +664,7 @@ static DNNReturnType load_native_model(TFModel *tf_model, const char *model_file
     return DNN_SUCCESS;
 }
 
-DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, AVFilterContext *filter_ctx)
+DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
 {
     DNNModel *model = NULL;
     TFModel *tf_model = NULL;
@@ -705,6 +705,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
     model->get_output = &get_output_tf;
     model->options = options;
     model->filter_ctx = filter_ctx;
+    model->func_type = func_type;
 
     return model;
 }
index cac893672954497b500002f1df12ad0cc4fc4621..8cec04748e72e40b7ac59a5626c966e431e5312b 100644 (file)
@@ -29,7 +29,7 @@
 
 #include "../dnn_interface.h"
 
-DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, AVFilterContext *filter_ctx);
+DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
 
 DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
                                       const char **output_names, uint32_t nb_output, AVFrame *out_frame);
index 5d0d7d3b906c25711fcc427b106c5cc98945889a..413adba4069d232f9bfe9801e74b0db88a0a412e 100644 (file)
@@ -18,7 +18,7 @@
 
 #include "dnn_filter_common.h"
 
-int ff_dnn_init(DnnContext *ctx, AVFilterContext *filter_ctx)
+int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
 {
     if (!ctx->model_filename) {
         av_log(filter_ctx, AV_LOG_ERROR, "model file for network is not specified\n");
@@ -43,7 +43,7 @@ int ff_dnn_init(DnnContext *ctx, AVFilterContext *filter_ctx)
         return AVERROR(EINVAL);
     }
 
-    ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, filter_ctx);
+    ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, func_type, ctx->backend_options, filter_ctx);
     if (!ctx->model) {
         av_log(filter_ctx, AV_LOG_ERROR, "could not load DNN model\n");
         return AVERROR(EINVAL);
index ab49a992ed0d66af2ff9397e35f845d6f26157dd..79c4d3efe35ed8ec7a6253648501c90f9e8644ec 100644 (file)
@@ -47,7 +47,7 @@ typedef struct DnnContext {
     { "async",              "use DNN async inference",    OFFSET(async),            AV_OPT_TYPE_BOOL,      { .i64 = 1},     0, 1, FLAGS},
 
 
-int ff_dnn_init(DnnContext *ctx, AVFilterContext *filter_ctx);
+int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
 DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
 DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
 DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
index ff338ea084e1b67ca848fe108fa5b37563596d1d..2fb9b15676157570b52e69118af22dccb8d86d4e 100644 (file)
@@ -43,6 +43,13 @@ typedef enum {
     DAST_SUCCESS            // got a result frame successfully
 } DNNAsyncStatusType;
 
+typedef enum {
+    DFT_NONE,
+    DFT_PROCESS_FRAME,      // process the whole frame
+    DFT_ANALYTICS_DETECT,   // detect from the whole frame
+    // we can add more such as detect_from_crop, classify_from_bbox, etc.
+}DNNFunctionType;
+
 typedef struct DNNData{
     void *data;
     DNNDataType dt;
@@ -56,6 +63,8 @@ typedef struct DNNModel{
     const char *options;
     // Stores FilterContext used for the interaction between AVFrame and DNNData
     AVFilterContext *filter_ctx;
+    // Stores function type of the model
+    DNNFunctionType func_type;
     // Gets model input information
     // Just reuse struct DNNData here, actually the DNNData.data field is not needed.
     DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
@@ -73,7 +82,7 @@ typedef struct DNNModel{
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
 typedef struct DNNModule{
     // Loads model and parameters from given file. Returns NULL if it is not possible.
-    DNNModel *(*load_model)(const char *model_filename, const char *options, AVFilterContext *filter_ctx);
+    DNNModel *(*load_model)(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
     // Executes model with specified input and output. Returns DNN_ERROR otherwise.
     DNNReturnType (*execute_model)(const DNNModel *model, const char *input_name, AVFrame *in_frame,
                                    const char **output_names, uint32_t nb_output, AVFrame *out_frame);
index ec9853d95751c5de8886a729062faec2c77c20dc..7814fc1e03de5ec2ba211cf4dea73283cfc17173 100644 (file)
@@ -100,7 +100,7 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
 static av_cold int init(AVFilterContext *ctx)
 {
     DRContext *dr_context = ctx->priv;
-    return ff_dnn_init(&dr_context->dnnctx, ctx);
+    return ff_dnn_init(&dr_context->dnnctx, DFT_PROCESS_FRAME, ctx);
 }
 
 static av_cold void uninit(AVFilterContext *ctx)
index 08ebf122c923f420fa63c2e28431fbd8dd6e217b..88e95e8ae3a5917a64b2724923b0009c411246a2 100644 (file)
@@ -62,7 +62,7 @@ AVFILTER_DEFINE_CLASS(dnn_processing);
 static av_cold int init(AVFilterContext *context)
 {
     DnnProcessingContext *ctx = context->priv;
-    return ff_dnn_init(&ctx->dnnctx, context);
+    return ff_dnn_init(&ctx->dnnctx, DFT_PROCESS_FRAME, context);
 }
 
 static int query_formats(AVFilterContext *context)
index 20334a84c49ee2ea60d87c3bfbc807a08734f7a9..45f941acdb2268e08bc0398e69ead7f42b2938ba 100644 (file)
@@ -63,7 +63,7 @@ AVFILTER_DEFINE_CLASS(sr);
 static av_cold int init(AVFilterContext *context)
 {
     SRContext *sr_context = context->priv;
-    return ff_dnn_init(&sr_context->dnnctx, context);
+    return ff_dnn_init(&sr_context->dnnctx, DFT_PROCESS_FRAME, context);
 }
 
 static int query_formats(AVFilterContext *context)