]> git.sesse.net Git - ffmpeg/blobdiff - libavfilter/vf_sr.c
avformat/avio: Add Metacube support
[ffmpeg] / libavfilter / vf_sr.c
index fe6c5d3c0da509fb4d4426bd18695848c40e0972..4360439ca6b21b4b933e4424693832f9464ce9fe 100644 (file)
 #include "libavutil/pixdesc.h"
 #include "libavformat/avio.h"
 #include "libswscale/swscale.h"
-#include "dnn_interface.h"
+#include "dnn_filter_common.h"
 
 typedef struct SRContext {
     const AVClass *class;
-
-    char *model_filename;
-    DNNBackendType backend_type;
-    DNNModule *dnn_module;
-    DNNModel *model;
+    DnnContext dnnctx;
     int scale_factor;
     struct SwsContext *sws_uv_scale;
     int sws_uv_height;
@@ -50,13 +46,15 @@ typedef struct SRContext {
 #define OFFSET(x) offsetof(SRContext, x)
 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
 static const AVOption sr_options[] = {
-    { "dnn_backend", "DNN backend used for model execution", OFFSET(backend_type), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS, "backend" },
+    { "dnn_backend", "DNN backend used for model execution", OFFSET(dnnctx.backend_type), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS, "backend" },
     { "native", "native backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 0 }, 0, 0, FLAGS, "backend" },
 #if (CONFIG_LIBTENSORFLOW == 1)
     { "tensorflow", "tensorflow backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 1 }, 0, 0, FLAGS, "backend" },
 #endif
     { "scale_factor", "scale factor for SRCNN model", OFFSET(scale_factor), AV_OPT_TYPE_INT, { .i64 = 2 }, 2, 4, FLAGS },
-    { "model", "path to model file specifying network architecture and its parameters", OFFSET(model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
+    { "model", "path to model file specifying network architecture and its parameters", OFFSET(dnnctx.model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
+    { "input",       "input name of the model",     OFFSET(dnnctx.model_inputname),  AV_OPT_TYPE_STRING,    { .str = "x" },  0, 0, FLAGS },
+    { "output",      "output name of the model",    OFFSET(dnnctx.model_outputname), AV_OPT_TYPE_STRING,    { .str = "y" },  0, 0, FLAGS },
     { NULL }
 };
 
@@ -65,28 +63,7 @@ AVFILTER_DEFINE_CLASS(sr);
 static av_cold int init(AVFilterContext *context)
 {
     SRContext *sr_context = context->priv;
-
-    sr_context->dnn_module = ff_get_dnn_module(sr_context->backend_type);
-    if (!sr_context->dnn_module){
-        av_log(context, AV_LOG_ERROR, "could not create DNN module for requested backend\n");
-        return AVERROR(ENOMEM);
-    }
-
-    if (!sr_context->model_filename){
-        av_log(context, AV_LOG_ERROR, "model file for network was not specified\n");
-        return AVERROR(EIO);
-    }
-    if (!sr_context->dnn_module->load_model) {
-        av_log(context, AV_LOG_ERROR, "load_model for network was not specified\n");
-        return AVERROR(EIO);
-    }
-    sr_context->model = (sr_context->dnn_module->load_model)(sr_context->model_filename, NULL, NULL);
-    if (!sr_context->model){
-        av_log(context, AV_LOG_ERROR, "could not load DNN model\n");
-        return AVERROR(EIO);
-    }
-
-    return 0;
+    return ff_dnn_init(&sr_context->dnnctx, DFT_PROCESS_FRAME, context);
 }
 
 static int query_formats(AVFilterContext *context)
@@ -114,8 +91,7 @@ static int config_output(AVFilterLink *outlink)
     int out_width, out_height;
 
     // have a try run in case that the dnn model resize the frame
-    result = ctx->model->get_output(ctx->model->model, "x", inlink->w, inlink->h,
-                                    "y", &out_width, &out_height);
+    result = ff_dnn_get_output(&ctx->dnnctx, inlink->w, inlink->h, &out_width, &out_height);
     if (result != DNN_SUCCESS) {
         av_log(ctx, AV_LOG_ERROR, "could not get output from the model\n");
         return AVERROR(EIO);
@@ -155,7 +131,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     AVFilterLink *outlink = context->outputs[0];
     AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     DNNReturnType dnn_result;
-    const char *model_output_name = "y";
 
     if (!out){
         av_log(context, AV_LOG_ERROR, "could not allocate memory for output frame\n");
@@ -168,11 +143,9 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
         sws_scale(ctx->sws_pre_scale,
                     (const uint8_t **)in->data, in->linesize, 0, in->height,
                     out->data, out->linesize);
-        dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", out,
-                                                      (const char **)&model_output_name, 1, out);
+        dnn_result = ff_dnn_execute_model(&ctx->dnnctx, out, out);
     } else {
-        dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", in,
-                                                      (const char **)&model_output_name, 1, out);
+        dnn_result = ff_dnn_execute_model(&ctx->dnnctx, in, out);
     }
 
     if (dnn_result != DNN_SUCCESS){
@@ -197,11 +170,7 @@ static av_cold void uninit(AVFilterContext *context)
 {
     SRContext *sr_context = context->priv;
 
-    if (sr_context->dnn_module){
-        (sr_context->dnn_module->free_model)(&sr_context->model);
-        av_freep(&sr_context->dnn_module);
-    }
-
+    ff_dnn_uninit(&sr_context->dnnctx);
     sws_freeContext(sr_context->sws_uv_scale);
     sws_freeContext(sr_context->sws_pre_scale);
 }
@@ -224,7 +193,7 @@ static const AVFilterPad sr_outputs[] = {
     { NULL }
 };
 
-AVFilter ff_vf_sr = {
+const AVFilter ff_vf_sr = {
     .name          = "sr",
     .description   = NULL_IF_CONFIG_SMALL("Apply DNN-based image super resolution to the input."),
     .priv_size     = sizeof(SRContext),