]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
avfilter/vf_scale: store the offset in a local variable before adding it
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
42 #include "avfilter.h"
43 #include "audio.h"
44 #include "filters.h"
45 #include "formats.h"
46
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
51
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56
57 #define SQUARE(x) ((x)*(x))
58
59 #define NB_BANDS 22
60
61 #define CEPS_MEM 8
62 #define NB_DELTA_CEPS 6
63
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65
66 #define WEIGHTS_SCALE (1.f/256)
67
68 #define MAX_NEURONS 128
69
70 #define ACTIVATION_TANH    0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU    2
73
74 #define Q15ONE 1.0f
75
76 typedef struct DenseLayer {
77     const float *bias;
78     const float *input_weights;
79     int nb_inputs;
80     int nb_neurons;
81     int activation;
82 } DenseLayer;
83
84 typedef struct GRULayer {
85     const float *bias;
86     const float *input_weights;
87     const float *recurrent_weights;
88     int nb_inputs;
89     int nb_neurons;
90     int activation;
91 } GRULayer;
92
93 typedef struct RNNModel {
94     int input_dense_size;
95     const DenseLayer *input_dense;
96
97     int vad_gru_size;
98     const GRULayer *vad_gru;
99
100     int noise_gru_size;
101     const GRULayer *noise_gru;
102
103     int denoise_gru_size;
104     const GRULayer *denoise_gru;
105
106     int denoise_output_size;
107     const DenseLayer *denoise_output;
108
109     int vad_output_size;
110     const DenseLayer *vad_output;
111 } RNNModel;
112
113 typedef struct RNNState {
114     float *vad_gru_state;
115     float *noise_gru_state;
116     float *denoise_gru_state;
117     RNNModel *model;
118 } RNNState;
119
120 typedef struct DenoiseState {
121     float analysis_mem[FRAME_SIZE];
122     float cepstral_mem[CEPS_MEM][NB_BANDS];
123     int memid;
124     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125     float pitch_buf[PITCH_BUF_SIZE];
126     float pitch_enh_buf[PITCH_BUF_SIZE];
127     float last_gain;
128     int last_period;
129     float mem_hp_x[2];
130     float lastg[NB_BANDS];
131     float history[FRAME_SIZE];
132     RNNState rnn[2];
133     AVTXContext *tx, *txi;
134     av_tx_fn tx_fn, txi_fn;
135 } DenoiseState;
136
137 typedef struct AudioRNNContext {
138     const AVClass *class;
139
140     char *model_name;
141     float mix;
142
143     int channels;
144     DenoiseState *st;
145
146     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147     DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148
149     RNNModel *model[2];
150
151     AVFloatDSPContext *fdsp;
152 } AudioRNNContext;
153
154 #define F_ACTIVATION_TANH       0
155 #define F_ACTIVATION_SIGMOID    1
156 #define F_ACTIVATION_RELU       2
157
158 static void rnnoise_model_free(RNNModel *model)
159 {
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
162     if (model->name) { \
163         av_free((void *) model->name->input_weights); \
164         av_free((void *) model->name->bias); \
165         av_free((void *) model->name); \
166     } \
167     } while (0)
168 #define FREE_GRU(name) do { \
169     if (model->name) { \
170         av_free((void *) model->name->input_weights); \
171         av_free((void *) model->name->recurrent_weights); \
172         av_free((void *) model->name->bias); \
173         av_free((void *) model->name); \
174     } \
175     } while (0)
176
177     if (!model)
178         return;
179     FREE_DENSE(input_dense);
180     FREE_GRU(vad_gru);
181     FREE_GRU(noise_gru);
182     FREE_GRU(denoise_gru);
183     FREE_DENSE(denoise_output);
184     FREE_DENSE(vad_output);
185     av_free(model);
186 }
187
188 static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
189 {
190     RNNModel *ret = NULL;
191     DenseLayer *input_dense;
192     GRULayer *vad_gru;
193     GRULayer *noise_gru;
194     GRULayer *denoise_gru;
195     DenseLayer *denoise_output;
196     DenseLayer *vad_output;
197     int in;
198
199     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200         return AVERROR_INVALIDDATA;
201
202     ret = av_calloc(1, sizeof(RNNModel));
203     if (!ret)
204         return AVERROR(ENOMEM);
205
206 #define ALLOC_LAYER(type, name) \
207     name = av_calloc(1, sizeof(type)); \
208     if (!name) { \
209         rnnoise_model_free(ret); \
210         return AVERROR(ENOMEM); \
211     } \
212     ret->name = name
213
214     ALLOC_LAYER(DenseLayer, input_dense);
215     ALLOC_LAYER(GRULayer, vad_gru);
216     ALLOC_LAYER(GRULayer, noise_gru);
217     ALLOC_LAYER(GRULayer, denoise_gru);
218     ALLOC_LAYER(DenseLayer, denoise_output);
219     ALLOC_LAYER(DenseLayer, vad_output);
220
221 #define INPUT_VAL(name) do { \
222     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223         rnnoise_model_free(ret); \
224         return AVERROR(EINVAL); \
225     } \
226     name = in; \
227     } while (0)
228
229 #define INPUT_ACTIVATION(name) do { \
230     int activation; \
231     INPUT_VAL(activation); \
232     switch (activation) { \
233     case F_ACTIVATION_SIGMOID: \
234         name = ACTIVATION_SIGMOID; \
235         break; \
236     case F_ACTIVATION_RELU: \
237         name = ACTIVATION_RELU; \
238         break; \
239     default: \
240         name = ACTIVATION_TANH; \
241     } \
242     } while (0)
243
244 #define INPUT_ARRAY(name, len) do { \
245     float *values = av_calloc((len), sizeof(float)); \
246     if (!values) { \
247         rnnoise_model_free(ret); \
248         return AVERROR(ENOMEM); \
249     } \
250     name = values; \
251     for (int i = 0; i < (len); i++) { \
252         if (fscanf(f, "%d", &in) != 1) { \
253             rnnoise_model_free(ret); \
254             return AVERROR(EINVAL); \
255         } \
256         values[i] = in; \
257     } \
258     } while (0)
259
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262     if (!values) { \
263         rnnoise_model_free(ret); \
264         return AVERROR(ENOMEM); \
265     } \
266     name = values; \
267     for (int k = 0; k < (len0); k++) { \
268         for (int i = 0; i < (len2); i++) { \
269             for (int j = 0; j < (len1); j++) { \
270                 if (fscanf(f, "%d", &in) != 1) { \
271                     rnnoise_model_free(ret); \
272                     return AVERROR(EINVAL); \
273                 } \
274                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275             } \
276         } \
277     } \
278     } while (0)
279
280 #define NEW_LINE() do { \
281     int c; \
282     while ((c = fgetc(f)) != EOF) { \
283         if (c == '\n') \
284         break; \
285     } \
286     } while (0)
287
288 #define INPUT_DENSE(name) do { \
289     INPUT_VAL(name->nb_inputs); \
290     INPUT_VAL(name->nb_neurons); \
291     ret->name ## _size = name->nb_neurons; \
292     INPUT_ACTIVATION(name->activation); \
293     NEW_LINE(); \
294     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
295     NEW_LINE(); \
296     INPUT_ARRAY(name->bias, name->nb_neurons); \
297     NEW_LINE(); \
298     } while (0)
299
300 #define INPUT_GRU(name) do { \
301     INPUT_VAL(name->nb_inputs); \
302     INPUT_VAL(name->nb_neurons); \
303     ret->name ## _size = name->nb_neurons; \
304     INPUT_ACTIVATION(name->activation); \
305     NEW_LINE(); \
306     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
307     NEW_LINE(); \
308     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
309     NEW_LINE(); \
310     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
311     NEW_LINE(); \
312     } while (0)
313
314     INPUT_DENSE(input_dense);
315     INPUT_GRU(vad_gru);
316     INPUT_GRU(noise_gru);
317     INPUT_GRU(denoise_gru);
318     INPUT_DENSE(denoise_output);
319     INPUT_DENSE(vad_output);
320
321     if (vad_output->nb_neurons != 1) {
322         rnnoise_model_free(ret);
323         return AVERROR(EINVAL);
324     }
325
326     *rnn = ret;
327
328     return 0;
329 }
330
331 static int query_formats(AVFilterContext *ctx)
332 {
333     AVFilterFormats *formats = NULL;
334     AVFilterChannelLayouts *layouts = NULL;
335     static const enum AVSampleFormat sample_fmts[] = {
336         AV_SAMPLE_FMT_FLTP,
337         AV_SAMPLE_FMT_NONE
338     };
339     int ret, sample_rates[] = { 48000, -1 };
340
341     formats = ff_make_format_list(sample_fmts);
342     if (!formats)
343         return AVERROR(ENOMEM);
344     ret = ff_set_common_formats(ctx, formats);
345     if (ret < 0)
346         return ret;
347
348     layouts = ff_all_channel_counts();
349     if (!layouts)
350         return AVERROR(ENOMEM);
351
352     ret = ff_set_common_channel_layouts(ctx, layouts);
353     if (ret < 0)
354         return ret;
355
356     formats = ff_make_format_list(sample_rates);
357     if (!formats)
358         return AVERROR(ENOMEM);
359     return ff_set_common_samplerates(ctx, formats);
360 }
361
362 static int config_input(AVFilterLink *inlink)
363 {
364     AVFilterContext *ctx = inlink->dst;
365     AudioRNNContext *s = ctx->priv;
366     int ret;
367
368     s->channels = inlink->channels;
369
370     if (!s->st)
371         s->st = av_calloc(s->channels, sizeof(DenoiseState));
372     if (!s->st)
373         return AVERROR(ENOMEM);
374
375     for (int i = 0; i < s->channels; i++) {
376         DenoiseState *st = &s->st[i];
377
378         st->rnn[0].model = s->model[0];
379         st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
380         st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
381         st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
382         if (!st->rnn[0].vad_gru_state ||
383             !st->rnn[0].noise_gru_state ||
384             !st->rnn[0].denoise_gru_state)
385             return AVERROR(ENOMEM);
386     }
387
388     for (int i = 0; i < s->channels; i++) {
389         DenoiseState *st = &s->st[i];
390
391         if (!st->tx)
392             ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
393         if (ret < 0)
394             return ret;
395
396         if (!st->txi)
397             ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
398         if (ret < 0)
399             return ret;
400     }
401
402     return 0;
403 }
404
405 static void biquad(float *y, float mem[2], const float *x,
406                    const float *b, const float *a, int N)
407 {
408     for (int i = 0; i < N; i++) {
409         float xi, yi;
410
411         xi = x[i];
412         yi = x[i] + mem[0];
413         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
414         mem[1] = (b[1]*xi - a[1]*yi);
415         y[i] = yi;
416     }
417 }
418
419 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
420 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
421 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
422
423 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
424 {
425     AVComplexFloat x[WINDOW_SIZE];
426     AVComplexFloat y[WINDOW_SIZE];
427
428     for (int i = 0; i < WINDOW_SIZE; i++) {
429         x[i].re = in[i];
430         x[i].im = 0;
431     }
432
433     st->tx_fn(st->tx, y, x, sizeof(float));
434
435     RNN_COPY(out, y, FREQ_SIZE);
436 }
437
438 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
439 {
440     AVComplexFloat x[WINDOW_SIZE];
441     AVComplexFloat y[WINDOW_SIZE];
442
443     RNN_COPY(x, in, FREQ_SIZE);
444
445     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
446         x[i].re =  x[WINDOW_SIZE - i].re;
447         x[i].im = -x[WINDOW_SIZE - i].im;
448     }
449
450     st->txi_fn(st->txi, y, x, sizeof(float));
451
452     for (int i = 0; i < WINDOW_SIZE; i++)
453         out[i] = y[i].re / WINDOW_SIZE;
454 }
455
456 static const uint8_t eband5ms[] = {
457 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
458   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
459 };
460
461 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
462 {
463     float sum[NB_BANDS] = {0};
464
465     for (int i = 0; i < NB_BANDS - 1; i++) {
466         int band_size;
467
468         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
469         for (int j = 0; j < band_size; j++) {
470             float tmp, frac = (float)j / band_size;
471
472             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
473             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
474             sum[i]     += (1.f - frac) * tmp;
475             sum[i + 1] +=        frac  * tmp;
476         }
477     }
478
479     sum[0] *= 2;
480     sum[NB_BANDS - 1] *= 2;
481
482     for (int i = 0; i < NB_BANDS; i++)
483         bandE[i] = sum[i];
484 }
485
486 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
487 {
488     float sum[NB_BANDS] = { 0 };
489
490     for (int i = 0; i < NB_BANDS - 1; i++) {
491         int band_size;
492
493         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
494         for (int j = 0; j < band_size; j++) {
495             float tmp, frac = (float)j / band_size;
496
497             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
498             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
499             sum[i]     += (1 - frac) * tmp;
500             sum[i + 1] +=      frac  * tmp;
501         }
502     }
503
504     sum[0] *= 2;
505     sum[NB_BANDS-1] *= 2;
506
507     for (int i = 0; i < NB_BANDS; i++)
508         bandE[i] = sum[i];
509 }
510
511 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
512 {
513     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
514
515     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
516     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
517     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
518     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
519     forward_transform(st, X, x);
520     compute_band_energy(Ex, X);
521 }
522
523 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
524 {
525     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
526     const float *src = st->history;
527     const float mix = s->mix;
528     const float imix = 1.f - FFMAX(mix, 0.f);
529
530     inverse_transform(st, x, y);
531     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
532     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
533     RNN_COPY(out, x, FRAME_SIZE);
534     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
535
536     for (int n = 0; n < FRAME_SIZE; n++)
537         out[n] = out[n] * mix + src[n] * imix;
538 }
539
540 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
541 {
542     float y_0, y_1, y_2, y_3 = 0;
543     int j;
544
545     y_0 = *y++;
546     y_1 = *y++;
547     y_2 = *y++;
548
549     for (j = 0; j < len - 3; j += 4) {
550         float tmp;
551
552         tmp = *x++;
553         y_3 = *y++;
554         sum[0] += tmp * y_0;
555         sum[1] += tmp * y_1;
556         sum[2] += tmp * y_2;
557         sum[3] += tmp * y_3;
558         tmp = *x++;
559         y_0 = *y++;
560         sum[0] += tmp * y_1;
561         sum[1] += tmp * y_2;
562         sum[2] += tmp * y_3;
563         sum[3] += tmp * y_0;
564         tmp = *x++;
565         y_1 = *y++;
566         sum[0] += tmp * y_2;
567         sum[1] += tmp * y_3;
568         sum[2] += tmp * y_0;
569         sum[3] += tmp * y_1;
570         tmp = *x++;
571         y_2 = *y++;
572         sum[0] += tmp * y_3;
573         sum[1] += tmp * y_0;
574         sum[2] += tmp * y_1;
575         sum[3] += tmp * y_2;
576     }
577
578     if (j++ < len) {
579         float tmp = *x++;
580
581         y_3 = *y++;
582         sum[0] += tmp * y_0;
583         sum[1] += tmp * y_1;
584         sum[2] += tmp * y_2;
585         sum[3] += tmp * y_3;
586     }
587
588     if (j++ < len) {
589         float tmp=*x++;
590
591         y_0 = *y++;
592         sum[0] += tmp * y_1;
593         sum[1] += tmp * y_2;
594         sum[2] += tmp * y_3;
595         sum[3] += tmp * y_0;
596     }
597
598     if (j < len) {
599         float tmp=*x++;
600
601         y_1 = *y++;
602         sum[0] += tmp * y_2;
603         sum[1] += tmp * y_3;
604         sum[2] += tmp * y_0;
605         sum[3] += tmp * y_1;
606     }
607 }
608
609 static inline float celt_inner_prod(const float *x,
610                                     const float *y, int N)
611 {
612     float xy = 0.f;
613
614     for (int i = 0; i < N; i++)
615         xy += x[i] * y[i];
616
617     return xy;
618 }
619
620 static void celt_pitch_xcorr(const float *x, const float *y,
621                              float *xcorr, int len, int max_pitch)
622 {
623     int i;
624
625     for (i = 0; i < max_pitch - 3; i += 4) {
626         float sum[4] = { 0, 0, 0, 0};
627
628         xcorr_kernel(x, y + i, sum, len);
629
630         xcorr[i]     = sum[0];
631         xcorr[i + 1] = sum[1];
632         xcorr[i + 2] = sum[2];
633         xcorr[i + 3] = sum[3];
634     }
635     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
636     for (; i < max_pitch; i++) {
637         xcorr[i] = celt_inner_prod(x, y + i, len);
638     }
639 }
640
641 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
642                          float       *ac,  /* out: [0...lag-1] ac values */
643                          const float *window,
644                          int          overlap,
645                          int          lag,
646                          int          n)
647 {
648     int fastN = n - lag;
649     int shift;
650     const float *xptr;
651     float xx[PITCH_BUF_SIZE>>1];
652
653     if (overlap == 0) {
654         xptr = x;
655     } else {
656         for (int i = 0; i < n; i++)
657             xx[i] = x[i];
658         for (int i = 0; i < overlap; i++) {
659             xx[i] = x[i] * window[i];
660             xx[n-i-1] = x[n-i-1] * window[i];
661         }
662         xptr = xx;
663     }
664
665     shift = 0;
666     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
667
668     for (int k = 0; k <= lag; k++) {
669         float d = 0.f;
670
671         for (int i = k + fastN; i < n; i++)
672             d += xptr[i] * xptr[i-k];
673         ac[k] += d;
674     }
675
676     return shift;
677 }
678
679 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
680                 const float *ac,   /* in:  [0...p] autocorrelation values  */
681                           int p)
682 {
683     float r, error = ac[0];
684
685     RNN_CLEAR(lpc, p);
686     if (ac[0] != 0) {
687         for (int i = 0; i < p; i++) {
688             /* Sum up this iteration's reflection coefficient */
689             float rr = 0;
690             for (int j = 0; j < i; j++)
691                 rr += (lpc[j] * ac[i - j]);
692             rr += ac[i + 1];
693             r = -rr/error;
694             /*  Update LPC coefficients and total error */
695             lpc[i] = r;
696             for (int j = 0; j < (i + 1) >> 1; j++) {
697                 float tmp1, tmp2;
698                 tmp1 = lpc[j];
699                 tmp2 = lpc[i-1-j];
700                 lpc[j]     = tmp1 + (r*tmp2);
701                 lpc[i-1-j] = tmp2 + (r*tmp1);
702             }
703
704             error = error - (r * r *error);
705             /* Bail out once we get 30 dB gain */
706             if (error < .001f * ac[0])
707                 break;
708         }
709     }
710 }
711
712 static void celt_fir5(const float *x,
713                       const float *num,
714                       float *y,
715                       int N,
716                       float *mem)
717 {
718     float num0, num1, num2, num3, num4;
719     float mem0, mem1, mem2, mem3, mem4;
720
721     num0 = num[0];
722     num1 = num[1];
723     num2 = num[2];
724     num3 = num[3];
725     num4 = num[4];
726     mem0 = mem[0];
727     mem1 = mem[1];
728     mem2 = mem[2];
729     mem3 = mem[3];
730     mem4 = mem[4];
731
732     for (int i = 0; i < N; i++) {
733         float sum = x[i];
734
735         sum += (num0*mem0);
736         sum += (num1*mem1);
737         sum += (num2*mem2);
738         sum += (num3*mem3);
739         sum += (num4*mem4);
740         mem4 = mem3;
741         mem3 = mem2;
742         mem2 = mem1;
743         mem1 = mem0;
744         mem0 = x[i];
745         y[i] = sum;
746     }
747
748     mem[0] = mem0;
749     mem[1] = mem1;
750     mem[2] = mem2;
751     mem[3] = mem3;
752     mem[4] = mem4;
753 }
754
755 static void pitch_downsample(float *x[], float *x_lp,
756                              int len, int C)
757 {
758     float ac[5];
759     float tmp=Q15ONE;
760     float lpc[4], mem[5]={0,0,0,0,0};
761     float lpc2[5];
762     float c1 = .8f;
763
764     for (int i = 1; i < len >> 1; i++)
765         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
766     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
767     if (C==2) {
768         for (int i = 1; i < len >> 1; i++)
769             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
770         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
771     }
772
773     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
774
775     /* Noise floor -40 dB */
776     ac[0] *= 1.0001f;
777     /* Lag windowing */
778     for (int i = 1; i <= 4; i++) {
779         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
780         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
781     }
782
783     celt_lpc(lpc, ac, 4);
784     for (int i = 0; i < 4; i++) {
785         tmp = .9f * tmp;
786         lpc[i] = (lpc[i] * tmp);
787     }
788     /* Add a zero */
789     lpc2[0] = lpc[0] + .8f;
790     lpc2[1] = lpc[1] + (c1 * lpc[0]);
791     lpc2[2] = lpc[2] + (c1 * lpc[1]);
792     lpc2[3] = lpc[3] + (c1 * lpc[2]);
793     lpc2[4] = (c1 * lpc[3]);
794     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
795 }
796
797 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
798                                    int N, float *xy1, float *xy2)
799 {
800     float xy01 = 0, xy02 = 0;
801
802     for (int i = 0; i < N; i++) {
803         xy01 += (x[i] * y01[i]);
804         xy02 += (x[i] * y02[i]);
805     }
806
807     *xy1 = xy01;
808     *xy2 = xy02;
809 }
810
811 static float compute_pitch_gain(float xy, float xx, float yy)
812 {
813     return xy / sqrtf(1.f + xx * yy);
814 }
815
816 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
817 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
818                              int *T0_, int prev_period, float prev_gain)
819 {
820     int k, i, T, T0;
821     float g, g0;
822     float pg;
823     float xy,xx,yy,xy2;
824     float xcorr[3];
825     float best_xy, best_yy;
826     int offset;
827     int minperiod0;
828     float yy_lookup[PITCH_MAX_PERIOD+1];
829
830     minperiod0 = minperiod;
831     maxperiod /= 2;
832     minperiod /= 2;
833     *T0_ /= 2;
834     prev_period /= 2;
835     N /= 2;
836     x += maxperiod;
837     if (*T0_>=maxperiod)
838         *T0_=maxperiod-1;
839
840     T = T0 = *T0_;
841     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
842     yy_lookup[0] = xx;
843     yy=xx;
844     for (i = 1; i <= maxperiod; i++) {
845         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
846         yy_lookup[i] = FFMAX(0, yy);
847     }
848     yy = yy_lookup[T0];
849     best_xy = xy;
850     best_yy = yy;
851     g = g0 = compute_pitch_gain(xy, xx, yy);
852     /* Look for any pitch at T/k */
853     for (k = 2; k <= 15; k++) {
854         int T1, T1b;
855         float g1;
856         float cont=0;
857         float thresh;
858         T1 = (2*T0+k)/(2*k);
859         if (T1 < minperiod)
860             break;
861         /* Look for another strong correlation at T1b */
862         if (k==2)
863         {
864             if (T1+T0>maxperiod)
865                 T1b = T0;
866             else
867                 T1b = T0+T1;
868         } else
869         {
870             T1b = (2*second_check[k]*T0+k)/(2*k);
871         }
872         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
873         xy = .5f * (xy + xy2);
874         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
875         g1 = compute_pitch_gain(xy, xx, yy);
876         if (FFABS(T1-prev_period)<=1)
877             cont = prev_gain;
878         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
879             cont = prev_gain * .5f;
880         else
881             cont = 0;
882         thresh = FFMAX(.3f, (.7f * g0) - cont);
883         /* Bias against very high pitch (very short period) to avoid false-positives
884            due to short-term correlation */
885         if (T1<3*minperiod)
886             thresh = FFMAX(.4f, (.85f * g0) - cont);
887         else if (T1<2*minperiod)
888             thresh = FFMAX(.5f, (.9f * g0) - cont);
889         if (g1 > thresh)
890         {
891             best_xy = xy;
892             best_yy = yy;
893             T = T1;
894             g = g1;
895         }
896     }
897     best_xy = FFMAX(0, best_xy);
898     if (best_yy <= best_xy)
899         pg = Q15ONE;
900     else
901         pg = best_xy/(best_yy + 1);
902
903     for (k = 0; k < 3; k++)
904         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
905     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
906         offset = 1;
907     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
908         offset = -1;
909     else
910         offset = 0;
911     if (pg > g)
912         pg = g;
913     *T0_ = 2*T+offset;
914
915     if (*T0_<minperiod0)
916         *T0_=minperiod0;
917     return pg;
918 }
919
920 static void find_best_pitch(float *xcorr, float *y, int len,
921                             int max_pitch, int *best_pitch)
922 {
923     float best_num[2];
924     float best_den[2];
925     float Syy = 1.f;
926
927     best_num[0] = -1;
928     best_num[1] = -1;
929     best_den[0] = 0;
930     best_den[1] = 0;
931     best_pitch[0] = 0;
932     best_pitch[1] = 1;
933
934     for (int j = 0; j < len; j++)
935         Syy += y[j] * y[j];
936
937     for (int i = 0; i < max_pitch; i++) {
938         if (xcorr[i]>0) {
939             float num;
940             float xcorr16;
941
942             xcorr16 = xcorr[i];
943             /* Considering the range of xcorr16, this should avoid both underflows
944                and overflows (inf) when squaring xcorr16 */
945             xcorr16 *= 1e-12f;
946             num = xcorr16 * xcorr16;
947             if ((num * best_den[1]) > (best_num[1] * Syy)) {
948                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
949                     best_num[1] = best_num[0];
950                     best_den[1] = best_den[0];
951                     best_pitch[1] = best_pitch[0];
952                     best_num[0] = num;
953                     best_den[0] = Syy;
954                     best_pitch[0] = i;
955                 } else {
956                     best_num[1] = num;
957                     best_den[1] = Syy;
958                     best_pitch[1] = i;
959                 }
960             }
961         }
962         Syy += y[i+len]*y[i+len] - y[i] * y[i];
963         Syy = FFMAX(1, Syy);
964     }
965 }
966
967 static void pitch_search(const float *x_lp, float *y,
968                          int len, int max_pitch, int *pitch)
969 {
970     int lag;
971     int best_pitch[2]={0,0};
972     int offset;
973
974     float x_lp4[WINDOW_SIZE];
975     float y_lp4[WINDOW_SIZE];
976     float xcorr[WINDOW_SIZE];
977
978     lag = len+max_pitch;
979
980     /* Downsample by 2 again */
981     for (int j = 0; j < len >> 2; j++)
982         x_lp4[j] = x_lp[2*j];
983     for (int j = 0; j < lag >> 2; j++)
984         y_lp4[j] = y[2*j];
985
986     /* Coarse search with 4x decimation */
987
988     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
989
990     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
991
992     /* Finer search with 2x decimation */
993     for (int i = 0; i < max_pitch >> 1; i++) {
994         float sum;
995         xcorr[i] = 0;
996         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
997             continue;
998         sum = celt_inner_prod(x_lp, y+i, len>>1);
999         xcorr[i] = FFMAX(-1, sum);
1000     }
1001
1002     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
1003
1004     /* Refine by pseudo-interpolation */
1005     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
1006         float a, b, c;
1007
1008         a = xcorr[best_pitch[0] - 1];
1009         b = xcorr[best_pitch[0]];
1010         c = xcorr[best_pitch[0] + 1];
1011         if (c - a > .7f * (b - a))
1012             offset = 1;
1013         else if (a - c > .7f * (b-c))
1014             offset = -1;
1015         else
1016             offset = 0;
1017     } else {
1018         offset = 0;
1019     }
1020
1021     *pitch = 2 * best_pitch[0] - offset;
1022 }
1023
1024 static void dct(AudioRNNContext *s, float *out, const float *in)
1025 {
1026     for (int i = 0; i < NB_BANDS; i++) {
1027         float sum;
1028
1029         sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1030         out[i] = sum * sqrtf(2.f / 22);
1031     }
1032 }
1033
1034 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1035                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1036 {
1037     float E = 0;
1038     float *ceps_0, *ceps_1, *ceps_2;
1039     float spec_variability = 0;
1040     LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1041     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1042     float pitch_buf[PITCH_BUF_SIZE>>1];
1043     int pitch_index;
1044     float gain;
1045     float *(pre[1]);
1046     float tmp[NB_BANDS];
1047     float follow, logMax;
1048
1049     frame_analysis(s, st, X, Ex, in);
1050     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1051     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1052     pre[0] = &st->pitch_buf[0];
1053     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1054     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1055             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1056     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1057
1058     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1059             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1060     st->last_period = pitch_index;
1061     st->last_gain = gain;
1062
1063     for (int i = 0; i < WINDOW_SIZE; i++)
1064         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1065
1066     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1067     forward_transform(st, P, p);
1068     compute_band_energy(Ep, P);
1069     compute_band_corr(Exp, X, P);
1070
1071     for (int i = 0; i < NB_BANDS; i++)
1072         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1073
1074     dct(s, tmp, Exp);
1075
1076     for (int i = 0; i < NB_DELTA_CEPS; i++)
1077         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1078
1079     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1080     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1081     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1082     logMax = -2;
1083     follow = -2;
1084
1085     for (int i = 0; i < NB_BANDS; i++) {
1086         Ly[i] = log10f(1e-2f + Ex[i]);
1087         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1088         logMax = FFMAX(logMax, Ly[i]);
1089         follow = FFMAX(follow-1.5, Ly[i]);
1090         E += Ex[i];
1091     }
1092
1093     if (E < 0.04f) {
1094         /* If there's no audio, avoid messing up the state. */
1095         RNN_CLEAR(features, NB_FEATURES);
1096         return 1;
1097     }
1098
1099     dct(s, features, Ly);
1100     features[0] -= 12;
1101     features[1] -= 4;
1102     ceps_0 = st->cepstral_mem[st->memid];
1103     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1104     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1105
1106     for (int i = 0; i < NB_BANDS; i++)
1107         ceps_0[i] = features[i];
1108
1109     st->memid++;
1110     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1111         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1112         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1113         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1114     }
1115     /* Spectral variability features. */
1116     if (st->memid == CEPS_MEM)
1117         st->memid = 0;
1118
1119     for (int i = 0; i < CEPS_MEM; i++) {
1120         float mindist = 1e15f;
1121         for (int j = 0; j < CEPS_MEM; j++) {
1122             float dist = 0.f;
1123             for (int k = 0; k < NB_BANDS; k++) {
1124                 float tmp;
1125
1126                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1127                 dist += tmp*tmp;
1128             }
1129
1130             if (j != i)
1131                 mindist = FFMIN(mindist, dist);
1132         }
1133
1134         spec_variability += mindist;
1135     }
1136
1137     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1138
1139     return 0;
1140 }
1141
1142 static void interp_band_gain(float *g, const float *bandE)
1143 {
1144     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1145
1146     for (int i = 0; i < NB_BANDS - 1; i++) {
1147         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1148
1149         for (int j = 0; j < band_size; j++) {
1150             float frac = (float)j / band_size;
1151
1152             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1153         }
1154     }
1155 }
1156
1157 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1158                          const float *Exp, const float *g)
1159 {
1160     float newE[NB_BANDS];
1161     float r[NB_BANDS];
1162     float norm[NB_BANDS];
1163     float rf[FREQ_SIZE] = {0};
1164     float normf[FREQ_SIZE]={0};
1165
1166     for (int i = 0; i < NB_BANDS; i++) {
1167         if (Exp[i]>g[i]) r[i] = 1;
1168         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1169         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1170         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1171     }
1172     interp_band_gain(rf, r);
1173     for (int i = 0; i < FREQ_SIZE; i++) {
1174         X[i].re += rf[i]*P[i].re;
1175         X[i].im += rf[i]*P[i].im;
1176     }
1177     compute_band_energy(newE, X);
1178     for (int i = 0; i < NB_BANDS; i++) {
1179         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1180     }
1181     interp_band_gain(normf, norm);
1182     for (int i = 0; i < FREQ_SIZE; i++) {
1183         X[i].re *= normf[i];
1184         X[i].im *= normf[i];
1185     }
1186 }
1187
1188 static const float tansig_table[201] = {
1189     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1190     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1191     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1192     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1193     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1194     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1195     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1196     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1197     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1198     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1199     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1200     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1201     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1202     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1203     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1204     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1205     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1206     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1207     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1208     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1209     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1210     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1211     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1212     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1213     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1214     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1215     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1216     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1217     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1218     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1219     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1220     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1221     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1222     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1223     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1224     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1225     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1226     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1227     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1228     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1229     1.000000f,
1230 };
1231
1232 static inline float tansig_approx(float x)
1233 {
1234     float y, dy;
1235     float sign=1;
1236     int i;
1237
1238     /* Tests are reversed to catch NaNs */
1239     if (!(x<8))
1240         return 1;
1241     if (!(x>-8))
1242         return -1;
1243     /* Another check in case of -ffast-math */
1244
1245     if (isnan(x))
1246        return 0;
1247
1248     if (x < 0) {
1249        x=-x;
1250        sign=-1;
1251     }
1252     i = (int)floor(.5f+25*x);
1253     x -= .04f*i;
1254     y = tansig_table[i];
1255     dy = 1-y*y;
1256     y = y + x*dy*(1 - y*x);
1257     return sign*y;
1258 }
1259
1260 static inline float sigmoid_approx(float x)
1261 {
1262     return .5f + .5f*tansig_approx(.5f*x);
1263 }
1264
1265 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1266 {
1267     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1268
1269     for (int i = 0; i < N; i++) {
1270         /* Compute update gate. */
1271         float sum = layer->bias[i];
1272
1273         for (int j = 0; j < M; j++)
1274             sum += layer->input_weights[j * stride + i] * input[j];
1275
1276         output[i] = WEIGHTS_SCALE * sum;
1277     }
1278
1279     if (layer->activation == ACTIVATION_SIGMOID) {
1280         for (int i = 0; i < N; i++)
1281             output[i] = sigmoid_approx(output[i]);
1282     } else if (layer->activation == ACTIVATION_TANH) {
1283         for (int i = 0; i < N; i++)
1284             output[i] = tansig_approx(output[i]);
1285     } else if (layer->activation == ACTIVATION_RELU) {
1286         for (int i = 0; i < N; i++)
1287             output[i] = FFMAX(0, output[i]);
1288     } else {
1289         av_assert0(0);
1290     }
1291 }
1292
1293 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1294 {
1295     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1296     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1297     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1298     const int M = gru->nb_inputs;
1299     const int N = gru->nb_neurons;
1300     const int AN = FFALIGN(N, 4);
1301     const int AM = FFALIGN(M, 4);
1302     const int stride = 3 * AN, istride = 3 * AM;
1303
1304     for (int i = 0; i < N; i++) {
1305         /* Compute update gate. */
1306         float sum = gru->bias[i];
1307
1308         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1309         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1310         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1311     }
1312
1313     for (int i = 0; i < N; i++) {
1314         /* Compute reset gate. */
1315         float sum = gru->bias[N + i];
1316
1317         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1318         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1319         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1320     }
1321
1322     for (int i = 0; i < N; i++) {
1323         /* Compute output. */
1324         float sum = gru->bias[2 * N + i];
1325
1326         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1327         for (int j = 0; j < N; j++)
1328             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1329
1330         if (gru->activation == ACTIVATION_SIGMOID)
1331             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1332         else if (gru->activation == ACTIVATION_TANH)
1333             sum = tansig_approx(WEIGHTS_SCALE * sum);
1334         else if (gru->activation == ACTIVATION_RELU)
1335             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1336         else
1337             av_assert0(0);
1338         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1339     }
1340
1341     RNN_COPY(state, h, N);
1342 }
1343
1344 #define INPUT_SIZE 42
1345
1346 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1347 {
1348     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1349     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1350     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1351
1352     compute_dense(rnn->model->input_dense, dense_out, input);
1353     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1354     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1355
1356     memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1357     memcpy(noise_input + rnn->model->input_dense_size,
1358            rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1359     memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1360            input, INPUT_SIZE * sizeof(float));
1361
1362     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1363
1364     memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1365     memcpy(denoise_input + rnn->model->vad_gru_size,
1366            rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1367     memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1368            input, INPUT_SIZE * sizeof(float));
1369
1370     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1371     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1372 }
1373
1374 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1375                              int disabled)
1376 {
1377     AVComplexFloat X[FREQ_SIZE];
1378     AVComplexFloat P[WINDOW_SIZE];
1379     float x[FRAME_SIZE];
1380     float Ex[NB_BANDS], Ep[NB_BANDS];
1381     LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1382     float features[NB_FEATURES];
1383     float g[NB_BANDS];
1384     float gf[FREQ_SIZE];
1385     float vad_prob = 0;
1386     float *history = st->history;
1387     static const float a_hp[2] = {-1.99599, 0.99600};
1388     static const float b_hp[2] = {-2, 1};
1389     int silence;
1390
1391     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1392     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1393
1394     if (!silence && !disabled) {
1395         compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1396         pitch_filter(X, P, Ex, Ep, Exp, g);
1397         for (int i = 0; i < NB_BANDS; i++) {
1398             float alpha = .6f;
1399
1400             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1401             st->lastg[i] = g[i];
1402         }
1403
1404         interp_band_gain(gf, g);
1405
1406         for (int i = 0; i < FREQ_SIZE; i++) {
1407             X[i].re *= gf[i];
1408             X[i].im *= gf[i];
1409         }
1410     }
1411
1412     frame_synthesis(s, st, out, X);
1413     memcpy(history, in, FRAME_SIZE * sizeof(*history));
1414
1415     return vad_prob;
1416 }
1417
1418 typedef struct ThreadData {
1419     AVFrame *in, *out;
1420 } ThreadData;
1421
1422 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1423 {
1424     AudioRNNContext *s = ctx->priv;
1425     ThreadData *td = arg;
1426     AVFrame *in = td->in;
1427     AVFrame *out = td->out;
1428     const int start = (out->channels * jobnr) / nb_jobs;
1429     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1430
1431     for (int ch = start; ch < end; ch++) {
1432         rnnoise_channel(s, &s->st[ch],
1433                         (float *)out->extended_data[ch],
1434                         (const float *)in->extended_data[ch],
1435                         ctx->is_disabled);
1436     }
1437
1438     return 0;
1439 }
1440
1441 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1442 {
1443     AVFilterContext *ctx = inlink->dst;
1444     AVFilterLink *outlink = ctx->outputs[0];
1445     AVFrame *out = NULL;
1446     ThreadData td;
1447
1448     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1449     if (!out) {
1450         av_frame_free(&in);
1451         return AVERROR(ENOMEM);
1452     }
1453     out->pts = in->pts;
1454
1455     td.in = in; td.out = out;
1456     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1457                                                                    ff_filter_get_nb_threads(ctx)));
1458
1459     av_frame_free(&in);
1460     return ff_filter_frame(outlink, out);
1461 }
1462
1463 static int activate(AVFilterContext *ctx)
1464 {
1465     AVFilterLink *inlink = ctx->inputs[0];
1466     AVFilterLink *outlink = ctx->outputs[0];
1467     AVFrame *in = NULL;
1468     int ret;
1469
1470     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1471
1472     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1473     if (ret < 0)
1474         return ret;
1475
1476     if (ret > 0)
1477         return filter_frame(inlink, in);
1478
1479     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1480     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1481
1482     return FFERROR_NOT_READY;
1483 }
1484
1485 static int open_model(AVFilterContext *ctx, RNNModel **model)
1486 {
1487     AudioRNNContext *s = ctx->priv;
1488     int ret;
1489     FILE *f;
1490
1491     if (!s->model_name)
1492         return AVERROR(EINVAL);
1493     f = av_fopen_utf8(s->model_name, "r");
1494     if (!f) {
1495         av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1496         return AVERROR(EINVAL);
1497     }
1498
1499     ret = rnnoise_model_from_file(f, model);
1500     fclose(f);
1501     if (!*model || ret < 0)
1502         return ret;
1503
1504     return 0;
1505 }
1506
1507 static av_cold int init(AVFilterContext *ctx)
1508 {
1509     AudioRNNContext *s = ctx->priv;
1510     int ret;
1511
1512     s->fdsp = avpriv_float_dsp_alloc(0);
1513     if (!s->fdsp)
1514         return AVERROR(ENOMEM);
1515
1516     ret = open_model(ctx, &s->model[0]);
1517     if (ret < 0)
1518         return ret;
1519
1520     for (int i = 0; i < FRAME_SIZE; i++) {
1521         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1522         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1523     }
1524
1525     for (int i = 0; i < NB_BANDS; i++) {
1526         for (int j = 0; j < NB_BANDS; j++) {
1527             s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1528             if (j == 0)
1529                 s->dct_table[j][i] *= sqrtf(.5);
1530         }
1531     }
1532
1533     return 0;
1534 }
1535
1536 static void free_model(AVFilterContext *ctx, int n)
1537 {
1538     AudioRNNContext *s = ctx->priv;
1539
1540     rnnoise_model_free(s->model[n]);
1541     s->model[n] = NULL;
1542
1543     for (int ch = 0; ch < s->channels && s->st; ch++) {
1544         av_freep(&s->st[ch].rnn[n].vad_gru_state);
1545         av_freep(&s->st[ch].rnn[n].noise_gru_state);
1546         av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1547     }
1548 }
1549
1550 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1551                            char *res, int res_len, int flags)
1552 {
1553     AudioRNNContext *s = ctx->priv;
1554     int ret;
1555
1556     ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1557     if (ret < 0)
1558         return ret;
1559
1560     ret = open_model(ctx, &s->model[1]);
1561     if (ret < 0)
1562         return ret;
1563
1564     FFSWAP(RNNModel *, s->model[0], s->model[1]);
1565     for (int ch = 0; ch < s->channels; ch++)
1566         FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1567
1568     ret = config_input(ctx->inputs[0]);
1569     if (ret < 0) {
1570         for (int ch = 0; ch < s->channels; ch++)
1571             FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1572         FFSWAP(RNNModel *, s->model[0], s->model[1]);
1573         return ret;
1574     }
1575
1576     free_model(ctx, 1);
1577     return 0;
1578 }
1579
1580 static av_cold void uninit(AVFilterContext *ctx)
1581 {
1582     AudioRNNContext *s = ctx->priv;
1583
1584     av_freep(&s->fdsp);
1585     free_model(ctx, 0);
1586     for (int ch = 0; ch < s->channels && s->st; ch++) {
1587         av_tx_uninit(&s->st[ch].tx);
1588         av_tx_uninit(&s->st[ch].txi);
1589     }
1590     av_freep(&s->st);
1591 }
1592
1593 static const AVFilterPad inputs[] = {
1594     {
1595         .name         = "default",
1596         .type         = AVMEDIA_TYPE_AUDIO,
1597         .config_props = config_input,
1598     },
1599     { NULL }
1600 };
1601
1602 static const AVFilterPad outputs[] = {
1603     {
1604         .name          = "default",
1605         .type          = AVMEDIA_TYPE_AUDIO,
1606     },
1607     { NULL }
1608 };
1609
1610 #define OFFSET(x) offsetof(AudioRNNContext, x)
1611 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1612
1613 static const AVOption arnndn_options[] = {
1614     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1615     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1616     { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1617     { NULL }
1618 };
1619
1620 AVFILTER_DEFINE_CLASS(arnndn);
1621
1622 const AVFilter ff_af_arnndn = {
1623     .name          = "arnndn",
1624     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1625     .query_formats = query_formats,
1626     .priv_size     = sizeof(AudioRNNContext),
1627     .priv_class    = &arnndn_class,
1628     .activate      = activate,
1629     .init          = init,
1630     .uninit        = uninit,
1631     .inputs        = inputs,
1632     .outputs       = outputs,
1633     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1634                      AVFILTER_FLAG_SLICE_THREADS,
1635     .process_command = process_command,
1636 };