]> git.sesse.net Git - ffmpeg/blob - libavfilter/vf_dnn_processing.c
avfilter/vf_dnn_processing: refine code for better naming
[ffmpeg] / libavfilter / vf_dnn_processing.c
1 /*
2  * Copyright (c) 2019 Guo Yejun
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 /**
22  * @file
23  * implementing a generic image processing filter using deep learning networks.
24  */
25
26 #include "libavformat/avio.h"
27 #include "libavutil/opt.h"
28 #include "libavutil/pixdesc.h"
29 #include "libavutil/avassert.h"
30 #include "avfilter.h"
31 #include "dnn_interface.h"
32 #include "formats.h"
33 #include "internal.h"
34
35 typedef struct DnnProcessingContext {
36     const AVClass *class;
37
38     char *model_filename;
39     DNNBackendType backend_type;
40     enum AVPixelFormat fmt;
41     char *model_inputname;
42     char *model_outputname;
43
44     DNNModule *dnn_module;
45     DNNModel *model;
46
47     // input & output of the model at execution time
48     DNNData input;
49     DNNData output;
50 } DnnProcessingContext;
51
52 #define OFFSET(x) offsetof(DnnProcessingContext, x)
53 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
54 static const AVOption dnn_processing_options[] = {
55     { "dnn_backend", "DNN backend",                OFFSET(backend_type),     AV_OPT_TYPE_INT,       { .i64 = 0 },    0, 1, FLAGS, "backend" },
56     { "native",      "native backend flag",        0,                        AV_OPT_TYPE_CONST,     { .i64 = 0 },    0, 0, FLAGS, "backend" },
57 #if (CONFIG_LIBTENSORFLOW == 1)
58     { "tensorflow",  "tensorflow backend flag",    0,                        AV_OPT_TYPE_CONST,     { .i64 = 1 },    0, 0, FLAGS, "backend" },
59 #endif
60     { "model",       "path to model file",         OFFSET(model_filename),   AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },
61     { "input",       "input name of the model",    OFFSET(model_inputname),  AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },
62     { "output",      "output name of the model",   OFFSET(model_outputname), AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },
63     { "fmt",         "AVPixelFormat of the frame", OFFSET(fmt),              AV_OPT_TYPE_PIXEL_FMT, { .i64=AV_PIX_FMT_RGB24 }, AV_PIX_FMT_NONE, AV_PIX_FMT_NB - 1, FLAGS },
64     { NULL }
65 };
66
67 AVFILTER_DEFINE_CLASS(dnn_processing);
68
69 static av_cold int init(AVFilterContext *context)
70 {
71     DnnProcessingContext *ctx = context->priv;
72     int supported = 0;
73     // as the first step, only rgb24 and bgr24 are supported
74     const enum AVPixelFormat supported_pixel_fmts[] = {
75         AV_PIX_FMT_RGB24,
76         AV_PIX_FMT_BGR24,
77     };
78     for (int i = 0; i < sizeof(supported_pixel_fmts) / sizeof(enum AVPixelFormat); ++i) {
79         if (supported_pixel_fmts[i] == ctx->fmt) {
80             supported = 1;
81             break;
82         }
83     }
84     if (!supported) {
85         av_log(context, AV_LOG_ERROR, "pixel fmt %s not supported yet\n",
86                                        av_get_pix_fmt_name(ctx->fmt));
87         return AVERROR(AVERROR_INVALIDDATA);
88     }
89
90     if (!ctx->model_filename) {
91         av_log(ctx, AV_LOG_ERROR, "model file for network is not specified\n");
92         return AVERROR(EINVAL);
93     }
94     if (!ctx->model_inputname) {
95         av_log(ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
96         return AVERROR(EINVAL);
97     }
98     if (!ctx->model_outputname) {
99         av_log(ctx, AV_LOG_ERROR, "output name of the model network is not specified\n");
100         return AVERROR(EINVAL);
101     }
102
103     ctx->dnn_module = ff_get_dnn_module(ctx->backend_type);
104     if (!ctx->dnn_module) {
105         av_log(ctx, AV_LOG_ERROR, "could not create DNN module for requested backend\n");
106         return AVERROR(ENOMEM);
107     }
108     if (!ctx->dnn_module->load_model) {
109         av_log(ctx, AV_LOG_ERROR, "load_model for network is not specified\n");
110         return AVERROR(EINVAL);
111     }
112
113     ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename);
114     if (!ctx->model) {
115         av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n");
116         return AVERROR(EINVAL);
117     }
118
119     return 0;
120 }
121
122 static int query_formats(AVFilterContext *context)
123 {
124     AVFilterFormats *formats;
125     DnnProcessingContext *ctx = context->priv;
126     enum AVPixelFormat pixel_fmts[2];
127     pixel_fmts[0] = ctx->fmt;
128     pixel_fmts[1] = AV_PIX_FMT_NONE;
129
130     formats = ff_make_format_list(pixel_fmts);
131     return ff_set_common_formats(context, formats);
132 }
133
134 static int config_input(AVFilterLink *inlink)
135 {
136     AVFilterContext *context     = inlink->dst;
137     DnnProcessingContext *ctx = context->priv;
138     DNNReturnType result;
139     DNNData model_input;
140
141     result = ctx->model->get_input(ctx->model->model, &model_input, ctx->model_inputname);
142     if (result != DNN_SUCCESS) {
143         av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n");
144         return AVERROR(EIO);
145     }
146
147     // the design is to add explicit scale filter before this filter
148     if (model_input.height != -1 && model_input.height != inlink->h) {
149         av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n",
150                                    model_input.height, inlink->h);
151         return AVERROR(EIO);
152     }
153     if (model_input.width != -1 && model_input.width != inlink->w) {
154         av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n",
155                                    model_input.width, inlink->w);
156         return AVERROR(EIO);
157     }
158
159     if (model_input.channels != 3) {
160         av_log(ctx, AV_LOG_ERROR, "the model requires input channels %d\n",
161                                    model_input.channels);
162         return AVERROR(EIO);
163     }
164     if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) {
165         av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
166         return AVERROR(EIO);
167     }
168
169     ctx->input.width    = inlink->w;
170     ctx->input.height   = inlink->h;
171     ctx->input.channels = model_input.channels;
172     ctx->input.dt = model_input.dt;
173
174     result = (ctx->model->set_input_output)(ctx->model->model,
175                                         &ctx->input, ctx->model_inputname,
176                                         (const char **)&ctx->model_outputname, 1);
177     if (result != DNN_SUCCESS) {
178         av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n");
179         return AVERROR(EIO);
180     }
181
182     return 0;
183 }
184
185 static int config_output(AVFilterLink *outlink)
186 {
187     AVFilterContext *context = outlink->src;
188     DnnProcessingContext *ctx = context->priv;
189     DNNReturnType result;
190
191     // have a try run in case that the dnn model resize the frame
192     result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, 1);
193     if (result != DNN_SUCCESS){
194         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
195         return AVERROR(EIO);
196     }
197
198     outlink->w = ctx->output.width;
199     outlink->h = ctx->output.height;
200
201     return 0;
202 }
203
204 static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
205 {
206     // extend this function to support more formats
207     av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
208
209     if (dnn_input->dt == DNN_FLOAT) {
210         float *dnn_input_data = dnn_input->data;
211         for (int i = 0; i < frame->height; i++) {
212             for(int j = 0; j < frame->width * 3; j++) {
213                 int k = i * frame->linesize[0] + j;
214                 int t = i * frame->width * 3 + j;
215                 dnn_input_data[t] = frame->data[0][k] / 255.0f;
216             }
217         }
218     } else {
219         uint8_t *dnn_input_data = dnn_input->data;
220         av_assert0(dnn_input->dt == DNN_UINT8);
221         for (int i = 0; i < frame->height; i++) {
222             for(int j = 0; j < frame->width * 3; j++) {
223                 int k = i * frame->linesize[0] + j;
224                 int t = i * frame->width * 3 + j;
225                 dnn_input_data[t] = frame->data[0][k];
226             }
227         }
228     }
229
230     return 0;
231 }
232
233 static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output)
234 {
235     // extend this function to support more formats
236     av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
237
238     if (dnn_output->dt == DNN_FLOAT) {
239         float *dnn_output_data = dnn_output->data;
240         for (int i = 0; i < frame->height; i++) {
241             for(int j = 0; j < frame->width * 3; j++) {
242                 int k = i * frame->linesize[0] + j;
243                 int t = i * frame->width * 3 + j;
244                 frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
245             }
246         }
247     } else {
248         uint8_t *dnn_output_data = dnn_output->data;
249         av_assert0(dnn_output->dt == DNN_UINT8);
250         for (int i = 0; i < frame->height; i++) {
251             for(int j = 0; j < frame->width * 3; j++) {
252                 int k = i * frame->linesize[0] + j;
253                 int t = i * frame->width * 3 + j;
254                 frame->data[0][k] = dnn_output_data[t];
255             }
256         }
257     }
258
259     return 0;
260 }
261
262 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
263 {
264     AVFilterContext *context  = inlink->dst;
265     AVFilterLink *outlink = context->outputs[0];
266     DnnProcessingContext *ctx = context->priv;
267     DNNReturnType dnn_result;
268     AVFrame *out;
269
270     copy_from_frame_to_dnn(&ctx->input, in);
271
272     dnn_result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, 1);
273     if (dnn_result != DNN_SUCCESS){
274         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
275         av_frame_free(&in);
276         return AVERROR(EIO);
277     }
278     av_assert0(ctx->output.channels == 3);
279
280     out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
281     if (!out) {
282         av_frame_free(&in);
283         return AVERROR(ENOMEM);
284     }
285
286     av_frame_copy_props(out, in);
287     copy_from_dnn_to_frame(out, &ctx->output);
288     av_frame_free(&in);
289     return ff_filter_frame(outlink, out);
290 }
291
292 static av_cold void uninit(AVFilterContext *ctx)
293 {
294     DnnProcessingContext *context = ctx->priv;
295
296     if (context->dnn_module)
297         (context->dnn_module->free_model)(&context->model);
298
299     av_freep(&context->dnn_module);
300 }
301
302 static const AVFilterPad dnn_processing_inputs[] = {
303     {
304         .name         = "default",
305         .type         = AVMEDIA_TYPE_VIDEO,
306         .config_props = config_input,
307         .filter_frame = filter_frame,
308     },
309     { NULL }
310 };
311
312 static const AVFilterPad dnn_processing_outputs[] = {
313     {
314         .name = "default",
315         .type = AVMEDIA_TYPE_VIDEO,
316         .config_props  = config_output,
317     },
318     { NULL }
319 };
320
321 AVFilter ff_vf_dnn_processing = {
322     .name          = "dnn_processing",
323     .description   = NULL_IF_CONFIG_SMALL("Apply DNN processing filter to the input."),
324     .priv_size     = sizeof(DnnProcessingContext),
325     .init          = init,
326     .uninit        = uninit,
327     .query_formats = query_formats,
328     .inputs        = dnn_processing_inputs,
329     .outputs       = dnn_processing_outputs,
330     .priv_class    = &dnn_processing_class,
331 };