2 * Copyright (c) 2019 Guo Yejun
4 * This file is part of FFmpeg.
6 * FFmpeg is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
11 * FFmpeg is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with FFmpeg; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
23 * implementing a generic image processing filter using deep learning networks.
26 #include "libavformat/avio.h"
27 #include "libavutil/opt.h"
28 #include "libavutil/pixdesc.h"
29 #include "libavutil/avassert.h"
30 #include "libavutil/imgutils.h"
32 #include "dnn_interface.h"
35 #include "libswscale/swscale.h"
37 typedef struct DnnProcessingContext {
41 DNNBackendType backend_type;
42 char *model_inputname;
43 char *model_outputname;
45 DNNModule *dnn_module;
48 // input & output of the model at execution time
52 struct SwsContext *sws_gray8_to_grayf32;
53 struct SwsContext *sws_grayf32_to_gray8;
54 } DnnProcessingContext;
56 #define OFFSET(x) offsetof(DnnProcessingContext, x)
57 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
58 static const AVOption dnn_processing_options[] = {
59 { "dnn_backend", "DNN backend", OFFSET(backend_type), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS, "backend" },
60 { "native", "native backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 0 }, 0, 0, FLAGS, "backend" },
61 #if (CONFIG_LIBTENSORFLOW == 1)
62 { "tensorflow", "tensorflow backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 1 }, 0, 0, FLAGS, "backend" },
64 { "model", "path to model file", OFFSET(model_filename), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
65 { "input", "input name of the model", OFFSET(model_inputname), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
66 { "output", "output name of the model", OFFSET(model_outputname), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
70 AVFILTER_DEFINE_CLASS(dnn_processing);
72 static av_cold int init(AVFilterContext *context)
74 DnnProcessingContext *ctx = context->priv;
76 if (!ctx->model_filename) {
77 av_log(ctx, AV_LOG_ERROR, "model file for network is not specified\n");
78 return AVERROR(EINVAL);
80 if (!ctx->model_inputname) {
81 av_log(ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
82 return AVERROR(EINVAL);
84 if (!ctx->model_outputname) {
85 av_log(ctx, AV_LOG_ERROR, "output name of the model network is not specified\n");
86 return AVERROR(EINVAL);
89 ctx->dnn_module = ff_get_dnn_module(ctx->backend_type);
90 if (!ctx->dnn_module) {
91 av_log(ctx, AV_LOG_ERROR, "could not create DNN module for requested backend\n");
92 return AVERROR(ENOMEM);
94 if (!ctx->dnn_module->load_model) {
95 av_log(ctx, AV_LOG_ERROR, "load_model for network is not specified\n");
96 return AVERROR(EINVAL);
99 ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename);
101 av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n");
102 return AVERROR(EINVAL);
108 static int query_formats(AVFilterContext *context)
110 static const enum AVPixelFormat pix_fmts[] = {
111 AV_PIX_FMT_RGB24, AV_PIX_FMT_BGR24,
112 AV_PIX_FMT_GRAY8, AV_PIX_FMT_GRAYF32,
113 AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P,
114 AV_PIX_FMT_YUV444P, AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P,
117 AVFilterFormats *fmts_list = ff_make_format_list(pix_fmts);
118 return ff_set_common_formats(context, fmts_list);
121 #define LOG_FORMAT_CHANNEL_MISMATCH() \
122 av_log(ctx, AV_LOG_ERROR, \
123 "the frame's format %s does not match " \
124 "the model input channel %d\n", \
125 av_get_pix_fmt_name(fmt), \
126 model_input->channels);
128 static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink)
130 AVFilterContext *ctx = inlink->dst;
131 enum AVPixelFormat fmt = inlink->format;
133 // the design is to add explicit scale filter before this filter
134 if (model_input->height != -1 && model_input->height != inlink->h) {
135 av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n",
136 model_input->height, inlink->h);
139 if (model_input->width != -1 && model_input->width != inlink->w) {
140 av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n",
141 model_input->width, inlink->w);
146 case AV_PIX_FMT_RGB24:
147 case AV_PIX_FMT_BGR24:
148 if (model_input->channels != 3) {
149 LOG_FORMAT_CHANNEL_MISMATCH();
152 if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) {
153 av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
157 case AV_PIX_FMT_GRAY8:
158 if (model_input->channels != 1) {
159 LOG_FORMAT_CHANNEL_MISMATCH();
162 if (model_input->dt != DNN_UINT8) {
163 av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n");
167 case AV_PIX_FMT_GRAYF32:
168 case AV_PIX_FMT_YUV420P:
169 case AV_PIX_FMT_YUV422P:
170 case AV_PIX_FMT_YUV444P:
171 case AV_PIX_FMT_YUV410P:
172 case AV_PIX_FMT_YUV411P:
173 if (model_input->channels != 1) {
174 LOG_FORMAT_CHANNEL_MISMATCH();
177 if (model_input->dt != DNN_FLOAT) {
178 av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n");
183 av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt));
190 static int config_input(AVFilterLink *inlink)
192 AVFilterContext *context = inlink->dst;
193 DnnProcessingContext *ctx = context->priv;
194 DNNReturnType result;
198 result = ctx->model->get_input(ctx->model->model, &model_input, ctx->model_inputname);
199 if (result != DNN_SUCCESS) {
200 av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n");
204 check = check_modelinput_inlink(&model_input, inlink);
209 ctx->input.width = inlink->w;
210 ctx->input.height = inlink->h;
211 ctx->input.channels = model_input.channels;
212 ctx->input.dt = model_input.dt;
214 result = (ctx->model->set_input_output)(ctx->model->model,
215 &ctx->input, ctx->model_inputname,
216 (const char **)&ctx->model_outputname, 1);
217 if (result != DNN_SUCCESS) {
218 av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n");
225 static int prepare_sws_context(AVFilterLink *outlink)
227 AVFilterContext *context = outlink->src;
228 DnnProcessingContext *ctx = context->priv;
229 AVFilterLink *inlink = context->inputs[0];
230 enum AVPixelFormat fmt = inlink->format;
231 DNNDataType input_dt = ctx->input.dt;
232 DNNDataType output_dt = ctx->output.dt;
235 case AV_PIX_FMT_RGB24:
236 case AV_PIX_FMT_BGR24:
237 if (input_dt == DNN_FLOAT) {
238 ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w * 3,
244 0, NULL, NULL, NULL);
246 if (output_dt == DNN_FLOAT) {
247 ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w * 3,
253 0, NULL, NULL, NULL);
256 case AV_PIX_FMT_YUV420P:
257 case AV_PIX_FMT_YUV422P:
258 case AV_PIX_FMT_YUV444P:
259 case AV_PIX_FMT_YUV410P:
260 case AV_PIX_FMT_YUV411P:
261 av_assert0(input_dt == DNN_FLOAT);
262 av_assert0(output_dt == DNN_FLOAT);
263 ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w,
269 0, NULL, NULL, NULL);
270 ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w,
276 0, NULL, NULL, NULL);
286 static int config_output(AVFilterLink *outlink)
288 AVFilterContext *context = outlink->src;
289 DnnProcessingContext *ctx = context->priv;
290 DNNReturnType result;
292 // have a try run in case that the dnn model resize the frame
293 result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, 1);
294 if (result != DNN_SUCCESS){
295 av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
299 outlink->w = ctx->output.width;
300 outlink->h = ctx->output.height;
302 prepare_sws_context(outlink);
307 static int copy_from_frame_to_dnn(DnnProcessingContext *ctx, const AVFrame *frame)
309 int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
310 DNNData *dnn_input = &ctx->input;
312 switch (frame->format) {
313 case AV_PIX_FMT_RGB24:
314 case AV_PIX_FMT_BGR24:
315 if (dnn_input->dt == DNN_FLOAT) {
316 sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize,
317 0, frame->height, (uint8_t * const*)(&dnn_input->data),
318 (const int [4]){frame->linesize[0] * sizeof(float), 0, 0, 0});
320 av_assert0(dnn_input->dt == DNN_UINT8);
321 av_image_copy_plane(dnn_input->data, bytewidth,
322 frame->data[0], frame->linesize[0],
323 bytewidth, frame->height);
326 case AV_PIX_FMT_GRAY8:
327 case AV_PIX_FMT_GRAYF32:
328 av_image_copy_plane(dnn_input->data, bytewidth,
329 frame->data[0], frame->linesize[0],
330 bytewidth, frame->height);
332 case AV_PIX_FMT_YUV420P:
333 case AV_PIX_FMT_YUV422P:
334 case AV_PIX_FMT_YUV444P:
335 case AV_PIX_FMT_YUV410P:
336 case AV_PIX_FMT_YUV411P:
337 sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize,
338 0, frame->height, (uint8_t * const*)(&dnn_input->data),
339 (const int [4]){frame->width * sizeof(float), 0, 0, 0});
348 static int copy_from_dnn_to_frame(DnnProcessingContext *ctx, AVFrame *frame)
350 int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
351 DNNData *dnn_output = &ctx->output;
353 switch (frame->format) {
354 case AV_PIX_FMT_RGB24:
355 case AV_PIX_FMT_BGR24:
356 if (dnn_output->dt == DNN_FLOAT) {
357 sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0},
358 (const int[4]){frame->linesize[0] * sizeof(float), 0, 0, 0},
359 0, frame->height, (uint8_t * const*)frame->data, frame->linesize);
362 av_assert0(dnn_output->dt == DNN_UINT8);
363 av_image_copy_plane(frame->data[0], frame->linesize[0],
364 dnn_output->data, bytewidth,
365 bytewidth, frame->height);
368 case AV_PIX_FMT_GRAY8:
369 // it is possible that data type of dnn output is float32,
370 // need to add support for such case when needed.
371 av_assert0(dnn_output->dt == DNN_UINT8);
372 av_image_copy_plane(frame->data[0], frame->linesize[0],
373 dnn_output->data, bytewidth,
374 bytewidth, frame->height);
376 case AV_PIX_FMT_GRAYF32:
377 av_assert0(dnn_output->dt == DNN_FLOAT);
378 av_image_copy_plane(frame->data[0], frame->linesize[0],
379 dnn_output->data, bytewidth,
380 bytewidth, frame->height);
382 case AV_PIX_FMT_YUV420P:
383 case AV_PIX_FMT_YUV422P:
384 case AV_PIX_FMT_YUV444P:
385 case AV_PIX_FMT_YUV410P:
386 case AV_PIX_FMT_YUV411P:
387 sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0},
388 (const int[4]){frame->width * sizeof(float), 0, 0, 0},
389 0, frame->height, (uint8_t * const*)frame->data, frame->linesize);
398 static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt)
400 const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt);
402 return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3;
405 static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in)
407 const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(in->format);
408 int uv_height = AV_CEIL_RSHIFT(in->height, desc->log2_chroma_h);
409 for (int i = 1; i < 3; ++i) {
410 int bytewidth = av_image_get_linesize(in->format, in->width, i);
411 av_image_copy_plane(out->data[i], out->linesize[i],
412 in->data[i], in->linesize[i],
413 bytewidth, uv_height);
419 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
421 AVFilterContext *context = inlink->dst;
422 AVFilterLink *outlink = context->outputs[0];
423 DnnProcessingContext *ctx = context->priv;
424 DNNReturnType dnn_result;
427 copy_from_frame_to_dnn(ctx, in);
429 dnn_result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, 1);
430 if (dnn_result != DNN_SUCCESS){
431 av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
436 out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
439 return AVERROR(ENOMEM);
442 av_frame_copy_props(out, in);
443 copy_from_dnn_to_frame(ctx, out);
445 if (isPlanarYUV(in->format))
446 copy_uv_planes(ctx, out, in);
449 return ff_filter_frame(outlink, out);
452 static av_cold void uninit(AVFilterContext *ctx)
454 DnnProcessingContext *context = ctx->priv;
456 sws_freeContext(context->sws_gray8_to_grayf32);
457 sws_freeContext(context->sws_grayf32_to_gray8);
459 if (context->dnn_module)
460 (context->dnn_module->free_model)(&context->model);
462 av_freep(&context->dnn_module);
465 static const AVFilterPad dnn_processing_inputs[] = {
468 .type = AVMEDIA_TYPE_VIDEO,
469 .config_props = config_input,
470 .filter_frame = filter_frame,
475 static const AVFilterPad dnn_processing_outputs[] = {
478 .type = AVMEDIA_TYPE_VIDEO,
479 .config_props = config_output,
484 AVFilter ff_vf_dnn_processing = {
485 .name = "dnn_processing",
486 .description = NULL_IF_CONFIG_SMALL("Apply DNN processing filter to the input."),
487 .priv_size = sizeof(DnnProcessingContext),
490 .query_formats = query_formats,
491 .inputs = dnn_processing_inputs,
492 .outputs = dnn_processing_outputs,
493 .priv_class = &dnn_processing_class,