]> git.sesse.net Git - ffmpeg/blobdiff - libavfilter/vf_sr.c
avformat/avio: Add Metacube support
[ffmpeg] / libavfilter / vf_sr.c
index 2eda8c3219c29e46e9aa68ad3a1f3812dde289cc..4360439ca6b21b4b933e4424693832f9464ce9fe 100644 (file)
 #include "libavutil/pixdesc.h"
 #include "libavformat/avio.h"
 #include "libswscale/swscale.h"
-#include "dnn_interface.h"
+#include "dnn_filter_common.h"
 
 typedef struct SRContext {
     const AVClass *class;
-
-    char *model_filename;
-    DNNBackendType backend_type;
-    DNNModule *dnn_module;
-    DNNModel *model;
+    DnnContext dnnctx;
     int scale_factor;
     struct SwsContext *sws_uv_scale;
     int sws_uv_height;
@@ -50,13 +46,15 @@ typedef struct SRContext {
 #define OFFSET(x) offsetof(SRContext, x)
 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
 static const AVOption sr_options[] = {
-    { "dnn_backend", "DNN backend used for model execution", OFFSET(backend_type), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS, "backend" },
+    { "dnn_backend", "DNN backend used for model execution", OFFSET(dnnctx.backend_type), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS, "backend" },
     { "native", "native backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 0 }, 0, 0, FLAGS, "backend" },
 #if (CONFIG_LIBTENSORFLOW == 1)
     { "tensorflow", "tensorflow backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 1 }, 0, 0, FLAGS, "backend" },
 #endif
     { "scale_factor", "scale factor for SRCNN model", OFFSET(scale_factor), AV_OPT_TYPE_INT, { .i64 = 2 }, 2, 4, FLAGS },
-    { "model", "path to model file specifying network architecture and its parameters", OFFSET(model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
+    { "model", "path to model file specifying network architecture and its parameters", OFFSET(dnnctx.model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
+    { "input",       "input name of the model",     OFFSET(dnnctx.model_inputname),  AV_OPT_TYPE_STRING,    { .str = "x" },  0, 0, FLAGS },
+    { "output",      "output name of the model",    OFFSET(dnnctx.model_outputname), AV_OPT_TYPE_STRING,    { .str = "y" },  0, 0, FLAGS },
     { NULL }
 };
 
@@ -65,28 +63,7 @@ AVFILTER_DEFINE_CLASS(sr);
 static av_cold int init(AVFilterContext *context)
 {
     SRContext *sr_context = context->priv;
-
-    sr_context->dnn_module = ff_get_dnn_module(sr_context->backend_type);
-    if (!sr_context->dnn_module){
-        av_log(context, AV_LOG_ERROR, "could not create DNN module for requested backend\n");
-        return AVERROR(ENOMEM);
-    }
-
-    if (!sr_context->model_filename){
-        av_log(context, AV_LOG_ERROR, "model file for network was not specified\n");
-        return AVERROR(EIO);
-    }
-    if (!sr_context->dnn_module->load_model) {
-        av_log(context, AV_LOG_ERROR, "load_model for network was not specified\n");
-        return AVERROR(EIO);
-    }
-    sr_context->model = (sr_context->dnn_module->load_model)(sr_context->model_filename, NULL, NULL);
-    if (!sr_context->model){
-        av_log(context, AV_LOG_ERROR, "could not load DNN model\n");
-        return AVERROR(EIO);
-    }
-
-    return 0;
+    return ff_dnn_init(&sr_context->dnnctx, DFT_PROCESS_FRAME, context);
 }
 
 static int query_formats(AVFilterContext *context)
@@ -111,28 +88,19 @@ static int config_output(AVFilterLink *outlink)
     SRContext *ctx = context->priv;
     DNNReturnType result;
     AVFilterLink *inlink = context->inputs[0];
-    AVFrame *out = NULL;
-    const char *model_output_name = "y";
-
-    AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->model->set_input)(ctx->model->model, fake_in, "x");
-    if (result != DNN_SUCCESS) {
-        av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
-        return AVERROR(EIO);
-    }
+    int out_width, out_height;
 
     // have a try run in case that the dnn model resize the frame
-    out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
-    result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
-    if (result != DNN_SUCCESS){
-        av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
+    result = ff_dnn_get_output(&ctx->dnnctx, inlink->w, inlink->h, &out_width, &out_height);
+    if (result != DNN_SUCCESS) {
+        av_log(ctx, AV_LOG_ERROR, "could not get output from the model\n");
         return AVERROR(EIO);
     }
 
-    if (fake_in->width != out->width || fake_in->height != out->height) {
+    if (inlink->w != out_width || inlink->h != out_height) {
         //espcn
-        outlink->w = out->width;
-        outlink->h = out->height;
+        outlink->w = out_width;
+        outlink->h = out_height;
         if (inlink->format != AV_PIX_FMT_GRAY8){
             const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format);
             int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
@@ -146,15 +114,13 @@ static int config_output(AVFilterLink *outlink)
         }
     } else {
         //srcnn
-        outlink->w = out->width * ctx->scale_factor;
-        outlink->h = out->height * ctx->scale_factor;
+        outlink->w = out_width * ctx->scale_factor;
+        outlink->h = out_height * ctx->scale_factor;
         ctx->sws_pre_scale = sws_getContext(inlink->w, inlink->h, inlink->format,
                                         outlink->w, outlink->h, outlink->format,
                                         SWS_BICUBIC, NULL, NULL, NULL);
     }
 
-    av_frame_free(&fake_in);
-    av_frame_free(&out);
     return 0;
 }
 
@@ -165,7 +131,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     AVFilterLink *outlink = context->outputs[0];
     AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     DNNReturnType dnn_result;
-    const char *model_output_name = "y";
 
     if (!out){
         av_log(context, AV_LOG_ERROR, "could not allocate memory for output frame\n");
@@ -178,19 +143,11 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
         sws_scale(ctx->sws_pre_scale,
                     (const uint8_t **)in->data, in->linesize, 0, in->height,
                     out->data, out->linesize);
-        dnn_result = (ctx->model->set_input)(ctx->model->model, out, "x");
+        dnn_result = ff_dnn_execute_model(&ctx->dnnctx, out, out);
     } else {
-        dnn_result = (ctx->model->set_input)(ctx->model->model, in, "x");
-    }
-
-    if (dnn_result != DNN_SUCCESS) {
-        av_frame_free(&in);
-        av_frame_free(&out);
-        av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
-        return AVERROR(EIO);
+        dnn_result = ff_dnn_execute_model(&ctx->dnnctx, in, out);
     }
 
-    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
     if (dnn_result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n");
         av_frame_free(&in);
@@ -213,11 +170,7 @@ static av_cold void uninit(AVFilterContext *context)
 {
     SRContext *sr_context = context->priv;
 
-    if (sr_context->dnn_module){
-        (sr_context->dnn_module->free_model)(&sr_context->model);
-        av_freep(&sr_context->dnn_module);
-    }
-
+    ff_dnn_uninit(&sr_context->dnnctx);
     sws_freeContext(sr_context->sws_uv_scale);
     sws_freeContext(sr_context->sws_pre_scale);
 }
@@ -240,7 +193,7 @@ static const AVFilterPad sr_outputs[] = {
     { NULL }
 };
 
-AVFilter ff_vf_sr = {
+const AVFilter ff_vf_sr = {
     .name          = "sr",
     .description   = NULL_IF_CONFIG_SMALL("Apply DNN-based image super resolution to the input."),
     .priv_size     = sizeof(SRContext),