]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
avutil/opt: check return value of av_bprint_finalize()
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
41 #include "avfilter.h"
42 #include "audio.h"
43 #include "filters.h"
44 #include "formats.h"
45
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56 #define SQUARE(x) ((x)*(x))
57
58 #define NB_BANDS 22
59
60 #define CEPS_MEM 8
61 #define NB_DELTA_CEPS 6
62
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65 #define WEIGHTS_SCALE (1.f/256)
66
67 #define MAX_NEURONS 128
68
69 #define ACTIVATION_TANH    0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU    2
72
73 #define Q15ONE 1.0f
74
75 typedef struct DenseLayer {
76     const float *bias;
77     const float *input_weights;
78     int nb_inputs;
79     int nb_neurons;
80     int activation;
81 } DenseLayer;
82
83 typedef struct GRULayer {
84     const float *bias;
85     const float *input_weights;
86     const float *recurrent_weights;
87     int nb_inputs;
88     int nb_neurons;
89     int activation;
90 } GRULayer;
91
92 typedef struct RNNModel {
93     int input_dense_size;
94     const DenseLayer *input_dense;
95
96     int vad_gru_size;
97     const GRULayer *vad_gru;
98
99     int noise_gru_size;
100     const GRULayer *noise_gru;
101
102     int denoise_gru_size;
103     const GRULayer *denoise_gru;
104
105     int denoise_output_size;
106     const DenseLayer *denoise_output;
107
108     int vad_output_size;
109     const DenseLayer *vad_output;
110 } RNNModel;
111
112 typedef struct RNNState {
113     float *vad_gru_state;
114     float *noise_gru_state;
115     float *denoise_gru_state;
116     RNNModel *model;
117 } RNNState;
118
119 typedef struct DenoiseState {
120     float analysis_mem[FRAME_SIZE];
121     float cepstral_mem[CEPS_MEM][NB_BANDS];
122     int memid;
123     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124     float pitch_buf[PITCH_BUF_SIZE];
125     float pitch_enh_buf[PITCH_BUF_SIZE];
126     float last_gain;
127     int last_period;
128     float mem_hp_x[2];
129     float lastg[NB_BANDS];
130     RNNState rnn;
131     AVTXContext *tx, *txi;
132     av_tx_fn tx_fn, txi_fn;
133 } DenoiseState;
134
135 typedef struct AudioRNNContext {
136     const AVClass *class;
137
138     char *model_name;
139
140     int channels;
141     DenoiseState *st;
142
143     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144     float dct_table[NB_BANDS*NB_BANDS];
145
146     RNNModel *model;
147
148     AVFloatDSPContext *fdsp;
149 } AudioRNNContext;
150
151 #define F_ACTIVATION_TANH       0
152 #define F_ACTIVATION_SIGMOID    1
153 #define F_ACTIVATION_RELU       2
154
155 static void rnnoise_model_free(RNNModel *model)
156 {
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
159     if (model->name) { \
160         av_free((void *) model->name->input_weights); \
161         av_free((void *) model->name->bias); \
162         av_free((void *) model->name); \
163     } \
164     } while (0)
165 #define FREE_GRU(name) do { \
166     if (model->name) { \
167         av_free((void *) model->name->input_weights); \
168         av_free((void *) model->name->recurrent_weights); \
169         av_free((void *) model->name->bias); \
170         av_free((void *) model->name); \
171     } \
172     } while (0)
173
174     if (!model)
175         return;
176     FREE_DENSE(input_dense);
177     FREE_GRU(vad_gru);
178     FREE_GRU(noise_gru);
179     FREE_GRU(denoise_gru);
180     FREE_DENSE(denoise_output);
181     FREE_DENSE(vad_output);
182     av_free(model);
183 }
184
185 static RNNModel *rnnoise_model_from_file(FILE *f)
186 {
187     RNNModel *ret;
188     DenseLayer *input_dense;
189     GRULayer *vad_gru;
190     GRULayer *noise_gru;
191     GRULayer *denoise_gru;
192     DenseLayer *denoise_output;
193     DenseLayer *vad_output;
194     int in;
195
196     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
197         return NULL;
198
199     ret = av_calloc(1, sizeof(RNNModel));
200     if (!ret)
201         return NULL;
202
203 #define ALLOC_LAYER(type, name) \
204     name = av_calloc(1, sizeof(type)); \
205     if (!name) { \
206         rnnoise_model_free(ret); \
207         return NULL; \
208     } \
209     ret->name = name
210
211     ALLOC_LAYER(DenseLayer, input_dense);
212     ALLOC_LAYER(GRULayer, vad_gru);
213     ALLOC_LAYER(GRULayer, noise_gru);
214     ALLOC_LAYER(GRULayer, denoise_gru);
215     ALLOC_LAYER(DenseLayer, denoise_output);
216     ALLOC_LAYER(DenseLayer, vad_output);
217
218 #define INPUT_VAL(name) do { \
219     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220         rnnoise_model_free(ret); \
221         return NULL; \
222     } \
223     name = in; \
224     } while (0)
225
226 #define INPUT_ACTIVATION(name) do { \
227     int activation; \
228     INPUT_VAL(activation); \
229     switch (activation) { \
230     case F_ACTIVATION_SIGMOID: \
231         name = ACTIVATION_SIGMOID; \
232         break; \
233     case F_ACTIVATION_RELU: \
234         name = ACTIVATION_RELU; \
235         break; \
236     default: \
237         name = ACTIVATION_TANH; \
238     } \
239     } while (0)
240
241 #define INPUT_ARRAY(name, len) do { \
242     float *values = av_calloc((len), sizeof(float)); \
243     if (!values) { \
244         rnnoise_model_free(ret); \
245         return NULL; \
246     } \
247     name = values; \
248     for (int i = 0; i < (len); i++) { \
249         if (fscanf(f, "%d", &in) != 1) { \
250             rnnoise_model_free(ret); \
251             return NULL; \
252         } \
253         values[i] = in; \
254     } \
255     } while (0)
256
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
259     if (!values) { \
260         rnnoise_model_free(ret); \
261         return NULL; \
262     } \
263     name = values; \
264     for (int k = 0; k < (len0); k++) { \
265         for (int i = 0; i < (len2); i++) { \
266             for (int j = 0; j < (len1); j++) { \
267                 if (fscanf(f, "%d", &in) != 1) { \
268                     rnnoise_model_free(ret); \
269                     return NULL; \
270                 } \
271                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
272             } \
273         } \
274     } \
275     } while (0)
276
277 #define INPUT_DENSE(name) do { \
278     INPUT_VAL(name->nb_inputs); \
279     INPUT_VAL(name->nb_neurons); \
280     ret->name ## _size = name->nb_neurons; \
281     INPUT_ACTIVATION(name->activation); \
282     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283     INPUT_ARRAY(name->bias, name->nb_neurons); \
284     } while (0)
285
286 #define INPUT_GRU(name) do { \
287     INPUT_VAL(name->nb_inputs); \
288     INPUT_VAL(name->nb_neurons); \
289     ret->name ## _size = name->nb_neurons; \
290     INPUT_ACTIVATION(name->activation); \
291     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
294     } while (0)
295
296     INPUT_DENSE(input_dense);
297     INPUT_GRU(vad_gru);
298     INPUT_GRU(noise_gru);
299     INPUT_GRU(denoise_gru);
300     INPUT_DENSE(denoise_output);
301     INPUT_DENSE(vad_output);
302
303     if (vad_output->nb_neurons != 1) {
304         rnnoise_model_free(ret);
305         return NULL;
306     }
307
308     return ret;
309 }
310
311 static int query_formats(AVFilterContext *ctx)
312 {
313     AVFilterFormats *formats = NULL;
314     AVFilterChannelLayouts *layouts = NULL;
315     static const enum AVSampleFormat sample_fmts[] = {
316         AV_SAMPLE_FMT_FLTP,
317         AV_SAMPLE_FMT_NONE
318     };
319     int ret, sample_rates[] = { 48000, -1 };
320
321     formats = ff_make_format_list(sample_fmts);
322     if (!formats)
323         return AVERROR(ENOMEM);
324     ret = ff_set_common_formats(ctx, formats);
325     if (ret < 0)
326         return ret;
327
328     layouts = ff_all_channel_counts();
329     if (!layouts)
330         return AVERROR(ENOMEM);
331
332     ret = ff_set_common_channel_layouts(ctx, layouts);
333     if (ret < 0)
334         return ret;
335
336     formats = ff_make_format_list(sample_rates);
337     if (!formats)
338         return AVERROR(ENOMEM);
339     return ff_set_common_samplerates(ctx, formats);
340 }
341
342 static int config_input(AVFilterLink *inlink)
343 {
344     AVFilterContext *ctx = inlink->dst;
345     AudioRNNContext *s = ctx->priv;
346     int ret;
347
348     s->channels = inlink->channels;
349
350     s->st = av_calloc(s->channels, sizeof(DenoiseState));
351     if (!s->st)
352         return AVERROR(ENOMEM);
353
354     for (int i = 0; i < s->channels; i++) {
355         DenoiseState *st = &s->st[i];
356
357         st->rnn.model = s->model;
358         st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359         st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360         st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361         if (!st->rnn.vad_gru_state ||
362             !st->rnn.noise_gru_state ||
363             !st->rnn.denoise_gru_state)
364             return AVERROR(ENOMEM);
365
366         ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
367         if (ret < 0)
368             return ret;
369
370         ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
371         if (ret < 0)
372             return ret;
373     }
374
375     return 0;
376 }
377
378 static void biquad(float *y, float mem[2], const float *x,
379                    const float *b, const float *a, int N)
380 {
381     for (int i = 0; i < N; i++) {
382         float xi, yi;
383
384         xi = x[i];
385         yi = x[i] + mem[0];
386         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387         mem[1] = (b[1]*xi - a[1]*yi);
388         y[i] = yi;
389     }
390 }
391
392 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
395
396 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
397 {
398     AVComplexFloat x[WINDOW_SIZE];
399     AVComplexFloat y[WINDOW_SIZE];
400
401     for (int i = 0; i < WINDOW_SIZE; i++) {
402         x[i].re = in[i];
403         x[i].im = 0;
404     }
405
406     st->tx_fn(st->tx, y, x, sizeof(float));
407
408     RNN_COPY(out, y, FREQ_SIZE);
409 }
410
411 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
412 {
413     AVComplexFloat x[WINDOW_SIZE];
414     AVComplexFloat y[WINDOW_SIZE];
415
416     for (int i = 0; i < FREQ_SIZE; i++)
417         x[i] = in[i];
418
419     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
420         x[i].re =  x[WINDOW_SIZE - i].re;
421         x[i].im = -x[WINDOW_SIZE - i].im;
422     }
423
424     st->txi_fn(st->txi, y, x, sizeof(float));
425
426     for (int i = 0; i < WINDOW_SIZE; i++)
427         out[i] = y[i].re / WINDOW_SIZE;
428 }
429
430 static const uint8_t eband5ms[] = {
431 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
432   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
433 };
434
435 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
436 {
437     float sum[NB_BANDS] = {0};
438
439     for (int i = 0; i < NB_BANDS - 1; i++) {
440         int band_size;
441
442         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
443         for (int j = 0; j < band_size; j++) {
444             float tmp, frac = (float)j / band_size;
445
446             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
447             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
448             sum[i]     += (1.f - frac) * tmp;
449             sum[i + 1] +=        frac  * tmp;
450         }
451     }
452
453     sum[0] *= 2;
454     sum[NB_BANDS - 1] *= 2;
455
456     for (int i = 0; i < NB_BANDS; i++)
457         bandE[i] = sum[i];
458 }
459
460 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
461 {
462     float sum[NB_BANDS] = { 0 };
463
464     for (int i = 0; i < NB_BANDS - 1; i++) {
465         int band_size;
466
467         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
468         for (int j = 0; j < band_size; j++) {
469             float tmp, frac = (float)j / band_size;
470
471             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
472             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
473             sum[i]     += (1 - frac) * tmp;
474             sum[i + 1] +=      frac  * tmp;
475         }
476     }
477
478     sum[0] *= 2;
479     sum[NB_BANDS-1] *= 2;
480
481     for (int i = 0; i < NB_BANDS; i++)
482         bandE[i] = sum[i];
483 }
484
485 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
486 {
487     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
488
489     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
490     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
491     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
492     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
493     forward_transform(st, X, x);
494     compute_band_energy(Ex, X);
495 }
496
497 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
498 {
499     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
500
501     inverse_transform(st, x, y);
502     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
503     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
504     RNN_COPY(out, x, FRAME_SIZE);
505     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
506 }
507
508 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
509 {
510     float y_0, y_1, y_2, y_3 = 0;
511     int j;
512
513     y_0 = *y++;
514     y_1 = *y++;
515     y_2 = *y++;
516
517     for (j = 0; j < len - 3; j += 4) {
518         float tmp;
519
520         tmp = *x++;
521         y_3 = *y++;
522         sum[0] += tmp * y_0;
523         sum[1] += tmp * y_1;
524         sum[2] += tmp * y_2;
525         sum[3] += tmp * y_3;
526         tmp = *x++;
527         y_0 = *y++;
528         sum[0] += tmp * y_1;
529         sum[1] += tmp * y_2;
530         sum[2] += tmp * y_3;
531         sum[3] += tmp * y_0;
532         tmp = *x++;
533         y_1 = *y++;
534         sum[0] += tmp * y_2;
535         sum[1] += tmp * y_3;
536         sum[2] += tmp * y_0;
537         sum[3] += tmp * y_1;
538         tmp = *x++;
539         y_2 = *y++;
540         sum[0] += tmp * y_3;
541         sum[1] += tmp * y_0;
542         sum[2] += tmp * y_1;
543         sum[3] += tmp * y_2;
544     }
545
546     if (j++ < len) {
547         float tmp = *x++;
548
549         y_3 = *y++;
550         sum[0] += tmp * y_0;
551         sum[1] += tmp * y_1;
552         sum[2] += tmp * y_2;
553         sum[3] += tmp * y_3;
554     }
555
556     if (j++ < len) {
557         float tmp=*x++;
558
559         y_0 = *y++;
560         sum[0] += tmp * y_1;
561         sum[1] += tmp * y_2;
562         sum[2] += tmp * y_3;
563         sum[3] += tmp * y_0;
564     }
565
566     if (j < len) {
567         float tmp=*x++;
568
569         y_1 = *y++;
570         sum[0] += tmp * y_2;
571         sum[1] += tmp * y_3;
572         sum[2] += tmp * y_0;
573         sum[3] += tmp * y_1;
574     }
575 }
576
577 static inline float celt_inner_prod(const float *x,
578                                     const float *y, int N)
579 {
580     float xy = 0.f;
581
582     for (int i = 0; i < N; i++)
583         xy += x[i] * y[i];
584
585     return xy;
586 }
587
588 static void celt_pitch_xcorr(const float *x, const float *y,
589                              float *xcorr, int len, int max_pitch)
590 {
591     int i;
592
593     for (i = 0; i < max_pitch - 3; i += 4) {
594         float sum[4] = { 0, 0, 0, 0};
595
596         xcorr_kernel(x, y + i, sum, len);
597
598         xcorr[i]     = sum[0];
599         xcorr[i + 1] = sum[1];
600         xcorr[i + 2] = sum[2];
601         xcorr[i + 3] = sum[3];
602     }
603     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
604     for (; i < max_pitch; i++) {
605         xcorr[i] = celt_inner_prod(x, y + i, len);
606     }
607 }
608
609 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
610                          float       *ac,  /* out: [0...lag-1] ac values */
611                          const float *window,
612                          int          overlap,
613                          int          lag,
614                          int          n)
615 {
616     int fastN = n - lag;
617     int shift;
618     const float *xptr;
619     float xx[PITCH_BUF_SIZE>>1];
620
621     if (overlap == 0) {
622         xptr = x;
623     } else {
624         for (int i = 0; i < n; i++)
625             xx[i] = x[i];
626         for (int i = 0; i < overlap; i++) {
627             xx[i] = x[i] * window[i];
628             xx[n-i-1] = x[n-i-1] * window[i];
629         }
630         xptr = xx;
631     }
632
633     shift = 0;
634     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
635
636     for (int k = 0; k <= lag; k++) {
637         float d = 0.f;
638
639         for (int i = k + fastN; i < n; i++)
640             d += xptr[i] * xptr[i-k];
641         ac[k] += d;
642     }
643
644     return shift;
645 }
646
647 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
648                 const float *ac,   /* in:  [0...p] autocorrelation values  */
649                           int p)
650 {
651     float r, error = ac[0];
652
653     RNN_CLEAR(lpc, p);
654     if (ac[0] != 0) {
655         for (int i = 0; i < p; i++) {
656             /* Sum up this iteration's reflection coefficient */
657             float rr = 0;
658             for (int j = 0; j < i; j++)
659                 rr += (lpc[j] * ac[i - j]);
660             rr += ac[i + 1];
661             r = -rr/error;
662             /*  Update LPC coefficients and total error */
663             lpc[i] = r;
664             for (int j = 0; j < (i + 1) >> 1; j++) {
665                 float tmp1, tmp2;
666                 tmp1 = lpc[j];
667                 tmp2 = lpc[i-1-j];
668                 lpc[j]     = tmp1 + (r*tmp2);
669                 lpc[i-1-j] = tmp2 + (r*tmp1);
670             }
671
672             error = error - (r * r *error);
673             /* Bail out once we get 30 dB gain */
674             if (error < .001f * ac[0])
675                 break;
676         }
677     }
678 }
679
680 static void celt_fir5(const float *x,
681                       const float *num,
682                       float *y,
683                       int N,
684                       float *mem)
685 {
686     float num0, num1, num2, num3, num4;
687     float mem0, mem1, mem2, mem3, mem4;
688
689     num0 = num[0];
690     num1 = num[1];
691     num2 = num[2];
692     num3 = num[3];
693     num4 = num[4];
694     mem0 = mem[0];
695     mem1 = mem[1];
696     mem2 = mem[2];
697     mem3 = mem[3];
698     mem4 = mem[4];
699
700     for (int i = 0; i < N; i++) {
701         float sum = x[i];
702
703         sum += (num0*mem0);
704         sum += (num1*mem1);
705         sum += (num2*mem2);
706         sum += (num3*mem3);
707         sum += (num4*mem4);
708         mem4 = mem3;
709         mem3 = mem2;
710         mem2 = mem1;
711         mem1 = mem0;
712         mem0 = x[i];
713         y[i] = sum;
714     }
715
716     mem[0] = mem0;
717     mem[1] = mem1;
718     mem[2] = mem2;
719     mem[3] = mem3;
720     mem[4] = mem4;
721 }
722
723 static void pitch_downsample(float *x[], float *x_lp,
724                              int len, int C)
725 {
726     float ac[5];
727     float tmp=Q15ONE;
728     float lpc[4], mem[5]={0,0,0,0,0};
729     float lpc2[5];
730     float c1 = .8f;
731
732     for (int i = 1; i < len >> 1; i++)
733         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
734     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
735     if (C==2) {
736         for (int i = 1; i < len >> 1; i++)
737             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
738         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
739     }
740
741     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
742
743     /* Noise floor -40 dB */
744     ac[0] *= 1.0001f;
745     /* Lag windowing */
746     for (int i = 1; i <= 4; i++) {
747         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
748         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
749     }
750
751     celt_lpc(lpc, ac, 4);
752     for (int i = 0; i < 4; i++) {
753         tmp = .9f * tmp;
754         lpc[i] = (lpc[i] * tmp);
755     }
756     /* Add a zero */
757     lpc2[0] = lpc[0] + .8f;
758     lpc2[1] = lpc[1] + (c1 * lpc[0]);
759     lpc2[2] = lpc[2] + (c1 * lpc[1]);
760     lpc2[3] = lpc[3] + (c1 * lpc[2]);
761     lpc2[4] = (c1 * lpc[3]);
762     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
763 }
764
765 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
766                                    int N, float *xy1, float *xy2)
767 {
768     float xy01 = 0, xy02 = 0;
769
770     for (int i = 0; i < N; i++) {
771         xy01 += (x[i] * y01[i]);
772         xy02 += (x[i] * y02[i]);
773     }
774
775     *xy1 = xy01;
776     *xy2 = xy02;
777 }
778
779 static float compute_pitch_gain(float xy, float xx, float yy)
780 {
781     return xy / sqrtf(1.f + xx * yy);
782 }
783
784 static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
785 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
786                              int *T0_, int prev_period, float prev_gain)
787 {
788     int k, i, T, T0;
789     float g, g0;
790     float pg;
791     float xy,xx,yy,xy2;
792     float xcorr[3];
793     float best_xy, best_yy;
794     int offset;
795     int minperiod0;
796     float yy_lookup[PITCH_MAX_PERIOD+1];
797
798     minperiod0 = minperiod;
799     maxperiod /= 2;
800     minperiod /= 2;
801     *T0_ /= 2;
802     prev_period /= 2;
803     N /= 2;
804     x += maxperiod;
805     if (*T0_>=maxperiod)
806         *T0_=maxperiod-1;
807
808     T = T0 = *T0_;
809     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
810     yy_lookup[0] = xx;
811     yy=xx;
812     for (i = 1; i <= maxperiod; i++) {
813         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
814         yy_lookup[i] = FFMAX(0, yy);
815     }
816     yy = yy_lookup[T0];
817     best_xy = xy;
818     best_yy = yy;
819     g = g0 = compute_pitch_gain(xy, xx, yy);
820     /* Look for any pitch at T/k */
821     for (k = 2; k <= 15; k++) {
822         int T1, T1b;
823         float g1;
824         float cont=0;
825         float thresh;
826         T1 = (2*T0+k)/(2*k);
827         if (T1 < minperiod)
828             break;
829         /* Look for another strong correlation at T1b */
830         if (k==2)
831         {
832             if (T1+T0>maxperiod)
833                 T1b = T0;
834             else
835                 T1b = T0+T1;
836         } else
837         {
838             T1b = (2*second_check[k]*T0+k)/(2*k);
839         }
840         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
841         xy = .5f * (xy + xy2);
842         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
843         g1 = compute_pitch_gain(xy, xx, yy);
844         if (FFABS(T1-prev_period)<=1)
845             cont = prev_gain;
846         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
847             cont = prev_gain * .5f;
848         else
849             cont = 0;
850         thresh = FFMAX(.3f, (.7f * g0) - cont);
851         /* Bias against very high pitch (very short period) to avoid false-positives
852            due to short-term correlation */
853         if (T1<3*minperiod)
854             thresh = FFMAX(.4f, (.85f * g0) - cont);
855         else if (T1<2*minperiod)
856             thresh = FFMAX(.5f, (.9f * g0) - cont);
857         if (g1 > thresh)
858         {
859             best_xy = xy;
860             best_yy = yy;
861             T = T1;
862             g = g1;
863         }
864     }
865     best_xy = FFMAX(0, best_xy);
866     if (best_yy <= best_xy)
867         pg = Q15ONE;
868     else
869         pg = best_xy/(best_yy + 1);
870
871     for (k = 0; k < 3; k++)
872         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
873     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
874         offset = 1;
875     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
876         offset = -1;
877     else
878         offset = 0;
879     if (pg > g)
880         pg = g;
881     *T0_ = 2*T+offset;
882
883     if (*T0_<minperiod0)
884         *T0_=minperiod0;
885     return pg;
886 }
887
888 static void find_best_pitch(float *xcorr, float *y, int len,
889                             int max_pitch, int *best_pitch)
890 {
891     float best_num[2];
892     float best_den[2];
893     float Syy = 1.f;
894
895     best_num[0] = -1;
896     best_num[1] = -1;
897     best_den[0] = 0;
898     best_den[1] = 0;
899     best_pitch[0] = 0;
900     best_pitch[1] = 1;
901
902     for (int j = 0; j < len; j++)
903         Syy += y[j] * y[j];
904
905     for (int i = 0; i < max_pitch; i++) {
906         if (xcorr[i]>0) {
907             float num;
908             float xcorr16;
909
910             xcorr16 = xcorr[i];
911             /* Considering the range of xcorr16, this should avoid both underflows
912                and overflows (inf) when squaring xcorr16 */
913             xcorr16 *= 1e-12f;
914             num = xcorr16 * xcorr16;
915             if ((num * best_den[1]) > (best_num[1] * Syy)) {
916                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
917                     best_num[1] = best_num[0];
918                     best_den[1] = best_den[0];
919                     best_pitch[1] = best_pitch[0];
920                     best_num[0] = num;
921                     best_den[0] = Syy;
922                     best_pitch[0] = i;
923                 } else {
924                     best_num[1] = num;
925                     best_den[1] = Syy;
926                     best_pitch[1] = i;
927                 }
928             }
929         }
930         Syy += y[i+len]*y[i+len] - y[i] * y[i];
931         Syy = FFMAX(1, Syy);
932     }
933 }
934
935 static void pitch_search(const float *x_lp, float *y,
936                          int len, int max_pitch, int *pitch)
937 {
938     int lag;
939     int best_pitch[2]={0,0};
940     int offset;
941
942     float x_lp4[WINDOW_SIZE];
943     float y_lp4[WINDOW_SIZE];
944     float xcorr[WINDOW_SIZE];
945
946     lag = len+max_pitch;
947
948     /* Downsample by 2 again */
949     for (int j = 0; j < len >> 2; j++)
950         x_lp4[j] = x_lp[2*j];
951     for (int j = 0; j < lag >> 2; j++)
952         y_lp4[j] = y[2*j];
953
954     /* Coarse search with 4x decimation */
955
956     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
957
958     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
959
960     /* Finer search with 2x decimation */
961     for (int i = 0; i < max_pitch >> 1; i++) {
962         float sum;
963         xcorr[i] = 0;
964         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
965             continue;
966         sum = celt_inner_prod(x_lp, y+i, len>>1);
967         xcorr[i] = FFMAX(-1, sum);
968     }
969
970     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
971
972     /* Refine by pseudo-interpolation */
973     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
974         float a, b, c;
975
976         a = xcorr[best_pitch[0] - 1];
977         b = xcorr[best_pitch[0]];
978         c = xcorr[best_pitch[0] + 1];
979         if (c - a > .7f * (b - a))
980             offset = 1;
981         else if (a - c > .7f * (b-c))
982             offset = -1;
983         else
984             offset = 0;
985     } else {
986         offset = 0;
987     }
988
989     *pitch = 2 * best_pitch[0] - offset;
990 }
991
992 static void dct(AudioRNNContext *s, float *out, const float *in)
993 {
994     for (int i = 0; i < NB_BANDS; i++) {
995         float sum = 0.f;
996
997         for (int j = 0; j < NB_BANDS; j++) {
998             sum += in[j] * s->dct_table[j * NB_BANDS + i];
999         }
1000         out[i] = sum * sqrtf(2.f / 22);
1001     }
1002 }
1003
1004 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1005                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1006 {
1007     float E = 0;
1008     float *ceps_0, *ceps_1, *ceps_2;
1009     float spec_variability = 0;
1010     float Ly[NB_BANDS];
1011     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1012     float pitch_buf[PITCH_BUF_SIZE>>1];
1013     int pitch_index;
1014     float gain;
1015     float *(pre[1]);
1016     float tmp[NB_BANDS];
1017     float follow, logMax;
1018
1019     frame_analysis(s, st, X, Ex, in);
1020     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1021     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1022     pre[0] = &st->pitch_buf[0];
1023     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1024     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1025             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1026     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1027
1028     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1029             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1030     st->last_period = pitch_index;
1031     st->last_gain = gain;
1032
1033     for (int i = 0; i < WINDOW_SIZE; i++)
1034         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1035
1036     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1037     forward_transform(st, P, p);
1038     compute_band_energy(Ep, P);
1039     compute_band_corr(Exp, X, P);
1040
1041     for (int i = 0; i < NB_BANDS; i++)
1042         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1043
1044     dct(s, tmp, Exp);
1045
1046     for (int i = 0; i < NB_DELTA_CEPS; i++)
1047         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1048
1049     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1050     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1051     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1052     logMax = -2;
1053     follow = -2;
1054
1055     for (int i = 0; i < NB_BANDS; i++) {
1056         Ly[i] = log10f(1e-2f + Ex[i]);
1057         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1058         logMax = FFMAX(logMax, Ly[i]);
1059         follow = FFMAX(follow-1.5, Ly[i]);
1060         E += Ex[i];
1061     }
1062
1063     if (E < 0.04f) {
1064         /* If there's no audio, avoid messing up the state. */
1065         RNN_CLEAR(features, NB_FEATURES);
1066         return 1;
1067     }
1068
1069     dct(s, features, Ly);
1070     features[0] -= 12;
1071     features[1] -= 4;
1072     ceps_0 = st->cepstral_mem[st->memid];
1073     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1074     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1075
1076     for (int i = 0; i < NB_BANDS; i++)
1077         ceps_0[i] = features[i];
1078
1079     st->memid++;
1080     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1081         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1082         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1083         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1084     }
1085     /* Spectral variability features. */
1086     if (st->memid == CEPS_MEM)
1087         st->memid = 0;
1088
1089     for (int i = 0; i < CEPS_MEM; i++) {
1090         float mindist = 1e15f;
1091         for (int j = 0; j < CEPS_MEM; j++) {
1092             float dist = 0.f;
1093             for (int k = 0; k < NB_BANDS; k++) {
1094                 float tmp;
1095
1096                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1097                 dist += tmp*tmp;
1098             }
1099
1100             if (j != i)
1101                 mindist = FFMIN(mindist, dist);
1102         }
1103
1104         spec_variability += mindist;
1105     }
1106
1107     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1108
1109     return 0;
1110 }
1111
1112 static void interp_band_gain(float *g, const float *bandE)
1113 {
1114     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1115
1116     for (int i = 0; i < NB_BANDS - 1; i++) {
1117         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1118
1119         for (int j = 0; j < band_size; j++) {
1120             float frac = (float)j / band_size;
1121
1122             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1123         }
1124     }
1125 }
1126
1127 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1128                          const float *Exp, const float *g)
1129 {
1130     float newE[NB_BANDS];
1131     float r[NB_BANDS];
1132     float norm[NB_BANDS];
1133     float rf[FREQ_SIZE] = {0};
1134     float normf[FREQ_SIZE]={0};
1135
1136     for (int i = 0; i < NB_BANDS; i++) {
1137         if (Exp[i]>g[i]) r[i] = 1;
1138         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1139         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1140         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1141     }
1142     interp_band_gain(rf, r);
1143     for (int i = 0; i < FREQ_SIZE; i++) {
1144         X[i].re += rf[i]*P[i].re;
1145         X[i].im += rf[i]*P[i].im;
1146     }
1147     compute_band_energy(newE, X);
1148     for (int i = 0; i < NB_BANDS; i++) {
1149         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1150     }
1151     interp_band_gain(normf, norm);
1152     for (int i = 0; i < FREQ_SIZE; i++) {
1153         X[i].re *= normf[i];
1154         X[i].im *= normf[i];
1155     }
1156 }
1157
1158 static const float tansig_table[201] = {
1159     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1160     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1161     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1162     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1163     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1164     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1165     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1166     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1167     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1168     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1169     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1170     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1171     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1172     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1173     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1174     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1175     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1176     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1177     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1178     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1179     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1180     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1181     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1182     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1183     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1184     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1185     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1186     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1187     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1188     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1189     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1190     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1191     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1192     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1193     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1194     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1195     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1196     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1197     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1198     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1199     1.000000f,
1200 };
1201
1202 static inline float tansig_approx(float x)
1203 {
1204     float y, dy;
1205     float sign=1;
1206     int i;
1207
1208     /* Tests are reversed to catch NaNs */
1209     if (!(x<8))
1210         return 1;
1211     if (!(x>-8))
1212         return -1;
1213     /* Another check in case of -ffast-math */
1214
1215     if (isnan(x))
1216        return 0;
1217
1218     if (x < 0) {
1219        x=-x;
1220        sign=-1;
1221     }
1222     i = (int)floor(.5f+25*x);
1223     x -= .04f*i;
1224     y = tansig_table[i];
1225     dy = 1-y*y;
1226     y = y + x*dy*(1 - y*x);
1227     return sign*y;
1228 }
1229
1230 static inline float sigmoid_approx(float x)
1231 {
1232     return .5f + .5f*tansig_approx(.5f*x);
1233 }
1234
1235 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1236 {
1237     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1238
1239     for (int i = 0; i < N; i++) {
1240         /* Compute update gate. */
1241         float sum = layer->bias[i];
1242
1243         for (int j = 0; j < M; j++)
1244             sum += layer->input_weights[j * stride + i] * input[j];
1245
1246         output[i] = WEIGHTS_SCALE * sum;
1247     }
1248
1249     if (layer->activation == ACTIVATION_SIGMOID) {
1250         for (int i = 0; i < N; i++)
1251             output[i] = sigmoid_approx(output[i]);
1252     } else if (layer->activation == ACTIVATION_TANH) {
1253         for (int i = 0; i < N; i++)
1254             output[i] = tansig_approx(output[i]);
1255     } else if (layer->activation == ACTIVATION_RELU) {
1256         for (int i = 0; i < N; i++)
1257             output[i] = FFMAX(0, output[i]);
1258     } else {
1259         av_assert0(0);
1260     }
1261 }
1262
1263 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1264 {
1265     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1266     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1267     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1268     const int M = gru->nb_inputs;
1269     const int N = gru->nb_neurons;
1270     const int AN = FFALIGN(N, 4);
1271     const int AM = FFALIGN(M, 4);
1272     const int stride = 3 * AN, istride = 3 * AM;
1273
1274     for (int i = 0; i < N; i++) {
1275         /* Compute update gate. */
1276         float sum = gru->bias[i];
1277
1278         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1279         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1280         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1281     }
1282
1283     for (int i = 0; i < N; i++) {
1284         /* Compute reset gate. */
1285         float sum = gru->bias[N + i];
1286
1287         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1288         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1289         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1290     }
1291
1292     for (int i = 0; i < N; i++) {
1293         /* Compute output. */
1294         float sum = gru->bias[2 * N + i];
1295
1296         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1297         for (int j = 0; j < N; j++)
1298             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1299
1300         if (gru->activation == ACTIVATION_SIGMOID)
1301             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1302         else if (gru->activation == ACTIVATION_TANH)
1303             sum = tansig_approx(WEIGHTS_SCALE * sum);
1304         else if (gru->activation == ACTIVATION_RELU)
1305             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1306         else
1307             av_assert0(0);
1308         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1309     }
1310
1311     RNN_COPY(state, h, N);
1312 }
1313
1314 #define INPUT_SIZE 42
1315
1316 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1317 {
1318     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1319     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1320     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1321
1322     compute_dense(rnn->model->input_dense, dense_out, input);
1323     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1324     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1325
1326     for (int i = 0; i < rnn->model->input_dense_size; i++)
1327         noise_input[i] = dense_out[i];
1328     for (int i = 0; i < rnn->model->vad_gru_size; i++)
1329         noise_input[i + rnn->model->input_dense_size] = rnn->vad_gru_state[i];
1330     for (int i = 0; i < INPUT_SIZE; i++)
1331         noise_input[i + rnn->model->input_dense_size + rnn->model->vad_gru_size] = input[i];
1332
1333     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1334
1335     for (int i = 0; i < rnn->model->vad_gru_size; i++)
1336         denoise_input[i] = rnn->vad_gru_state[i];
1337     for (int i = 0; i < rnn->model->noise_gru_size; i++)
1338         denoise_input[i + rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
1339     for (int i = 0; i < INPUT_SIZE; i++)
1340         denoise_input[i + rnn->model->vad_gru_size + rnn->model->noise_gru_size] = input[i];
1341
1342     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1343     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1344 }
1345
1346 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1347 {
1348     AVComplexFloat X[FREQ_SIZE];
1349     AVComplexFloat P[WINDOW_SIZE];
1350     float x[FRAME_SIZE];
1351     float Ex[NB_BANDS], Ep[NB_BANDS];
1352     float Exp[NB_BANDS];
1353     float features[NB_FEATURES];
1354     float g[NB_BANDS];
1355     float gf[FREQ_SIZE];
1356     float vad_prob = 0;
1357     static const float a_hp[2] = {-1.99599, 0.99600};
1358     static const float b_hp[2] = {-2, 1};
1359     int silence;
1360
1361     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1362     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1363
1364     if (!silence) {
1365         compute_rnn(s, &st->rnn, g, &vad_prob, features);
1366         pitch_filter(X, P, Ex, Ep, Exp, g);
1367         for (int i = 0; i < NB_BANDS; i++) {
1368             float alpha = .6f;
1369
1370             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1371             st->lastg[i] = g[i];
1372         }
1373
1374         interp_band_gain(gf, g);
1375
1376         for (int i = 0; i < FREQ_SIZE; i++) {
1377             X[i].re *= gf[i];
1378             X[i].im *= gf[i];
1379         }
1380     }
1381
1382     frame_synthesis(s, st, out, X);
1383
1384     return vad_prob;
1385 }
1386
1387 typedef struct ThreadData {
1388     AVFrame *in, *out;
1389 } ThreadData;
1390
1391 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1392 {
1393     AudioRNNContext *s = ctx->priv;
1394     ThreadData *td = arg;
1395     AVFrame *in = td->in;
1396     AVFrame *out = td->out;
1397     const int start = (out->channels * jobnr) / nb_jobs;
1398     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1399
1400     for (int ch = start; ch < end; ch++) {
1401         rnnoise_channel(s, &s->st[ch],
1402                         (float *)out->extended_data[ch],
1403                         (const float *)in->extended_data[ch]);
1404     }
1405
1406     return 0;
1407 }
1408
1409 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1410 {
1411     AVFilterContext *ctx = inlink->dst;
1412     AVFilterLink *outlink = ctx->outputs[0];
1413     AVFrame *out = NULL;
1414     ThreadData td;
1415
1416     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1417     if (!out) {
1418         av_frame_free(&in);
1419         return AVERROR(ENOMEM);
1420     }
1421     out->pts = in->pts;
1422
1423     td.in = in; td.out = out;
1424     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1425                                                                    ff_filter_get_nb_threads(ctx)));
1426
1427     av_frame_free(&in);
1428     return ff_filter_frame(outlink, out);
1429 }
1430
1431 static int activate(AVFilterContext *ctx)
1432 {
1433     AVFilterLink *inlink = ctx->inputs[0];
1434     AVFilterLink *outlink = ctx->outputs[0];
1435     AVFrame *in = NULL;
1436     int ret;
1437
1438     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1439
1440     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1441     if (ret < 0)
1442         return ret;
1443
1444     if (ret > 0)
1445         return filter_frame(inlink, in);
1446
1447     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1448     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1449
1450     return FFERROR_NOT_READY;
1451 }
1452
1453 static av_cold int init(AVFilterContext *ctx)
1454 {
1455     AudioRNNContext *s = ctx->priv;
1456     FILE *f;
1457
1458     s->fdsp = avpriv_float_dsp_alloc(0);
1459     if (!s->fdsp)
1460         return AVERROR(ENOMEM);
1461
1462     if (!s->model_name)
1463         return AVERROR(EINVAL);
1464     f = av_fopen_utf8(s->model_name, "r");
1465     if (!f)
1466         return AVERROR(EINVAL);
1467
1468     s->model = rnnoise_model_from_file(f);
1469     fclose(f);
1470     if (!s->model)
1471         return AVERROR(EINVAL);
1472
1473     for (int i = 0; i < FRAME_SIZE; i++) {
1474         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1475         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1476     }
1477
1478     for (int i = 0; i < NB_BANDS; i++) {
1479         for (int j = 0; j < NB_BANDS; j++) {
1480             s->dct_table[i*NB_BANDS + j] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1481             if (j == 0)
1482                 s->dct_table[i*NB_BANDS + j] *= sqrtf(.5);
1483         }
1484     }
1485
1486     return 0;
1487 }
1488
1489 static av_cold void uninit(AVFilterContext *ctx)
1490 {
1491     AudioRNNContext *s = ctx->priv;
1492
1493     av_freep(&s->fdsp);
1494     rnnoise_model_free(s->model);
1495     s->model = NULL;
1496
1497     if (s->st) {
1498         for (int ch = 0; ch < s->channels; ch++) {
1499             av_freep(&s->st[ch].rnn.vad_gru_state);
1500             av_freep(&s->st[ch].rnn.noise_gru_state);
1501             av_freep(&s->st[ch].rnn.denoise_gru_state);
1502             av_tx_uninit(&s->st[ch].tx);
1503             av_tx_uninit(&s->st[ch].txi);
1504         }
1505     }
1506     av_freep(&s->st);
1507 }
1508
1509 static const AVFilterPad inputs[] = {
1510     {
1511         .name         = "default",
1512         .type         = AVMEDIA_TYPE_AUDIO,
1513         .config_props = config_input,
1514     },
1515     { NULL }
1516 };
1517
1518 static const AVFilterPad outputs[] = {
1519     {
1520         .name          = "default",
1521         .type          = AVMEDIA_TYPE_AUDIO,
1522     },
1523     { NULL }
1524 };
1525
1526 #define OFFSET(x) offsetof(AudioRNNContext, x)
1527 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1528
1529 static const AVOption arnndn_options[] = {
1530     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1531     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1532     { NULL }
1533 };
1534
1535 AVFILTER_DEFINE_CLASS(arnndn);
1536
1537 AVFilter ff_af_arnndn = {
1538     .name          = "arnndn",
1539     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1540     .query_formats = query_formats,
1541     .priv_size     = sizeof(AudioRNNContext),
1542     .priv_class    = &arnndn_class,
1543     .activate      = activate,
1544     .init          = init,
1545     .uninit        = uninit,
1546     .inputs        = inputs,
1547     .outputs       = outputs,
1548     .flags         = AVFILTER_FLAG_SLICE_THREADS,
1549 };