]> git.sesse.net Git - ffmpeg/blob - libavfilter/vf_libvmaf.c
42c6b66b69a538b391a1143a58330a0e28c549ed
[ffmpeg] / libavfilter / vf_libvmaf.c
1 /*
2  * Copyright (c) 2017 Ronald S. Bultje <rsbultje@gmail.com>
3  * Copyright (c) 2017 Ashish Pratap Singh <ashk43712@gmail.com>
4  *
5  * This file is part of FFmpeg.
6  *
7  * FFmpeg is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * FFmpeg is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with FFmpeg; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21
22 /**
23  * @file
24  * Calculate the VMAF between two input videos.
25  */
26
27 #include <pthread.h>
28 #include <libvmaf.h>
29 #include "libavutil/avstring.h"
30 #include "libavutil/opt.h"
31 #include "libavutil/pixdesc.h"
32 #include "avfilter.h"
33 #include "drawutils.h"
34 #include "formats.h"
35 #include "framesync.h"
36 #include "internal.h"
37 #include "video.h"
38
39 typedef struct LIBVMAFContext {
40     const AVClass *class;
41     FFFrameSync fs;
42     const AVPixFmtDescriptor *desc;
43     int width;
44     int height;
45     double vmaf_score;
46     pthread_t vmaf_thread;
47     pthread_mutex_t lock;
48     pthread_cond_t cond;
49     int eof;
50     AVFrame *gmain;
51     AVFrame *gref;
52     int frame_set;
53     char *model_path;
54     char *log_path;
55     char *log_fmt;
56     int disable_clip;
57     int disable_avx;
58     int enable_transform;
59     int phone_model;
60     int psnr;
61     int ssim;
62     int ms_ssim;
63     char *pool;
64     int error;
65 } LIBVMAFContext;
66
67 #define OFFSET(x) offsetof(LIBVMAFContext, x)
68 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_VIDEO_PARAM
69
70 static const AVOption libvmaf_options[] = {
71     {"model_path",  "Set the model to be used for computing vmaf.",                     OFFSET(model_path), AV_OPT_TYPE_STRING, {.str="/usr/local/share/model/vmaf_v0.6.1.pkl"}, 0, 1, FLAGS},
72     {"log_path",  "Set the file path to be used to store logs.",                        OFFSET(log_path), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 1, FLAGS},
73     {"log_fmt",  "Set the format of the log (xml or json).",                            OFFSET(log_fmt), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 1, FLAGS},
74     {"enable_transform",  "Enables transform for computing vmaf.",                      OFFSET(enable_transform), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS},
75     {"phone_model",  "Invokes the phone model that will generate higher VMAF scores.",  OFFSET(phone_model), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS},
76     {"psnr",  "Enables computing psnr along with vmaf.",                                OFFSET(psnr), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS},
77     {"ssim",  "Enables computing ssim along with vmaf.",                                OFFSET(ssim), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS},
78     {"ms_ssim",  "Enables computing ms-ssim along with vmaf.",                          OFFSET(ms_ssim), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS},
79     {"pool",  "Set the pool method to be used for computing vmaf.",                     OFFSET(pool), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 1, FLAGS},
80     { NULL }
81 };
82
83 FRAMESYNC_DEFINE_CLASS(libvmaf, LIBVMAFContext, fs);
84
85 #define read_frame_fn(type, bits)                                               \
86     static int read_frame_##bits##bit(float *ref_data, float *main_data,        \
87                                       float *temp_data, int stride, void *ctx)  \
88 {                                                                               \
89     LIBVMAFContext *s = (LIBVMAFContext *) ctx;                                 \
90     int ret;                                                                    \
91     \
92     pthread_mutex_lock(&s->lock);                                               \
93     \
94     while (!s->frame_set && !s->eof) {                                          \
95         pthread_cond_wait(&s->cond, &s->lock);                                  \
96     }                                                                           \
97     \
98     if (s->frame_set) {                                                         \
99         int ref_stride = s->gref->linesize[0];                                  \
100         int main_stride = s->gmain->linesize[0];                                \
101         \
102         const type *ref_ptr = (const type *) s->gref->data[0];                  \
103         const type *main_ptr = (const type *) s->gmain->data[0];                \
104         \
105         float *ptr = ref_data;                                                  \
106         \
107         int h = s->height;                                                      \
108         int w = s->width;                                                       \
109         \
110         int i,j;                                                                \
111         \
112         for (i = 0; i < h; i++) {                                               \
113             for ( j = 0; j < w; j++) {                                          \
114                 ptr[j] = (float)ref_ptr[j];                                     \
115             }                                                                   \
116             ref_ptr += ref_stride / sizeof(*ref_ptr);                           \
117             ptr += stride / sizeof(*ptr);                                       \
118         }                                                                       \
119         \
120         ptr = main_data;                                                        \
121         \
122         for (i = 0; i < h; i++) {                                               \
123             for (j = 0; j < w; j++) {                                           \
124                 ptr[j] = (float)main_ptr[j];                                    \
125             }                                                                   \
126             main_ptr += main_stride / sizeof(*main_ptr);                        \
127             ptr += stride / sizeof(*ptr);                                       \
128         }                                                                       \
129     }                                                                           \
130     \
131     ret = !s->frame_set;                                                        \
132     \
133     av_frame_unref(s->gref);                                                    \
134     av_frame_unref(s->gmain);                                                   \
135     s->frame_set = 0;                                                           \
136     \
137     pthread_cond_signal(&s->cond);                                              \
138     pthread_mutex_unlock(&s->lock);                                             \
139     \
140     if (ret) {                                                                  \
141         return 2;                                                               \
142     }                                                                           \
143     \
144     return 0;                                                                   \
145 }
146
147 read_frame_fn(uint8_t, 8);
148 read_frame_fn(uint16_t, 10);
149
150 static void compute_vmaf_score(LIBVMAFContext *s)
151 {
152     int (*read_frame)(float *ref_data, float *main_data, float *temp_data,
153                       int stride, void *ctx);
154     char *format;
155
156     if (s->desc->comp[0].depth <= 8) {
157         read_frame = read_frame_8bit;
158     } else {
159         read_frame = read_frame_10bit;
160     }
161
162     format = (char *) s->desc->name;
163
164     s->error = compute_vmaf(&s->vmaf_score, format, s->width, s->height,
165                             read_frame, s, s->model_path, s->log_path,
166                             s->log_fmt, 0, 0, s->enable_transform,
167                             s->phone_model, s->psnr, s->ssim,
168                             s->ms_ssim, s->pool);
169 }
170
171 static void *call_vmaf(void *ctx)
172 {
173     LIBVMAFContext *s = (LIBVMAFContext *) ctx;
174     compute_vmaf_score(s);
175     if (!s->error) {
176         av_log(ctx, AV_LOG_INFO, "VMAF score: %f\n",s->vmaf_score);
177     } else {
178         pthread_mutex_lock(&s->lock);
179         pthread_cond_signal(&s->cond);
180         pthread_mutex_unlock(&s->lock);
181     }
182     pthread_exit(NULL);
183     return NULL;
184 }
185
186 static int do_vmaf(FFFrameSync *fs)
187 {
188     AVFilterContext *ctx = fs->parent;
189     LIBVMAFContext *s = ctx->priv;
190     AVFrame *master, *ref;
191     int ret;
192
193     ret = ff_framesync_dualinput_get(fs, &master, &ref);
194     if (ret < 0)
195         return ret;
196     if (!ref)
197         return ff_filter_frame(ctx->outputs[0], master);
198
199     pthread_mutex_lock(&s->lock);
200
201     while (s->frame_set && !s->error) {
202         pthread_cond_wait(&s->cond, &s->lock);
203     }
204
205     if (s->error) {
206         av_log(ctx, AV_LOG_ERROR,
207                "libvmaf encountered an error, check log for details\n");
208         pthread_mutex_unlock(&s->lock);
209         return AVERROR(EINVAL);
210     }
211
212     av_frame_ref(s->gref, ref);
213     av_frame_ref(s->gmain, master);
214
215     s->frame_set = 1;
216
217     pthread_cond_signal(&s->cond);
218     pthread_mutex_unlock(&s->lock);
219
220     return ff_filter_frame(ctx->outputs[0], master);
221 }
222
223 static av_cold int init(AVFilterContext *ctx)
224 {
225     LIBVMAFContext *s = ctx->priv;
226
227     s->gref = av_frame_alloc();
228     s->gmain = av_frame_alloc();
229     s->error = 0;
230
231     pthread_mutex_init(&s->lock, NULL);
232     pthread_cond_init (&s->cond, NULL);
233
234     s->fs.on_event = do_vmaf;
235     return 0;
236 }
237
238 static int query_formats(AVFilterContext *ctx)
239 {
240     static const enum AVPixelFormat pix_fmts[] = {
241         AV_PIX_FMT_YUV444P, AV_PIX_FMT_YUV422P, AV_PIX_FMT_YUV420P,
242         AV_PIX_FMT_YUV444P10LE, AV_PIX_FMT_YUV422P10LE, AV_PIX_FMT_YUV420P10LE,
243         AV_PIX_FMT_NONE
244     };
245
246     AVFilterFormats *fmts_list = ff_make_format_list(pix_fmts);
247     if (!fmts_list)
248         return AVERROR(ENOMEM);
249     return ff_set_common_formats(ctx, fmts_list);
250 }
251
252
253 static int config_input_ref(AVFilterLink *inlink)
254 {
255     AVFilterContext *ctx  = inlink->dst;
256     LIBVMAFContext *s = ctx->priv;
257     int th;
258
259     if (ctx->inputs[0]->w != ctx->inputs[1]->w ||
260         ctx->inputs[0]->h != ctx->inputs[1]->h) {
261         av_log(ctx, AV_LOG_ERROR, "Width and height of input videos must be same.\n");
262         return AVERROR(EINVAL);
263     }
264     if (ctx->inputs[0]->format != ctx->inputs[1]->format) {
265         av_log(ctx, AV_LOG_ERROR, "Inputs must be of same pixel format.\n");
266         return AVERROR(EINVAL);
267     }
268
269     s->desc = av_pix_fmt_desc_get(inlink->format);
270     s->width = ctx->inputs[0]->w;
271     s->height = ctx->inputs[0]->h;
272
273     th = pthread_create(&s->vmaf_thread, NULL, call_vmaf, (void *) s);
274     if (th) {
275         av_log(ctx, AV_LOG_ERROR, "Thread creation failed.\n");
276         return AVERROR(EINVAL);
277     }
278
279     return 0;
280 }
281
282 static int config_output(AVFilterLink *outlink)
283 {
284     AVFilterContext *ctx = outlink->src;
285     LIBVMAFContext *s = ctx->priv;
286     AVFilterLink *mainlink = ctx->inputs[0];
287     int ret;
288
289     ret = ff_framesync_init_dualinput(&s->fs, ctx);
290     if (ret < 0)
291         return ret;
292     outlink->w = mainlink->w;
293     outlink->h = mainlink->h;
294     outlink->time_base = mainlink->time_base;
295     outlink->sample_aspect_ratio = mainlink->sample_aspect_ratio;
296     outlink->frame_rate = mainlink->frame_rate;
297     if ((ret = ff_framesync_configure(&s->fs)) < 0)
298         return ret;
299
300     return 0;
301 }
302
303 static int activate(AVFilterContext *ctx)
304 {
305     LIBVMAFContext *s = ctx->priv;
306     return ff_framesync_activate(&s->fs);
307 }
308
309 static av_cold void uninit(AVFilterContext *ctx)
310 {
311     LIBVMAFContext *s = ctx->priv;
312
313     ff_framesync_uninit(&s->fs);
314
315     pthread_mutex_lock(&s->lock);
316     s->eof = 1;
317     pthread_cond_signal(&s->cond);
318     pthread_mutex_unlock(&s->lock);
319
320     pthread_join(s->vmaf_thread, NULL);
321
322     av_frame_free(&s->gref);
323     av_frame_free(&s->gmain);
324
325     pthread_mutex_destroy(&s->lock);
326     pthread_cond_destroy(&s->cond);
327 }
328
329 static const AVFilterPad libvmaf_inputs[] = {
330     {
331         .name         = "main",
332         .type         = AVMEDIA_TYPE_VIDEO,
333     },{
334         .name         = "reference",
335         .type         = AVMEDIA_TYPE_VIDEO,
336         .config_props = config_input_ref,
337     },
338     { NULL }
339 };
340
341 static const AVFilterPad libvmaf_outputs[] = {
342     {
343         .name          = "default",
344         .type          = AVMEDIA_TYPE_VIDEO,
345         .config_props  = config_output,
346     },
347     { NULL }
348 };
349
350 AVFilter ff_vf_libvmaf = {
351     .name          = "libvmaf",
352     .description   = NULL_IF_CONFIG_SMALL("Calculate the VMAF between two video streams."),
353     .preinit       = libvmaf_framesync_preinit,
354     .init          = init,
355     .uninit        = uninit,
356     .query_formats = query_formats,
357     .activate      = activate,
358     .priv_size     = sizeof(LIBVMAFContext),
359     .priv_class    = &libvmaf_class,
360     .inputs        = libvmaf_inputs,
361     .outputs       = libvmaf_outputs,
362 };