]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
avfilter: add arnndn filter
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
41 #include "avfilter.h"
42 #include "audio.h"
43 #include "filters.h"
44 #include "formats.h"
45
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56 #define SQUARE(x) ((x)*(x))
57
58 #define NB_BANDS 22
59
60 #define CEPS_MEM 8
61 #define NB_DELTA_CEPS 6
62
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65 #define WEIGHTS_SCALE (1.f/256)
66
67 #define MAX_NEURONS 128
68
69 #define ACTIVATION_TANH    0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU    2
72
73 #define Q15ONE 1.0f
74
75 typedef struct DenseLayer {
76     const float *bias;
77     const float *input_weights;
78     int nb_inputs;
79     int nb_neurons;
80     int activation;
81 } DenseLayer;
82
83 typedef struct GRULayer {
84     const float *bias;
85     const float *input_weights;
86     const float *recurrent_weights;
87     int nb_inputs;
88     int nb_neurons;
89     int activation;
90 } GRULayer;
91
92 typedef struct RNNModel {
93     int input_dense_size;
94     const DenseLayer *input_dense;
95
96     int vad_gru_size;
97     const GRULayer *vad_gru;
98
99     int noise_gru_size;
100     const GRULayer *noise_gru;
101
102     int denoise_gru_size;
103     const GRULayer *denoise_gru;
104
105     int denoise_output_size;
106     const DenseLayer *denoise_output;
107
108     int vad_output_size;
109     const DenseLayer *vad_output;
110 } RNNModel;
111
112 typedef struct RNNState {
113     float *vad_gru_state;
114     float *noise_gru_state;
115     float *denoise_gru_state;
116     RNNModel *model;
117 } RNNState;
118
119 typedef struct DenoiseState {
120     float analysis_mem[FRAME_SIZE];
121     float cepstral_mem[CEPS_MEM][NB_BANDS];
122     int memid;
123     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124     float pitch_buf[PITCH_BUF_SIZE];
125     float pitch_enh_buf[PITCH_BUF_SIZE];
126     float last_gain;
127     int last_period;
128     float mem_hp_x[2];
129     float lastg[NB_BANDS];
130     RNNState rnn;
131     AVTXContext *tx, *txi;
132     av_tx_fn tx_fn, txi_fn;
133 } DenoiseState;
134
135 typedef struct AudioRNNContext {
136     const AVClass *class;
137
138     char *model_name;
139
140     int channels;
141     DenoiseState *st;
142
143     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144     float dct_table[NB_BANDS*NB_BANDS];
145
146     RNNModel *model;
147
148     AVFloatDSPContext *fdsp;
149 } AudioRNNContext;
150
151 #define F_ACTIVATION_TANH       0
152 #define F_ACTIVATION_SIGMOID    1
153 #define F_ACTIVATION_RELU       2
154
155 static void rnnoise_model_free(RNNModel *model)
156 {
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
159     if (model->name) { \
160         av_free((void *) model->name->input_weights); \
161         av_free((void *) model->name->bias); \
162         av_free((void *) model->name); \
163     } \
164     } while (0)
165 #define FREE_GRU(name) do { \
166     if (model->name) { \
167         av_free((void *) model->name->input_weights); \
168         av_free((void *) model->name->recurrent_weights); \
169         av_free((void *) model->name->bias); \
170         av_free((void *) model->name); \
171     } \
172     } while (0)
173
174     if (!model)
175         return;
176     FREE_DENSE(input_dense);
177     FREE_GRU(vad_gru);
178     FREE_GRU(noise_gru);
179     FREE_GRU(denoise_gru);
180     FREE_DENSE(denoise_output);
181     FREE_DENSE(vad_output);
182     av_free(model);
183 }
184
185 static RNNModel *rnnoise_model_from_file(FILE *f)
186 {
187     RNNModel *ret;
188     DenseLayer *input_dense;
189     GRULayer *vad_gru;
190     GRULayer *noise_gru;
191     GRULayer *denoise_gru;
192     DenseLayer *denoise_output;
193     DenseLayer *vad_output;
194     int in;
195
196     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
197         return NULL;
198
199     ret = av_calloc(1, sizeof(RNNModel));
200     if (!ret)
201         return NULL;
202
203 #define ALLOC_LAYER(type, name) \
204     name = av_calloc(1, sizeof(type)); \
205     if (!name) { \
206         rnnoise_model_free(ret); \
207         return NULL; \
208     } \
209     ret->name = name
210
211     ALLOC_LAYER(DenseLayer, input_dense);
212     ALLOC_LAYER(GRULayer, vad_gru);
213     ALLOC_LAYER(GRULayer, noise_gru);
214     ALLOC_LAYER(GRULayer, denoise_gru);
215     ALLOC_LAYER(DenseLayer, denoise_output);
216     ALLOC_LAYER(DenseLayer, vad_output);
217
218 #define INPUT_VAL(name) do { \
219     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220         rnnoise_model_free(ret); \
221         return NULL; \
222     } \
223     name = in; \
224     } while (0)
225
226 #define INPUT_ACTIVATION(name) do { \
227     int activation; \
228     INPUT_VAL(activation); \
229     switch (activation) { \
230     case F_ACTIVATION_SIGMOID: \
231         name = ACTIVATION_SIGMOID; \
232         break; \
233     case F_ACTIVATION_RELU: \
234         name = ACTIVATION_RELU; \
235         break; \
236     default: \
237         name = ACTIVATION_TANH; \
238     } \
239     } while (0)
240
241 #define INPUT_ARRAY(name, len) do { \
242     float *values = av_calloc((len), sizeof(float)); \
243     if (!values) { \
244         rnnoise_model_free(ret); \
245         return NULL; \
246     } \
247     name = values; \
248     for (int i = 0; i < (len); i++) { \
249         if (fscanf(f, "%d", &in) != 1) { \
250             rnnoise_model_free(ret); \
251             return NULL; \
252         } \
253         values[i] = in; \
254     } \
255     } while (0)
256
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
259     if (!values) { \
260         rnnoise_model_free(ret); \
261         return NULL; \
262     } \
263     name = values; \
264     for (int k = 0; k < (len0); k++) { \
265         for (int i = 0; i < (len2); i++) { \
266             for (int j = 0; j < (len1); j++) { \
267                 if (fscanf(f, "%d", &in) != 1) { \
268                     rnnoise_model_free(ret); \
269                     return NULL; \
270                 } \
271                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
272             } \
273         } \
274     } \
275     } while (0)
276
277 #define INPUT_DENSE(name) do { \
278     INPUT_VAL(name->nb_inputs); \
279     INPUT_VAL(name->nb_neurons); \
280     ret->name ## _size = name->nb_neurons; \
281     INPUT_ACTIVATION(name->activation); \
282     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283     INPUT_ARRAY(name->bias, name->nb_neurons); \
284     } while (0)
285
286 #define INPUT_GRU(name) do { \
287     INPUT_VAL(name->nb_inputs); \
288     INPUT_VAL(name->nb_neurons); \
289     ret->name ## _size = name->nb_neurons; \
290     INPUT_ACTIVATION(name->activation); \
291     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
294     } while (0)
295
296     INPUT_DENSE(input_dense);
297     INPUT_GRU(vad_gru);
298     INPUT_GRU(noise_gru);
299     INPUT_GRU(denoise_gru);
300     INPUT_DENSE(denoise_output);
301     INPUT_DENSE(vad_output);
302
303     return ret;
304 }
305
306 static int query_formats(AVFilterContext *ctx)
307 {
308     AVFilterFormats *formats = NULL;
309     AVFilterChannelLayouts *layouts = NULL;
310     static const enum AVSampleFormat sample_fmts[] = {
311         AV_SAMPLE_FMT_FLTP,
312         AV_SAMPLE_FMT_NONE
313     };
314     int ret, sample_rates[] = { 48000, -1 };
315
316     formats = ff_make_format_list(sample_fmts);
317     if (!formats)
318         return AVERROR(ENOMEM);
319     ret = ff_set_common_formats(ctx, formats);
320     if (ret < 0)
321         return ret;
322
323     layouts = ff_all_channel_counts();
324     if (!layouts)
325         return AVERROR(ENOMEM);
326
327     ret = ff_set_common_channel_layouts(ctx, layouts);
328     if (ret < 0)
329         return ret;
330
331     formats = ff_make_format_list(sample_rates);
332     if (!formats)
333         return AVERROR(ENOMEM);
334     return ff_set_common_samplerates(ctx, formats);
335 }
336
337 static int config_input(AVFilterLink *inlink)
338 {
339     AVFilterContext *ctx = inlink->dst;
340     AudioRNNContext *s = ctx->priv;
341     int ret;
342
343     s->channels = inlink->channels;
344
345     s->st = av_calloc(s->channels, sizeof(DenoiseState));
346     if (!s->st)
347         return AVERROR(ENOMEM);
348
349     for (int i = 0; i < s->channels; i++) {
350         DenoiseState *st = &s->st[i];
351
352         st->rnn.model = s->model;
353         st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
354         st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
355         st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
356         if (!st->rnn.vad_gru_state ||
357             !st->rnn.noise_gru_state ||
358             !st->rnn.denoise_gru_state)
359             return AVERROR(ENOMEM);
360
361         ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
362         if (ret < 0)
363             return ret;
364
365         ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
366         if (ret < 0)
367             return ret;
368     }
369
370     return 0;
371 }
372
373 static void biquad(float *y, float mem[2], const float *x,
374                    const float *b, const float *a, int N)
375 {
376     for (int i = 0; i < N; i++) {
377         float xi, yi;
378
379         xi = x[i];
380         yi = x[i] + mem[0];
381         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
382         mem[1] = (b[1]*xi - a[1]*yi);
383         y[i] = yi;
384     }
385 }
386
387 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
388 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
389 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
390
391 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
392 {
393     AVComplexFloat x[WINDOW_SIZE];
394     AVComplexFloat y[WINDOW_SIZE];
395
396     for (int i = 0; i < WINDOW_SIZE; i++) {
397         x[i].re = in[i];
398         x[i].im = 0;
399     }
400
401     st->tx_fn(st->tx, y, x, sizeof(float));
402
403     RNN_COPY(out, y, FREQ_SIZE);
404 }
405
406 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
407 {
408     AVComplexFloat x[WINDOW_SIZE];
409     AVComplexFloat y[WINDOW_SIZE];
410
411     for (int i = 0; i < FREQ_SIZE; i++)
412         x[i] = in[i];
413
414     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
415         x[i].re =  x[WINDOW_SIZE - i].re;
416         x[i].im = -x[WINDOW_SIZE - i].im;
417     }
418
419     st->txi_fn(st->txi, y, x, sizeof(float));
420
421     for (int i = 0; i < WINDOW_SIZE; i++)
422         out[i] = y[i].re / WINDOW_SIZE;
423 }
424
425 static const uint8_t eband5ms[] = {
426 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
427   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
428 };
429
430 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
431 {
432     float sum[NB_BANDS] = {0};
433
434     for (int i = 0; i < NB_BANDS - 1; i++) {
435         int band_size;
436
437         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
438         for (int j = 0; j < band_size; j++) {
439             float tmp, frac = (float)j / band_size;
440
441             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
442             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
443             sum[i]     += (1.f - frac) * tmp;
444             sum[i + 1] +=        frac  * tmp;
445         }
446     }
447
448     sum[0] *= 2;
449     sum[NB_BANDS - 1] *= 2;
450
451     for (int i = 0; i < NB_BANDS; i++)
452         bandE[i] = sum[i];
453 }
454
455 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
456 {
457     float sum[NB_BANDS] = { 0 };
458
459     for (int i = 0; i < NB_BANDS - 1; i++) {
460         int band_size;
461
462         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
463         for (int j = 0; j < band_size; j++) {
464             float tmp, frac = (float)j / band_size;
465
466             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
467             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
468             sum[i]     += (1 - frac) * tmp;
469             sum[i + 1] +=      frac  * tmp;
470         }
471     }
472
473     sum[0] *= 2;
474     sum[NB_BANDS-1] *= 2;
475
476     for (int i = 0; i < NB_BANDS; i++)
477         bandE[i] = sum[i];
478 }
479
480 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
481 {
482     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
483
484     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
485     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
486     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
487     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
488     forward_transform(st, X, x);
489     compute_band_energy(Ex, X);
490 }
491
492 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
493 {
494     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
495
496     inverse_transform(st, x, y);
497     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
498     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
499     RNN_COPY(out, x, FRAME_SIZE);
500     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
501 }
502
503 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
504 {
505     float y_0, y_1, y_2, y_3 = 0;
506     int j;
507
508     y_0 = *y++;
509     y_1 = *y++;
510     y_2 = *y++;
511
512     for (j = 0; j < len - 3; j += 4) {
513         float tmp;
514
515         tmp = *x++;
516         y_3 = *y++;
517         sum[0] += tmp * y_0;
518         sum[1] += tmp * y_1;
519         sum[2] += tmp * y_2;
520         sum[3] += tmp * y_3;
521         tmp = *x++;
522         y_0 = *y++;
523         sum[0] += tmp * y_1;
524         sum[1] += tmp * y_2;
525         sum[2] += tmp * y_3;
526         sum[3] += tmp * y_0;
527         tmp = *x++;
528         y_1 = *y++;
529         sum[0] += tmp * y_2;
530         sum[1] += tmp * y_3;
531         sum[2] += tmp * y_0;
532         sum[3] += tmp * y_1;
533         tmp = *x++;
534         y_2 = *y++;
535         sum[0] += tmp * y_3;
536         sum[1] += tmp * y_0;
537         sum[2] += tmp * y_1;
538         sum[3] += tmp * y_2;
539     }
540
541     if (j++ < len) {
542         float tmp = *x++;
543
544         y_3 = *y++;
545         sum[0] += tmp * y_0;
546         sum[1] += tmp * y_1;
547         sum[2] += tmp * y_2;
548         sum[3] += tmp * y_3;
549     }
550
551     if (j++ < len) {
552         float tmp=*x++;
553
554         y_0 = *y++;
555         sum[0] += tmp * y_1;
556         sum[1] += tmp * y_2;
557         sum[2] += tmp * y_3;
558         sum[3] += tmp * y_0;
559     }
560
561     if (j < len) {
562         float tmp=*x++;
563
564         y_1 = *y++;
565         sum[0] += tmp * y_2;
566         sum[1] += tmp * y_3;
567         sum[2] += tmp * y_0;
568         sum[3] += tmp * y_1;
569     }
570 }
571
572 static inline float celt_inner_prod(const float *x,
573                                     const float *y, int N)
574 {
575     float xy = 0.f;
576
577     for (int i = 0; i < N; i++)
578         xy += x[i] * y[i];
579
580     return xy;
581 }
582
583 static void celt_pitch_xcorr(const float *x, const float *y,
584                              float *xcorr, int len, int max_pitch)
585 {
586     int i;
587
588     for (i = 0; i < max_pitch - 3; i += 4) {
589         float sum[4] = { 0, 0, 0, 0};
590
591         xcorr_kernel(x, y + i, sum, len);
592
593         xcorr[i]     = sum[0];
594         xcorr[i + 1] = sum[1];
595         xcorr[i + 2] = sum[2];
596         xcorr[i + 3] = sum[3];
597     }
598     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
599     for (; i < max_pitch; i++) {
600         xcorr[i] = celt_inner_prod(x, y + i, len);
601     }
602 }
603
604 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
605                          float       *ac,  /* out: [0...lag-1] ac values */
606                          const float *window,
607                          int          overlap,
608                          int          lag,
609                          int          n)
610 {
611     int fastN = n - lag;
612     int shift;
613     const float *xptr;
614     float xx[PITCH_BUF_SIZE>>1];
615
616     if (overlap == 0) {
617         xptr = x;
618     } else {
619         for (int i = 0; i < n; i++)
620             xx[i] = x[i];
621         for (int i = 0; i < overlap; i++) {
622             xx[i] = x[i] * window[i];
623             xx[n-i-1] = x[n-i-1] * window[i];
624         }
625         xptr = xx;
626     }
627
628     shift = 0;
629     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
630
631     for (int k = 0; k <= lag; k++) {
632         float d = 0.f;
633
634         for (int i = k + fastN; i < n; i++)
635             d += xptr[i] * xptr[i-k];
636         ac[k] += d;
637     }
638
639     return shift;
640 }
641
642 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
643                 const float *ac,   /* in:  [0...p] autocorrelation values  */
644                           int p)
645 {
646     float r, error = ac[0];
647
648     RNN_CLEAR(lpc, p);
649     if (ac[0] != 0) {
650         for (int i = 0; i < p; i++) {
651             /* Sum up this iteration's reflection coefficient */
652             float rr = 0;
653             for (int j = 0; j < i; j++)
654                 rr += (lpc[j] * ac[i - j]);
655             rr += ac[i + 1];
656             r = -rr/error;
657             /*  Update LPC coefficients and total error */
658             lpc[i] = r;
659             for (int j = 0; j < (i + 1) >> 1; j++) {
660                 float tmp1, tmp2;
661                 tmp1 = lpc[j];
662                 tmp2 = lpc[i-1-j];
663                 lpc[j]     = tmp1 + (r*tmp2);
664                 lpc[i-1-j] = tmp2 + (r*tmp1);
665             }
666
667             error = error - (r * r *error);
668             /* Bail out once we get 30 dB gain */
669             if (error < .001f * ac[0])
670                 break;
671         }
672     }
673 }
674
675 static void celt_fir5(const float *x,
676                       const float *num,
677                       float *y,
678                       int N,
679                       float *mem)
680 {
681     float num0, num1, num2, num3, num4;
682     float mem0, mem1, mem2, mem3, mem4;
683
684     num0 = num[0];
685     num1 = num[1];
686     num2 = num[2];
687     num3 = num[3];
688     num4 = num[4];
689     mem0 = mem[0];
690     mem1 = mem[1];
691     mem2 = mem[2];
692     mem3 = mem[3];
693     mem4 = mem[4];
694
695     for (int i = 0; i < N; i++) {
696         float sum = x[i];
697
698         sum += (num0*mem0);
699         sum += (num1*mem1);
700         sum += (num2*mem2);
701         sum += (num3*mem3);
702         sum += (num4*mem4);
703         mem4 = mem3;
704         mem3 = mem2;
705         mem2 = mem1;
706         mem1 = mem0;
707         mem0 = x[i];
708         y[i] = sum;
709     }
710
711     mem[0] = mem0;
712     mem[1] = mem1;
713     mem[2] = mem2;
714     mem[3] = mem3;
715     mem[4] = mem4;
716 }
717
718 static void pitch_downsample(float *x[], float *x_lp,
719                              int len, int C)
720 {
721     float ac[5];
722     float tmp=Q15ONE;
723     float lpc[4], mem[5]={0,0,0,0,0};
724     float lpc2[5];
725     float c1 = .8f;
726
727     for (int i = 1; i < len >> 1; i++)
728         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
729     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
730     if (C==2) {
731         for (int i = 1; i < len >> 1; i++)
732             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
733         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
734     }
735
736     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
737
738     /* Noise floor -40 dB */
739     ac[0] *= 1.0001f;
740     /* Lag windowing */
741     for (int i = 1; i <= 4; i++) {
742         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
743         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
744     }
745
746     celt_lpc(lpc, ac, 4);
747     for (int i = 0; i < 4; i++) {
748         tmp = .9f * tmp;
749         lpc[i] = (lpc[i] * tmp);
750     }
751     /* Add a zero */
752     lpc2[0] = lpc[0] + .8f;
753     lpc2[1] = lpc[1] + (c1 * lpc[0]);
754     lpc2[2] = lpc[2] + (c1 * lpc[1]);
755     lpc2[3] = lpc[3] + (c1 * lpc[2]);
756     lpc2[4] = (c1 * lpc[3]);
757     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
758 }
759
760 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
761                                    int N, float *xy1, float *xy2)
762 {
763     float xy01 = 0, xy02 = 0;
764
765     for (int i = 0; i < N; i++) {
766         xy01 += (x[i] * y01[i]);
767         xy02 += (x[i] * y02[i]);
768     }
769
770     *xy1 = xy01;
771     *xy2 = xy02;
772 }
773
774 static float compute_pitch_gain(float xy, float xx, float yy)
775 {
776     return xy / sqrtf(1.f + xx * yy);
777 }
778
779 static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
780 static const float remove_doubling(float *x, int maxperiod, int minperiod,
781                                    int N, int *T0_, int prev_period, float prev_gain)
782 {
783     int k, i, T, T0;
784     float g, g0;
785     float pg;
786     float xy,xx,yy,xy2;
787     float xcorr[3];
788     float best_xy, best_yy;
789     int offset;
790     int minperiod0;
791     float yy_lookup[PITCH_MAX_PERIOD+1];
792
793     minperiod0 = minperiod;
794     maxperiod /= 2;
795     minperiod /= 2;
796     *T0_ /= 2;
797     prev_period /= 2;
798     N /= 2;
799     x += maxperiod;
800     if (*T0_>=maxperiod)
801         *T0_=maxperiod-1;
802
803     T = T0 = *T0_;
804     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
805     yy_lookup[0] = xx;
806     yy=xx;
807     for (i = 1; i <= maxperiod; i++) {
808         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
809         yy_lookup[i] = FFMAX(0, yy);
810     }
811     yy = yy_lookup[T0];
812     best_xy = xy;
813     best_yy = yy;
814     g = g0 = compute_pitch_gain(xy, xx, yy);
815     /* Look for any pitch at T/k */
816     for (k = 2; k <= 15; k++) {
817         int T1, T1b;
818         float g1;
819         float cont=0;
820         float thresh;
821         T1 = (2*T0+k)/(2*k);
822         if (T1 < minperiod)
823             break;
824         /* Look for another strong correlation at T1b */
825         if (k==2)
826         {
827             if (T1+T0>maxperiod)
828                 T1b = T0;
829             else
830                 T1b = T0+T1;
831         } else
832         {
833             T1b = (2*second_check[k]*T0+k)/(2*k);
834         }
835         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
836         xy = .5f * (xy + xy2);
837         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
838         g1 = compute_pitch_gain(xy, xx, yy);
839         if (FFABS(T1-prev_period)<=1)
840             cont = prev_gain;
841         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
842             cont = prev_gain * .5f;
843         else
844             cont = 0;
845         thresh = FFMAX(.3f, (.7f * g0) - cont);
846         /* Bias against very high pitch (very short period) to avoid false-positives
847            due to short-term correlation */
848         if (T1<3*minperiod)
849             thresh = FFMAX(.4f, (.85f * g0) - cont);
850         else if (T1<2*minperiod)
851             thresh = FFMAX(.5f, (.9f * g0) - cont);
852         if (g1 > thresh)
853         {
854             best_xy = xy;
855             best_yy = yy;
856             T = T1;
857             g = g1;
858         }
859     }
860     best_xy = FFMAX(0, best_xy);
861     if (best_yy <= best_xy)
862         pg = Q15ONE;
863     else
864         pg = best_xy/(best_yy + 1);
865
866     for (k = 0; k < 3; k++)
867         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
868     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
869         offset = 1;
870     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
871         offset = -1;
872     else
873         offset = 0;
874     if (pg > g)
875         pg = g;
876     *T0_ = 2*T+offset;
877
878     if (*T0_<minperiod0)
879         *T0_=minperiod0;
880     return pg;
881 }
882
883 static void find_best_pitch(float *xcorr, float *y, int len,
884                             int max_pitch, int *best_pitch)
885 {
886     float best_num[2];
887     float best_den[2];
888     float Syy = 1.f;
889
890     best_num[0] = -1;
891     best_num[1] = -1;
892     best_den[0] = 0;
893     best_den[1] = 0;
894     best_pitch[0] = 0;
895     best_pitch[1] = 1;
896
897     for (int j = 0; j < len; j++)
898         Syy += y[j] * y[j];
899
900     for (int i = 0; i < max_pitch; i++) {
901         if (xcorr[i]>0) {
902             float num;
903             float xcorr16;
904
905             xcorr16 = xcorr[i];
906             /* Considering the range of xcorr16, this should avoid both underflows
907                and overflows (inf) when squaring xcorr16 */
908             xcorr16 *= 1e-12f;
909             num = xcorr16 * xcorr16;
910             if ((num * best_den[1]) > (best_num[1] * Syy)) {
911                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
912                     best_num[1] = best_num[0];
913                     best_den[1] = best_den[0];
914                     best_pitch[1] = best_pitch[0];
915                     best_num[0] = num;
916                     best_den[0] = Syy;
917                     best_pitch[0] = i;
918                 } else {
919                     best_num[1] = num;
920                     best_den[1] = Syy;
921                     best_pitch[1] = i;
922                 }
923             }
924         }
925         Syy += y[i+len]*y[i+len] - y[i] * y[i];
926         Syy = FFMAX(1, Syy);
927     }
928 }
929
930 static void pitch_search(const float *x_lp, float *y,
931                          int len, int max_pitch, int *pitch)
932 {
933     int lag;
934     int best_pitch[2]={0,0};
935     int offset;
936
937     float x_lp4[WINDOW_SIZE];
938     float y_lp4[WINDOW_SIZE];
939     float xcorr[WINDOW_SIZE];
940
941     lag = len+max_pitch;
942
943     /* Downsample by 2 again */
944     for (int j = 0; j < len >> 2; j++)
945         x_lp4[j] = x_lp[2*j];
946     for (int j = 0; j < lag >> 2; j++)
947         y_lp4[j] = y[2*j];
948
949     /* Coarse search with 4x decimation */
950
951     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
952
953     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
954
955     /* Finer search with 2x decimation */
956     for (int i = 0; i < max_pitch >> 1; i++) {
957         float sum;
958         xcorr[i] = 0;
959         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
960             continue;
961         sum = celt_inner_prod(x_lp, y+i, len>>1);
962         xcorr[i] = FFMAX(-1, sum);
963     }
964
965     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
966
967     /* Refine by pseudo-interpolation */
968     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
969         float a, b, c;
970
971         a = xcorr[best_pitch[0] - 1];
972         b = xcorr[best_pitch[0]];
973         c = xcorr[best_pitch[0] + 1];
974         if (c - a > .7f * (b - a))
975             offset = 1;
976         else if (a - c > .7f * (b-c))
977             offset = -1;
978         else
979             offset = 0;
980     } else {
981         offset = 0;
982     }
983
984     *pitch = 2 * best_pitch[0] - offset;
985 }
986
987 static void dct(AudioRNNContext *s, float *out, const float *in)
988 {
989     for (int i = 0; i < NB_BANDS; i++) {
990         float sum = 0.f;
991
992         for (int j = 0; j < NB_BANDS; j++) {
993             sum += in[j] * s->dct_table[j * NB_BANDS + i];
994         }
995         out[i] = sum * sqrtf(2.f / 22);
996     }
997 }
998
999 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1000                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1001 {
1002     float E = 0;
1003     float *ceps_0, *ceps_1, *ceps_2;
1004     float spec_variability = 0;
1005     float Ly[NB_BANDS];
1006     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1007     float pitch_buf[PITCH_BUF_SIZE>>1];
1008     int pitch_index;
1009     float gain;
1010     float *(pre[1]);
1011     float tmp[NB_BANDS];
1012     float follow, logMax;
1013
1014     frame_analysis(s, st, X, Ex, in);
1015     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1016     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1017     pre[0] = &st->pitch_buf[0];
1018     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1019     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1020             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1021     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1022
1023     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1024             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1025     st->last_period = pitch_index;
1026     st->last_gain = gain;
1027
1028     for (int i = 0; i < WINDOW_SIZE; i++)
1029         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1030
1031     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1032     forward_transform(st, P, p);
1033     compute_band_energy(Ep, P);
1034     compute_band_corr(Exp, X, P);
1035
1036     for (int i = 0; i < NB_BANDS; i++)
1037         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1038
1039     dct(s, tmp, Exp);
1040
1041     for (int i = 0; i < NB_DELTA_CEPS; i++)
1042         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1043
1044     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1045     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1046     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1047     logMax = -2;
1048     follow = -2;
1049
1050     for (int i = 0; i < NB_BANDS; i++) {
1051         Ly[i] = log10f(1e-2f + Ex[i]);
1052         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1053         logMax = FFMAX(logMax, Ly[i]);
1054         follow = FFMAX(follow-1.5, Ly[i]);
1055         E += Ex[i];
1056     }
1057
1058     if (E < 0.04f) {
1059         /* If there's no audio, avoid messing up the state. */
1060         RNN_CLEAR(features, NB_FEATURES);
1061         return 1;
1062     }
1063
1064     dct(s, features, Ly);
1065     features[0] -= 12;
1066     features[1] -= 4;
1067     ceps_0 = st->cepstral_mem[st->memid];
1068     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1069     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1070
1071     for (int i = 0; i < NB_BANDS; i++)
1072         ceps_0[i] = features[i];
1073
1074     st->memid++;
1075     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1076         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1077         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1078         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1079     }
1080     /* Spectral variability features. */
1081     if (st->memid == CEPS_MEM)
1082         st->memid = 0;
1083
1084     for (int i = 0; i < CEPS_MEM; i++) {
1085         float mindist = 1e15f;
1086         for (int j = 0; j < CEPS_MEM; j++) {
1087             float dist = 0.f;
1088             for (int k = 0; k < NB_BANDS; k++) {
1089                 float tmp;
1090
1091                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1092                 dist += tmp*tmp;
1093             }
1094
1095             if (j != i)
1096                 mindist = FFMIN(mindist, dist);
1097         }
1098
1099         spec_variability += mindist;
1100     }
1101
1102     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1103
1104     return 0;
1105 }
1106
1107 static void interp_band_gain(float *g, const float *bandE)
1108 {
1109     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1110
1111     for (int i = 0; i < NB_BANDS - 1; i++) {
1112         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1113
1114         for (int j = 0; j < band_size; j++) {
1115             float frac = (float)j / band_size;
1116
1117             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1118         }
1119     }
1120 }
1121
1122 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1123                          const float *Exp, const float *g)
1124 {
1125     float newE[NB_BANDS];
1126     float r[NB_BANDS];
1127     float norm[NB_BANDS];
1128     float rf[FREQ_SIZE] = {0};
1129     float normf[FREQ_SIZE]={0};
1130
1131     for (int i = 0; i < NB_BANDS; i++) {
1132         if (Exp[i]>g[i]) r[i] = 1;
1133         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1134         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1135         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1136     }
1137     interp_band_gain(rf, r);
1138     for (int i = 0; i < FREQ_SIZE; i++) {
1139         X[i].re += rf[i]*P[i].re;
1140         X[i].im += rf[i]*P[i].im;
1141     }
1142     compute_band_energy(newE, X);
1143     for (int i = 0; i < NB_BANDS; i++) {
1144         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1145     }
1146     interp_band_gain(normf, norm);
1147     for (int i = 0; i < FREQ_SIZE; i++) {
1148         X[i].re *= normf[i];
1149         X[i].im *= normf[i];
1150     }
1151 }
1152
1153 static const float tansig_table[201] = {
1154     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1155     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1156     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1157     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1158     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1159     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1160     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1161     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1162     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1163     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1164     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1165     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1166     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1167     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1168     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1169     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1170     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1171     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1172     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1173     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1174     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1175     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1176     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1177     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1178     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1179     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1180     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1181     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1182     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1183     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1184     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1185     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1186     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1187     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1188     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1189     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1190     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1191     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1192     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1193     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1194     1.000000f,
1195 };
1196
1197 static inline float tansig_approx(float x)
1198 {
1199     float y, dy;
1200     float sign=1;
1201     int i;
1202
1203     /* Tests are reversed to catch NaNs */
1204     if (!(x<8))
1205         return 1;
1206     if (!(x>-8))
1207         return -1;
1208     /* Another check in case of -ffast-math */
1209
1210     if (isnan(x))
1211        return 0;
1212
1213     if (x < 0) {
1214        x=-x;
1215        sign=-1;
1216     }
1217     i = (int)floor(.5f+25*x);
1218     x -= .04f*i;
1219     y = tansig_table[i];
1220     dy = 1-y*y;
1221     y = y + x*dy*(1 - y*x);
1222     return sign*y;
1223 }
1224
1225 static inline float sigmoid_approx(float x)
1226 {
1227     return .5f + .5f*tansig_approx(.5f*x);
1228 }
1229
1230 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1231 {
1232     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1233
1234     for (int i = 0; i < N; i++) {
1235         /* Compute update gate. */
1236         float sum = layer->bias[i];
1237
1238         for (int j = 0; j < M; j++)
1239             sum += layer->input_weights[j * stride + i] * input[j];
1240
1241         output[i] = WEIGHTS_SCALE * sum;
1242     }
1243
1244     if (layer->activation == ACTIVATION_SIGMOID) {
1245         for (int i = 0; i < N; i++)
1246             output[i] = sigmoid_approx(output[i]);
1247     } else if (layer->activation == ACTIVATION_TANH) {
1248         for (int i = 0; i < N; i++)
1249             output[i] = tansig_approx(output[i]);
1250     } else if (layer->activation == ACTIVATION_RELU) {
1251         for (int i = 0; i < N; i++)
1252             output[i] = FFMAX(0, output[i]);
1253     } else {
1254         av_assert0(0);
1255     }
1256 }
1257
1258 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1259 {
1260     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1261     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1262     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1263     const int M = gru->nb_inputs;
1264     const int N = gru->nb_neurons;
1265     const int AN = FFALIGN(N, 4);
1266     const int AM = FFALIGN(M, 4);
1267     const int stride = 3 * AN, istride = 3 * AM;
1268
1269     for (int i = 0; i < N; i++) {
1270         /* Compute update gate. */
1271         float sum = gru->bias[i];
1272
1273         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1274         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1275         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1276     }
1277
1278     for (int i = 0; i < N; i++) {
1279         /* Compute reset gate. */
1280         float sum = gru->bias[N + i];
1281
1282         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1283         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1284         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1285     }
1286
1287     for (int i = 0; i < N; i++) {
1288         /* Compute output. */
1289         float sum = gru->bias[2 * N + i];
1290
1291         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1292         for (int j = 0; j < N; j++)
1293             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1294
1295         if (gru->activation == ACTIVATION_SIGMOID)
1296             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1297         else if (gru->activation == ACTIVATION_TANH)
1298             sum = tansig_approx(WEIGHTS_SCALE * sum);
1299         else if (gru->activation == ACTIVATION_RELU)
1300             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1301         else
1302             av_assert0(0);
1303         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1304     }
1305
1306     RNN_COPY(state, h, N);
1307 }
1308
1309 #define INPUT_SIZE 42
1310
1311 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1312 {
1313     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1314     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1315     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1316
1317     compute_dense(rnn->model->input_dense, dense_out, input);
1318     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1319     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1320
1321     for (int i = 0; i < rnn->model->input_dense_size; i++)
1322         noise_input[i] = dense_out[i];
1323     for (int i = 0; i < rnn->model->vad_gru_size; i++)
1324         noise_input[i + rnn->model->input_dense_size] = rnn->vad_gru_state[i];
1325     for (int i = 0; i < INPUT_SIZE; i++)
1326         noise_input[i + rnn->model->input_dense_size + rnn->model->vad_gru_size] = input[i];
1327
1328     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1329
1330     for (int i = 0; i < rnn->model->vad_gru_size; i++)
1331         denoise_input[i] = rnn->vad_gru_state[i];
1332     for (int i = 0; i < rnn->model->noise_gru_size; i++)
1333         denoise_input[i + rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
1334     for (int i = 0; i < INPUT_SIZE; i++)
1335         denoise_input[i + rnn->model->vad_gru_size + rnn->model->noise_gru_size] = input[i];
1336
1337     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1338     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1339 }
1340
1341 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1342 {
1343     AVComplexFloat X[FREQ_SIZE];
1344     AVComplexFloat P[WINDOW_SIZE];
1345     float x[FRAME_SIZE];
1346     float Ex[NB_BANDS], Ep[NB_BANDS];
1347     float Exp[NB_BANDS];
1348     float features[NB_FEATURES];
1349     float g[NB_BANDS];
1350     float gf[FREQ_SIZE];
1351     float vad_prob = 0;
1352     static const float a_hp[2] = {-1.99599, 0.99600};
1353     static const float b_hp[2] = {-2, 1};
1354     int silence;
1355
1356     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1357     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1358
1359     if (!silence) {
1360         compute_rnn(s, &st->rnn, g, &vad_prob, features);
1361         pitch_filter(X, P, Ex, Ep, Exp, g);
1362         for (int i = 0; i < NB_BANDS; i++) {
1363             float alpha = .6f;
1364
1365             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1366             st->lastg[i] = g[i];
1367         }
1368
1369         interp_band_gain(gf, g);
1370
1371         for (int i = 0; i < FREQ_SIZE; i++) {
1372             X[i].re *= gf[i];
1373             X[i].im *= gf[i];
1374         }
1375     }
1376
1377     frame_synthesis(s, st, out, X);
1378
1379     return vad_prob;
1380 }
1381
1382 typedef struct ThreadData {
1383     AVFrame *in, *out;
1384 } ThreadData;
1385
1386 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1387 {
1388     AudioRNNContext *s = ctx->priv;
1389     ThreadData *td = arg;
1390     AVFrame *in = td->in;
1391     AVFrame *out = td->out;
1392     const int start = (out->channels * jobnr) / nb_jobs;
1393     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1394
1395     for (int ch = start; ch < end; ch++) {
1396         rnnoise_channel(s, &s->st[ch],
1397                         (float *)out->extended_data[ch],
1398                         (const float *)in->extended_data[ch]);
1399     }
1400
1401     return 0;
1402 }
1403
1404 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1405 {
1406     AVFilterContext *ctx = inlink->dst;
1407     AVFilterLink *outlink = ctx->outputs[0];
1408     AVFrame *out = NULL;
1409     ThreadData td;
1410
1411     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1412     if (!out) {
1413         av_frame_free(&in);
1414         return AVERROR(ENOMEM);
1415     }
1416     out->pts = in->pts;
1417
1418     td.in = in; td.out = out;
1419     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1420                                                                    ff_filter_get_nb_threads(ctx)));
1421
1422     av_frame_free(&in);
1423     return ff_filter_frame(outlink, out);
1424 }
1425
1426 static int activate(AVFilterContext *ctx)
1427 {
1428     AVFilterLink *inlink = ctx->inputs[0];
1429     AVFilterLink *outlink = ctx->outputs[0];
1430     AVFrame *in = NULL;
1431     int ret;
1432
1433     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1434
1435     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1436     if (ret < 0)
1437         return ret;
1438
1439     if (ret > 0)
1440         return filter_frame(inlink, in);
1441
1442     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1443     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1444
1445     return FFERROR_NOT_READY;
1446 }
1447
1448 static av_cold int init(AVFilterContext *ctx)
1449 {
1450     AudioRNNContext *s = ctx->priv;
1451     FILE *f;
1452
1453     s->fdsp = avpriv_float_dsp_alloc(0);
1454     if (!s->fdsp)
1455         return AVERROR(ENOMEM);
1456
1457     if (!s->model_name)
1458         return AVERROR(EINVAL);
1459     f = av_fopen_utf8(s->model_name, "r");
1460     if (!f)
1461         return AVERROR(EINVAL);
1462
1463     s->model = rnnoise_model_from_file(f);
1464     fclose(f);
1465     if (!s->model)
1466         return AVERROR(EINVAL);
1467
1468     for (int i = 0; i < FRAME_SIZE; i++) {
1469         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1470         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1471     }
1472
1473     for (int i = 0; i < NB_BANDS; i++) {
1474         for (int j = 0; j < NB_BANDS; j++) {
1475             s->dct_table[i*NB_BANDS + j] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1476             if (j == 0)
1477                 s->dct_table[i*NB_BANDS + j] *= sqrtf(.5);
1478         }
1479     }
1480
1481     return 0;
1482 }
1483
1484 static av_cold void uninit(AVFilterContext *ctx)
1485 {
1486     AudioRNNContext *s = ctx->priv;
1487
1488     av_freep(&s->fdsp);
1489     rnnoise_model_free(s->model);
1490     s->model = NULL;
1491
1492     if (s->st) {
1493         for (int ch = 0; ch < s->channels; ch++) {
1494             av_freep(&s->st[ch].rnn.vad_gru_state);
1495             av_freep(&s->st[ch].rnn.noise_gru_state);
1496             av_freep(&s->st[ch].rnn.denoise_gru_state);
1497             av_tx_uninit(&s->st[ch].tx);
1498             av_tx_uninit(&s->st[ch].txi);
1499         }
1500     }
1501     av_freep(&s->st);
1502 }
1503
1504 static const AVFilterPad inputs[] = {
1505     {
1506         .name         = "default",
1507         .type         = AVMEDIA_TYPE_AUDIO,
1508         .config_props = config_input,
1509     },
1510     { NULL }
1511 };
1512
1513 static const AVFilterPad outputs[] = {
1514     {
1515         .name          = "default",
1516         .type          = AVMEDIA_TYPE_AUDIO,
1517     },
1518     { NULL }
1519 };
1520
1521 #define OFFSET(x) offsetof(AudioRNNContext, x)
1522 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1523
1524 static const AVOption arnndn_options[] = {
1525     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1526     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1527     { NULL }
1528 };
1529
1530 AVFILTER_DEFINE_CLASS(arnndn);
1531
1532 AVFilter ff_af_arnndn = {
1533     .name          = "arnndn",
1534     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1535     .query_formats = query_formats,
1536     .priv_size     = sizeof(AudioRNNContext),
1537     .priv_class    = &arnndn_class,
1538     .activate      = activate,
1539     .init          = init,
1540     .uninit        = uninit,
1541     .inputs        = inputs,
1542     .outputs       = outputs,
1543     .flags         = AVFILTER_FLAG_SLICE_THREADS,
1544 };