2 * Copyright (c) 2019 Guo Yejun
4 * This file is part of FFmpeg.
6 * FFmpeg is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
11 * FFmpeg is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with FFmpeg; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
23 * implementing a generic image processing filter using deep learning networks.
26 #include "libavformat/avio.h"
27 #include "libavutil/opt.h"
28 #include "libavutil/pixdesc.h"
29 #include "libavutil/avassert.h"
31 #include "dnn_interface.h"
35 typedef struct DnnProcessingContext {
39 DNNBackendType backend_type;
40 enum AVPixelFormat fmt;
41 char *model_inputname;
42 char *model_outputname;
44 DNNModule *dnn_module;
47 // input & output of the model at execution time
50 } DnnProcessingContext;
52 #define OFFSET(x) offsetof(DnnProcessingContext, x)
53 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
54 static const AVOption dnn_processing_options[] = {
55 { "dnn_backend", "DNN backend", OFFSET(backend_type), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS, "backend" },
56 { "native", "native backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 0 }, 0, 0, FLAGS, "backend" },
57 #if (CONFIG_LIBTENSORFLOW == 1)
58 { "tensorflow", "tensorflow backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 1 }, 0, 0, FLAGS, "backend" },
60 { "model", "path to model file", OFFSET(model_filename), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
61 { "input", "input name of the model", OFFSET(model_inputname), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
62 { "output", "output name of the model", OFFSET(model_outputname), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
63 { "fmt", "AVPixelFormat of the frame", OFFSET(fmt), AV_OPT_TYPE_PIXEL_FMT, { .i64=AV_PIX_FMT_RGB24 }, AV_PIX_FMT_NONE, AV_PIX_FMT_NB - 1, FLAGS },
67 AVFILTER_DEFINE_CLASS(dnn_processing);
69 static av_cold int init(AVFilterContext *context)
71 DnnProcessingContext *ctx = context->priv;
73 // as the first step, only rgb24 and bgr24 are supported
74 const enum AVPixelFormat supported_pixel_fmts[] = {
78 for (int i = 0; i < sizeof(supported_pixel_fmts) / sizeof(enum AVPixelFormat); ++i) {
79 if (supported_pixel_fmts[i] == ctx->fmt) {
85 av_log(context, AV_LOG_ERROR, "pixel fmt %s not supported yet\n",
86 av_get_pix_fmt_name(ctx->fmt));
87 return AVERROR(AVERROR_INVALIDDATA);
90 if (!ctx->model_filename) {
91 av_log(ctx, AV_LOG_ERROR, "model file for network is not specified\n");
92 return AVERROR(EINVAL);
94 if (!ctx->model_inputname) {
95 av_log(ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
96 return AVERROR(EINVAL);
98 if (!ctx->model_outputname) {
99 av_log(ctx, AV_LOG_ERROR, "output name of the model network is not specified\n");
100 return AVERROR(EINVAL);
103 ctx->dnn_module = ff_get_dnn_module(ctx->backend_type);
104 if (!ctx->dnn_module) {
105 av_log(ctx, AV_LOG_ERROR, "could not create DNN module for requested backend\n");
106 return AVERROR(ENOMEM);
108 if (!ctx->dnn_module->load_model) {
109 av_log(ctx, AV_LOG_ERROR, "load_model for network is not specified\n");
110 return AVERROR(EINVAL);
113 ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename);
115 av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n");
116 return AVERROR(EINVAL);
122 static int query_formats(AVFilterContext *context)
124 AVFilterFormats *formats;
125 DnnProcessingContext *ctx = context->priv;
126 enum AVPixelFormat pixel_fmts[2];
127 pixel_fmts[0] = ctx->fmt;
128 pixel_fmts[1] = AV_PIX_FMT_NONE;
130 formats = ff_make_format_list(pixel_fmts);
131 return ff_set_common_formats(context, formats);
134 static int config_input(AVFilterLink *inlink)
136 AVFilterContext *context = inlink->dst;
137 DnnProcessingContext *ctx = context->priv;
138 DNNReturnType result;
141 result = ctx->model->get_input(ctx->model->model, &model_input, ctx->model_inputname);
142 if (result != DNN_SUCCESS) {
143 av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n");
147 // the design is to add explicit scale filter before this filter
148 if (model_input.height != -1 && model_input.height != inlink->h) {
149 av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n",
150 model_input.height, inlink->h);
153 if (model_input.width != -1 && model_input.width != inlink->w) {
154 av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n",
155 model_input.width, inlink->w);
159 if (model_input.channels != 3) {
160 av_log(ctx, AV_LOG_ERROR, "the model requires input channels %d\n",
161 model_input.channels);
164 if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) {
165 av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
169 ctx->input.width = inlink->w;
170 ctx->input.height = inlink->h;
171 ctx->input.channels = model_input.channels;
172 ctx->input.dt = model_input.dt;
174 result = (ctx->model->set_input_output)(ctx->model->model,
175 &ctx->input, ctx->model_inputname,
176 (const char **)&ctx->model_outputname, 1);
177 if (result != DNN_SUCCESS) {
178 av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n");
185 static int config_output(AVFilterLink *outlink)
187 AVFilterContext *context = outlink->src;
188 DnnProcessingContext *ctx = context->priv;
189 DNNReturnType result;
191 // have a try run in case that the dnn model resize the frame
192 result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, 1);
193 if (result != DNN_SUCCESS){
194 av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
198 outlink->w = ctx->output.width;
199 outlink->h = ctx->output.height;
204 static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
206 // extend this function to support more formats
207 av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
209 if (dnn_input->dt == DNN_FLOAT) {
210 float *dnn_input_data = dnn_input->data;
211 for (int i = 0; i < frame->height; i++) {
212 for(int j = 0; j < frame->width * 3; j++) {
213 int k = i * frame->linesize[0] + j;
214 int t = i * frame->width * 3 + j;
215 dnn_input_data[t] = frame->data[0][k] / 255.0f;
219 uint8_t *dnn_input_data = dnn_input->data;
220 av_assert0(dnn_input->dt == DNN_UINT8);
221 for (int i = 0; i < frame->height; i++) {
222 for(int j = 0; j < frame->width * 3; j++) {
223 int k = i * frame->linesize[0] + j;
224 int t = i * frame->width * 3 + j;
225 dnn_input_data[t] = frame->data[0][k];
233 static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output)
235 // extend this function to support more formats
236 av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
238 if (dnn_output->dt == DNN_FLOAT) {
239 float *dnn_output_data = dnn_output->data;
240 for (int i = 0; i < frame->height; i++) {
241 for(int j = 0; j < frame->width * 3; j++) {
242 int k = i * frame->linesize[0] + j;
243 int t = i * frame->width * 3 + j;
244 frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
248 uint8_t *dnn_output_data = dnn_output->data;
249 av_assert0(dnn_output->dt == DNN_UINT8);
250 for (int i = 0; i < frame->height; i++) {
251 for(int j = 0; j < frame->width * 3; j++) {
252 int k = i * frame->linesize[0] + j;
253 int t = i * frame->width * 3 + j;
254 frame->data[0][k] = dnn_output_data[t];
262 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
264 AVFilterContext *context = inlink->dst;
265 AVFilterLink *outlink = context->outputs[0];
266 DnnProcessingContext *ctx = context->priv;
267 DNNReturnType dnn_result;
270 copy_from_frame_to_dnn(&ctx->input, in);
272 dnn_result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, 1);
273 if (dnn_result != DNN_SUCCESS){
274 av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
278 av_assert0(ctx->output.channels == 3);
280 out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
283 return AVERROR(ENOMEM);
286 av_frame_copy_props(out, in);
287 copy_from_dnn_to_frame(out, &ctx->output);
289 return ff_filter_frame(outlink, out);
292 static av_cold void uninit(AVFilterContext *ctx)
294 DnnProcessingContext *context = ctx->priv;
296 if (context->dnn_module)
297 (context->dnn_module->free_model)(&context->model);
299 av_freep(&context->dnn_module);
302 static const AVFilterPad dnn_processing_inputs[] = {
305 .type = AVMEDIA_TYPE_VIDEO,
306 .config_props = config_input,
307 .filter_frame = filter_frame,
312 static const AVFilterPad dnn_processing_outputs[] = {
315 .type = AVMEDIA_TYPE_VIDEO,
316 .config_props = config_output,
321 AVFilter ff_vf_dnn_processing = {
322 .name = "dnn_processing",
323 .description = NULL_IF_CONFIG_SMALL("Apply DNN processing filter to the input."),
324 .priv_size = sizeof(DnnProcessingContext),
327 .query_formats = query_formats,
328 .inputs = dnn_processing_inputs,
329 .outputs = dnn_processing_outputs,
330 .priv_class = &dnn_processing_class,