]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
avfilter/af_arnndn: use RNN_COPY macro to copy
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
41 #include "avfilter.h"
42 #include "audio.h"
43 #include "filters.h"
44 #include "formats.h"
45
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56 #define SQUARE(x) ((x)*(x))
57
58 #define NB_BANDS 22
59
60 #define CEPS_MEM 8
61 #define NB_DELTA_CEPS 6
62
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65 #define WEIGHTS_SCALE (1.f/256)
66
67 #define MAX_NEURONS 128
68
69 #define ACTIVATION_TANH    0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU    2
72
73 #define Q15ONE 1.0f
74
75 typedef struct DenseLayer {
76     const float *bias;
77     const float *input_weights;
78     int nb_inputs;
79     int nb_neurons;
80     int activation;
81 } DenseLayer;
82
83 typedef struct GRULayer {
84     const float *bias;
85     const float *input_weights;
86     const float *recurrent_weights;
87     int nb_inputs;
88     int nb_neurons;
89     int activation;
90 } GRULayer;
91
92 typedef struct RNNModel {
93     int input_dense_size;
94     const DenseLayer *input_dense;
95
96     int vad_gru_size;
97     const GRULayer *vad_gru;
98
99     int noise_gru_size;
100     const GRULayer *noise_gru;
101
102     int denoise_gru_size;
103     const GRULayer *denoise_gru;
104
105     int denoise_output_size;
106     const DenseLayer *denoise_output;
107
108     int vad_output_size;
109     const DenseLayer *vad_output;
110 } RNNModel;
111
112 typedef struct RNNState {
113     float *vad_gru_state;
114     float *noise_gru_state;
115     float *denoise_gru_state;
116     RNNModel *model;
117 } RNNState;
118
119 typedef struct DenoiseState {
120     float analysis_mem[FRAME_SIZE];
121     float cepstral_mem[CEPS_MEM][NB_BANDS];
122     int memid;
123     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124     float pitch_buf[PITCH_BUF_SIZE];
125     float pitch_enh_buf[PITCH_BUF_SIZE];
126     float last_gain;
127     int last_period;
128     float mem_hp_x[2];
129     float lastg[NB_BANDS];
130     RNNState rnn;
131     AVTXContext *tx, *txi;
132     av_tx_fn tx_fn, txi_fn;
133 } DenoiseState;
134
135 typedef struct AudioRNNContext {
136     const AVClass *class;
137
138     char *model_name;
139
140     int channels;
141     DenoiseState *st;
142
143     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144     float dct_table[NB_BANDS*NB_BANDS];
145
146     RNNModel *model;
147
148     AVFloatDSPContext *fdsp;
149 } AudioRNNContext;
150
151 #define F_ACTIVATION_TANH       0
152 #define F_ACTIVATION_SIGMOID    1
153 #define F_ACTIVATION_RELU       2
154
155 static void rnnoise_model_free(RNNModel *model)
156 {
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
159     if (model->name) { \
160         av_free((void *) model->name->input_weights); \
161         av_free((void *) model->name->bias); \
162         av_free((void *) model->name); \
163     } \
164     } while (0)
165 #define FREE_GRU(name) do { \
166     if (model->name) { \
167         av_free((void *) model->name->input_weights); \
168         av_free((void *) model->name->recurrent_weights); \
169         av_free((void *) model->name->bias); \
170         av_free((void *) model->name); \
171     } \
172     } while (0)
173
174     if (!model)
175         return;
176     FREE_DENSE(input_dense);
177     FREE_GRU(vad_gru);
178     FREE_GRU(noise_gru);
179     FREE_GRU(denoise_gru);
180     FREE_DENSE(denoise_output);
181     FREE_DENSE(vad_output);
182     av_free(model);
183 }
184
185 static RNNModel *rnnoise_model_from_file(FILE *f)
186 {
187     RNNModel *ret;
188     DenseLayer *input_dense;
189     GRULayer *vad_gru;
190     GRULayer *noise_gru;
191     GRULayer *denoise_gru;
192     DenseLayer *denoise_output;
193     DenseLayer *vad_output;
194     int in;
195
196     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
197         return NULL;
198
199     ret = av_calloc(1, sizeof(RNNModel));
200     if (!ret)
201         return NULL;
202
203 #define ALLOC_LAYER(type, name) \
204     name = av_calloc(1, sizeof(type)); \
205     if (!name) { \
206         rnnoise_model_free(ret); \
207         return NULL; \
208     } \
209     ret->name = name
210
211     ALLOC_LAYER(DenseLayer, input_dense);
212     ALLOC_LAYER(GRULayer, vad_gru);
213     ALLOC_LAYER(GRULayer, noise_gru);
214     ALLOC_LAYER(GRULayer, denoise_gru);
215     ALLOC_LAYER(DenseLayer, denoise_output);
216     ALLOC_LAYER(DenseLayer, vad_output);
217
218 #define INPUT_VAL(name) do { \
219     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220         rnnoise_model_free(ret); \
221         return NULL; \
222     } \
223     name = in; \
224     } while (0)
225
226 #define INPUT_ACTIVATION(name) do { \
227     int activation; \
228     INPUT_VAL(activation); \
229     switch (activation) { \
230     case F_ACTIVATION_SIGMOID: \
231         name = ACTIVATION_SIGMOID; \
232         break; \
233     case F_ACTIVATION_RELU: \
234         name = ACTIVATION_RELU; \
235         break; \
236     default: \
237         name = ACTIVATION_TANH; \
238     } \
239     } while (0)
240
241 #define INPUT_ARRAY(name, len) do { \
242     float *values = av_calloc((len), sizeof(float)); \
243     if (!values) { \
244         rnnoise_model_free(ret); \
245         return NULL; \
246     } \
247     name = values; \
248     for (int i = 0; i < (len); i++) { \
249         if (fscanf(f, "%d", &in) != 1) { \
250             rnnoise_model_free(ret); \
251             return NULL; \
252         } \
253         values[i] = in; \
254     } \
255     } while (0)
256
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
259     if (!values) { \
260         rnnoise_model_free(ret); \
261         return NULL; \
262     } \
263     name = values; \
264     for (int k = 0; k < (len0); k++) { \
265         for (int i = 0; i < (len2); i++) { \
266             for (int j = 0; j < (len1); j++) { \
267                 if (fscanf(f, "%d", &in) != 1) { \
268                     rnnoise_model_free(ret); \
269                     return NULL; \
270                 } \
271                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
272             } \
273         } \
274     } \
275     } while (0)
276
277 #define INPUT_DENSE(name) do { \
278     INPUT_VAL(name->nb_inputs); \
279     INPUT_VAL(name->nb_neurons); \
280     ret->name ## _size = name->nb_neurons; \
281     INPUT_ACTIVATION(name->activation); \
282     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283     INPUT_ARRAY(name->bias, name->nb_neurons); \
284     } while (0)
285
286 #define INPUT_GRU(name) do { \
287     INPUT_VAL(name->nb_inputs); \
288     INPUT_VAL(name->nb_neurons); \
289     ret->name ## _size = name->nb_neurons; \
290     INPUT_ACTIVATION(name->activation); \
291     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
294     } while (0)
295
296     INPUT_DENSE(input_dense);
297     INPUT_GRU(vad_gru);
298     INPUT_GRU(noise_gru);
299     INPUT_GRU(denoise_gru);
300     INPUT_DENSE(denoise_output);
301     INPUT_DENSE(vad_output);
302
303     if (vad_output->nb_neurons != 1) {
304         rnnoise_model_free(ret);
305         return NULL;
306     }
307
308     return ret;
309 }
310
311 static int query_formats(AVFilterContext *ctx)
312 {
313     AVFilterFormats *formats = NULL;
314     AVFilterChannelLayouts *layouts = NULL;
315     static const enum AVSampleFormat sample_fmts[] = {
316         AV_SAMPLE_FMT_FLTP,
317         AV_SAMPLE_FMT_NONE
318     };
319     int ret, sample_rates[] = { 48000, -1 };
320
321     formats = ff_make_format_list(sample_fmts);
322     if (!formats)
323         return AVERROR(ENOMEM);
324     ret = ff_set_common_formats(ctx, formats);
325     if (ret < 0)
326         return ret;
327
328     layouts = ff_all_channel_counts();
329     if (!layouts)
330         return AVERROR(ENOMEM);
331
332     ret = ff_set_common_channel_layouts(ctx, layouts);
333     if (ret < 0)
334         return ret;
335
336     formats = ff_make_format_list(sample_rates);
337     if (!formats)
338         return AVERROR(ENOMEM);
339     return ff_set_common_samplerates(ctx, formats);
340 }
341
342 static int config_input(AVFilterLink *inlink)
343 {
344     AVFilterContext *ctx = inlink->dst;
345     AudioRNNContext *s = ctx->priv;
346     int ret;
347
348     s->channels = inlink->channels;
349
350     s->st = av_calloc(s->channels, sizeof(DenoiseState));
351     if (!s->st)
352         return AVERROR(ENOMEM);
353
354     for (int i = 0; i < s->channels; i++) {
355         DenoiseState *st = &s->st[i];
356
357         st->rnn.model = s->model;
358         st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359         st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360         st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361         if (!st->rnn.vad_gru_state ||
362             !st->rnn.noise_gru_state ||
363             !st->rnn.denoise_gru_state)
364             return AVERROR(ENOMEM);
365
366         ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
367         if (ret < 0)
368             return ret;
369
370         ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
371         if (ret < 0)
372             return ret;
373     }
374
375     return 0;
376 }
377
378 static void biquad(float *y, float mem[2], const float *x,
379                    const float *b, const float *a, int N)
380 {
381     for (int i = 0; i < N; i++) {
382         float xi, yi;
383
384         xi = x[i];
385         yi = x[i] + mem[0];
386         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387         mem[1] = (b[1]*xi - a[1]*yi);
388         y[i] = yi;
389     }
390 }
391
392 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
395
396 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
397 {
398     AVComplexFloat x[WINDOW_SIZE];
399     AVComplexFloat y[WINDOW_SIZE];
400
401     for (int i = 0; i < WINDOW_SIZE; i++) {
402         x[i].re = in[i];
403         x[i].im = 0;
404     }
405
406     st->tx_fn(st->tx, y, x, sizeof(float));
407
408     RNN_COPY(out, y, FREQ_SIZE);
409 }
410
411 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
412 {
413     AVComplexFloat x[WINDOW_SIZE];
414     AVComplexFloat y[WINDOW_SIZE];
415
416     RNN_COPY(x, in, FREQ_SIZE);
417
418     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
419         x[i].re =  x[WINDOW_SIZE - i].re;
420         x[i].im = -x[WINDOW_SIZE - i].im;
421     }
422
423     st->txi_fn(st->txi, y, x, sizeof(float));
424
425     for (int i = 0; i < WINDOW_SIZE; i++)
426         out[i] = y[i].re / WINDOW_SIZE;
427 }
428
429 static const uint8_t eband5ms[] = {
430 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
431   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
432 };
433
434 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
435 {
436     float sum[NB_BANDS] = {0};
437
438     for (int i = 0; i < NB_BANDS - 1; i++) {
439         int band_size;
440
441         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
442         for (int j = 0; j < band_size; j++) {
443             float tmp, frac = (float)j / band_size;
444
445             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
446             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
447             sum[i]     += (1.f - frac) * tmp;
448             sum[i + 1] +=        frac  * tmp;
449         }
450     }
451
452     sum[0] *= 2;
453     sum[NB_BANDS - 1] *= 2;
454
455     for (int i = 0; i < NB_BANDS; i++)
456         bandE[i] = sum[i];
457 }
458
459 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
460 {
461     float sum[NB_BANDS] = { 0 };
462
463     for (int i = 0; i < NB_BANDS - 1; i++) {
464         int band_size;
465
466         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
467         for (int j = 0; j < band_size; j++) {
468             float tmp, frac = (float)j / band_size;
469
470             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
471             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
472             sum[i]     += (1 - frac) * tmp;
473             sum[i + 1] +=      frac  * tmp;
474         }
475     }
476
477     sum[0] *= 2;
478     sum[NB_BANDS-1] *= 2;
479
480     for (int i = 0; i < NB_BANDS; i++)
481         bandE[i] = sum[i];
482 }
483
484 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
485 {
486     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
487
488     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
489     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
490     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
491     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
492     forward_transform(st, X, x);
493     compute_band_energy(Ex, X);
494 }
495
496 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
497 {
498     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
499
500     inverse_transform(st, x, y);
501     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
502     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
503     RNN_COPY(out, x, FRAME_SIZE);
504     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
505 }
506
507 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
508 {
509     float y_0, y_1, y_2, y_3 = 0;
510     int j;
511
512     y_0 = *y++;
513     y_1 = *y++;
514     y_2 = *y++;
515
516     for (j = 0; j < len - 3; j += 4) {
517         float tmp;
518
519         tmp = *x++;
520         y_3 = *y++;
521         sum[0] += tmp * y_0;
522         sum[1] += tmp * y_1;
523         sum[2] += tmp * y_2;
524         sum[3] += tmp * y_3;
525         tmp = *x++;
526         y_0 = *y++;
527         sum[0] += tmp * y_1;
528         sum[1] += tmp * y_2;
529         sum[2] += tmp * y_3;
530         sum[3] += tmp * y_0;
531         tmp = *x++;
532         y_1 = *y++;
533         sum[0] += tmp * y_2;
534         sum[1] += tmp * y_3;
535         sum[2] += tmp * y_0;
536         sum[3] += tmp * y_1;
537         tmp = *x++;
538         y_2 = *y++;
539         sum[0] += tmp * y_3;
540         sum[1] += tmp * y_0;
541         sum[2] += tmp * y_1;
542         sum[3] += tmp * y_2;
543     }
544
545     if (j++ < len) {
546         float tmp = *x++;
547
548         y_3 = *y++;
549         sum[0] += tmp * y_0;
550         sum[1] += tmp * y_1;
551         sum[2] += tmp * y_2;
552         sum[3] += tmp * y_3;
553     }
554
555     if (j++ < len) {
556         float tmp=*x++;
557
558         y_0 = *y++;
559         sum[0] += tmp * y_1;
560         sum[1] += tmp * y_2;
561         sum[2] += tmp * y_3;
562         sum[3] += tmp * y_0;
563     }
564
565     if (j < len) {
566         float tmp=*x++;
567
568         y_1 = *y++;
569         sum[0] += tmp * y_2;
570         sum[1] += tmp * y_3;
571         sum[2] += tmp * y_0;
572         sum[3] += tmp * y_1;
573     }
574 }
575
576 static inline float celt_inner_prod(const float *x,
577                                     const float *y, int N)
578 {
579     float xy = 0.f;
580
581     for (int i = 0; i < N; i++)
582         xy += x[i] * y[i];
583
584     return xy;
585 }
586
587 static void celt_pitch_xcorr(const float *x, const float *y,
588                              float *xcorr, int len, int max_pitch)
589 {
590     int i;
591
592     for (i = 0; i < max_pitch - 3; i += 4) {
593         float sum[4] = { 0, 0, 0, 0};
594
595         xcorr_kernel(x, y + i, sum, len);
596
597         xcorr[i]     = sum[0];
598         xcorr[i + 1] = sum[1];
599         xcorr[i + 2] = sum[2];
600         xcorr[i + 3] = sum[3];
601     }
602     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
603     for (; i < max_pitch; i++) {
604         xcorr[i] = celt_inner_prod(x, y + i, len);
605     }
606 }
607
608 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
609                          float       *ac,  /* out: [0...lag-1] ac values */
610                          const float *window,
611                          int          overlap,
612                          int          lag,
613                          int          n)
614 {
615     int fastN = n - lag;
616     int shift;
617     const float *xptr;
618     float xx[PITCH_BUF_SIZE>>1];
619
620     if (overlap == 0) {
621         xptr = x;
622     } else {
623         for (int i = 0; i < n; i++)
624             xx[i] = x[i];
625         for (int i = 0; i < overlap; i++) {
626             xx[i] = x[i] * window[i];
627             xx[n-i-1] = x[n-i-1] * window[i];
628         }
629         xptr = xx;
630     }
631
632     shift = 0;
633     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
634
635     for (int k = 0; k <= lag; k++) {
636         float d = 0.f;
637
638         for (int i = k + fastN; i < n; i++)
639             d += xptr[i] * xptr[i-k];
640         ac[k] += d;
641     }
642
643     return shift;
644 }
645
646 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
647                 const float *ac,   /* in:  [0...p] autocorrelation values  */
648                           int p)
649 {
650     float r, error = ac[0];
651
652     RNN_CLEAR(lpc, p);
653     if (ac[0] != 0) {
654         for (int i = 0; i < p; i++) {
655             /* Sum up this iteration's reflection coefficient */
656             float rr = 0;
657             for (int j = 0; j < i; j++)
658                 rr += (lpc[j] * ac[i - j]);
659             rr += ac[i + 1];
660             r = -rr/error;
661             /*  Update LPC coefficients and total error */
662             lpc[i] = r;
663             for (int j = 0; j < (i + 1) >> 1; j++) {
664                 float tmp1, tmp2;
665                 tmp1 = lpc[j];
666                 tmp2 = lpc[i-1-j];
667                 lpc[j]     = tmp1 + (r*tmp2);
668                 lpc[i-1-j] = tmp2 + (r*tmp1);
669             }
670
671             error = error - (r * r *error);
672             /* Bail out once we get 30 dB gain */
673             if (error < .001f * ac[0])
674                 break;
675         }
676     }
677 }
678
679 static void celt_fir5(const float *x,
680                       const float *num,
681                       float *y,
682                       int N,
683                       float *mem)
684 {
685     float num0, num1, num2, num3, num4;
686     float mem0, mem1, mem2, mem3, mem4;
687
688     num0 = num[0];
689     num1 = num[1];
690     num2 = num[2];
691     num3 = num[3];
692     num4 = num[4];
693     mem0 = mem[0];
694     mem1 = mem[1];
695     mem2 = mem[2];
696     mem3 = mem[3];
697     mem4 = mem[4];
698
699     for (int i = 0; i < N; i++) {
700         float sum = x[i];
701
702         sum += (num0*mem0);
703         sum += (num1*mem1);
704         sum += (num2*mem2);
705         sum += (num3*mem3);
706         sum += (num4*mem4);
707         mem4 = mem3;
708         mem3 = mem2;
709         mem2 = mem1;
710         mem1 = mem0;
711         mem0 = x[i];
712         y[i] = sum;
713     }
714
715     mem[0] = mem0;
716     mem[1] = mem1;
717     mem[2] = mem2;
718     mem[3] = mem3;
719     mem[4] = mem4;
720 }
721
722 static void pitch_downsample(float *x[], float *x_lp,
723                              int len, int C)
724 {
725     float ac[5];
726     float tmp=Q15ONE;
727     float lpc[4], mem[5]={0,0,0,0,0};
728     float lpc2[5];
729     float c1 = .8f;
730
731     for (int i = 1; i < len >> 1; i++)
732         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
733     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
734     if (C==2) {
735         for (int i = 1; i < len >> 1; i++)
736             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
737         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
738     }
739
740     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
741
742     /* Noise floor -40 dB */
743     ac[0] *= 1.0001f;
744     /* Lag windowing */
745     for (int i = 1; i <= 4; i++) {
746         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
747         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
748     }
749
750     celt_lpc(lpc, ac, 4);
751     for (int i = 0; i < 4; i++) {
752         tmp = .9f * tmp;
753         lpc[i] = (lpc[i] * tmp);
754     }
755     /* Add a zero */
756     lpc2[0] = lpc[0] + .8f;
757     lpc2[1] = lpc[1] + (c1 * lpc[0]);
758     lpc2[2] = lpc[2] + (c1 * lpc[1]);
759     lpc2[3] = lpc[3] + (c1 * lpc[2]);
760     lpc2[4] = (c1 * lpc[3]);
761     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
762 }
763
764 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
765                                    int N, float *xy1, float *xy2)
766 {
767     float xy01 = 0, xy02 = 0;
768
769     for (int i = 0; i < N; i++) {
770         xy01 += (x[i] * y01[i]);
771         xy02 += (x[i] * y02[i]);
772     }
773
774     *xy1 = xy01;
775     *xy2 = xy02;
776 }
777
778 static float compute_pitch_gain(float xy, float xx, float yy)
779 {
780     return xy / sqrtf(1.f + xx * yy);
781 }
782
783 static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
784 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
785                              int *T0_, int prev_period, float prev_gain)
786 {
787     int k, i, T, T0;
788     float g, g0;
789     float pg;
790     float xy,xx,yy,xy2;
791     float xcorr[3];
792     float best_xy, best_yy;
793     int offset;
794     int minperiod0;
795     float yy_lookup[PITCH_MAX_PERIOD+1];
796
797     minperiod0 = minperiod;
798     maxperiod /= 2;
799     minperiod /= 2;
800     *T0_ /= 2;
801     prev_period /= 2;
802     N /= 2;
803     x += maxperiod;
804     if (*T0_>=maxperiod)
805         *T0_=maxperiod-1;
806
807     T = T0 = *T0_;
808     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
809     yy_lookup[0] = xx;
810     yy=xx;
811     for (i = 1; i <= maxperiod; i++) {
812         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
813         yy_lookup[i] = FFMAX(0, yy);
814     }
815     yy = yy_lookup[T0];
816     best_xy = xy;
817     best_yy = yy;
818     g = g0 = compute_pitch_gain(xy, xx, yy);
819     /* Look for any pitch at T/k */
820     for (k = 2; k <= 15; k++) {
821         int T1, T1b;
822         float g1;
823         float cont=0;
824         float thresh;
825         T1 = (2*T0+k)/(2*k);
826         if (T1 < minperiod)
827             break;
828         /* Look for another strong correlation at T1b */
829         if (k==2)
830         {
831             if (T1+T0>maxperiod)
832                 T1b = T0;
833             else
834                 T1b = T0+T1;
835         } else
836         {
837             T1b = (2*second_check[k]*T0+k)/(2*k);
838         }
839         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
840         xy = .5f * (xy + xy2);
841         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
842         g1 = compute_pitch_gain(xy, xx, yy);
843         if (FFABS(T1-prev_period)<=1)
844             cont = prev_gain;
845         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
846             cont = prev_gain * .5f;
847         else
848             cont = 0;
849         thresh = FFMAX(.3f, (.7f * g0) - cont);
850         /* Bias against very high pitch (very short period) to avoid false-positives
851            due to short-term correlation */
852         if (T1<3*minperiod)
853             thresh = FFMAX(.4f, (.85f * g0) - cont);
854         else if (T1<2*minperiod)
855             thresh = FFMAX(.5f, (.9f * g0) - cont);
856         if (g1 > thresh)
857         {
858             best_xy = xy;
859             best_yy = yy;
860             T = T1;
861             g = g1;
862         }
863     }
864     best_xy = FFMAX(0, best_xy);
865     if (best_yy <= best_xy)
866         pg = Q15ONE;
867     else
868         pg = best_xy/(best_yy + 1);
869
870     for (k = 0; k < 3; k++)
871         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
872     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
873         offset = 1;
874     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
875         offset = -1;
876     else
877         offset = 0;
878     if (pg > g)
879         pg = g;
880     *T0_ = 2*T+offset;
881
882     if (*T0_<minperiod0)
883         *T0_=minperiod0;
884     return pg;
885 }
886
887 static void find_best_pitch(float *xcorr, float *y, int len,
888                             int max_pitch, int *best_pitch)
889 {
890     float best_num[2];
891     float best_den[2];
892     float Syy = 1.f;
893
894     best_num[0] = -1;
895     best_num[1] = -1;
896     best_den[0] = 0;
897     best_den[1] = 0;
898     best_pitch[0] = 0;
899     best_pitch[1] = 1;
900
901     for (int j = 0; j < len; j++)
902         Syy += y[j] * y[j];
903
904     for (int i = 0; i < max_pitch; i++) {
905         if (xcorr[i]>0) {
906             float num;
907             float xcorr16;
908
909             xcorr16 = xcorr[i];
910             /* Considering the range of xcorr16, this should avoid both underflows
911                and overflows (inf) when squaring xcorr16 */
912             xcorr16 *= 1e-12f;
913             num = xcorr16 * xcorr16;
914             if ((num * best_den[1]) > (best_num[1] * Syy)) {
915                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
916                     best_num[1] = best_num[0];
917                     best_den[1] = best_den[0];
918                     best_pitch[1] = best_pitch[0];
919                     best_num[0] = num;
920                     best_den[0] = Syy;
921                     best_pitch[0] = i;
922                 } else {
923                     best_num[1] = num;
924                     best_den[1] = Syy;
925                     best_pitch[1] = i;
926                 }
927             }
928         }
929         Syy += y[i+len]*y[i+len] - y[i] * y[i];
930         Syy = FFMAX(1, Syy);
931     }
932 }
933
934 static void pitch_search(const float *x_lp, float *y,
935                          int len, int max_pitch, int *pitch)
936 {
937     int lag;
938     int best_pitch[2]={0,0};
939     int offset;
940
941     float x_lp4[WINDOW_SIZE];
942     float y_lp4[WINDOW_SIZE];
943     float xcorr[WINDOW_SIZE];
944
945     lag = len+max_pitch;
946
947     /* Downsample by 2 again */
948     for (int j = 0; j < len >> 2; j++)
949         x_lp4[j] = x_lp[2*j];
950     for (int j = 0; j < lag >> 2; j++)
951         y_lp4[j] = y[2*j];
952
953     /* Coarse search with 4x decimation */
954
955     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
956
957     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
958
959     /* Finer search with 2x decimation */
960     for (int i = 0; i < max_pitch >> 1; i++) {
961         float sum;
962         xcorr[i] = 0;
963         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
964             continue;
965         sum = celt_inner_prod(x_lp, y+i, len>>1);
966         xcorr[i] = FFMAX(-1, sum);
967     }
968
969     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
970
971     /* Refine by pseudo-interpolation */
972     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
973         float a, b, c;
974
975         a = xcorr[best_pitch[0] - 1];
976         b = xcorr[best_pitch[0]];
977         c = xcorr[best_pitch[0] + 1];
978         if (c - a > .7f * (b - a))
979             offset = 1;
980         else if (a - c > .7f * (b-c))
981             offset = -1;
982         else
983             offset = 0;
984     } else {
985         offset = 0;
986     }
987
988     *pitch = 2 * best_pitch[0] - offset;
989 }
990
991 static void dct(AudioRNNContext *s, float *out, const float *in)
992 {
993     for (int i = 0; i < NB_BANDS; i++) {
994         float sum = 0.f;
995
996         for (int j = 0; j < NB_BANDS; j++) {
997             sum += in[j] * s->dct_table[j * NB_BANDS + i];
998         }
999         out[i] = sum * sqrtf(2.f / 22);
1000     }
1001 }
1002
1003 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1004                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1005 {
1006     float E = 0;
1007     float *ceps_0, *ceps_1, *ceps_2;
1008     float spec_variability = 0;
1009     float Ly[NB_BANDS];
1010     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1011     float pitch_buf[PITCH_BUF_SIZE>>1];
1012     int pitch_index;
1013     float gain;
1014     float *(pre[1]);
1015     float tmp[NB_BANDS];
1016     float follow, logMax;
1017
1018     frame_analysis(s, st, X, Ex, in);
1019     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1020     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1021     pre[0] = &st->pitch_buf[0];
1022     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1023     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1024             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1025     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1026
1027     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1028             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1029     st->last_period = pitch_index;
1030     st->last_gain = gain;
1031
1032     for (int i = 0; i < WINDOW_SIZE; i++)
1033         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1034
1035     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1036     forward_transform(st, P, p);
1037     compute_band_energy(Ep, P);
1038     compute_band_corr(Exp, X, P);
1039
1040     for (int i = 0; i < NB_BANDS; i++)
1041         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1042
1043     dct(s, tmp, Exp);
1044
1045     for (int i = 0; i < NB_DELTA_CEPS; i++)
1046         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1047
1048     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1049     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1050     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1051     logMax = -2;
1052     follow = -2;
1053
1054     for (int i = 0; i < NB_BANDS; i++) {
1055         Ly[i] = log10f(1e-2f + Ex[i]);
1056         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1057         logMax = FFMAX(logMax, Ly[i]);
1058         follow = FFMAX(follow-1.5, Ly[i]);
1059         E += Ex[i];
1060     }
1061
1062     if (E < 0.04f) {
1063         /* If there's no audio, avoid messing up the state. */
1064         RNN_CLEAR(features, NB_FEATURES);
1065         return 1;
1066     }
1067
1068     dct(s, features, Ly);
1069     features[0] -= 12;
1070     features[1] -= 4;
1071     ceps_0 = st->cepstral_mem[st->memid];
1072     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1073     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1074
1075     for (int i = 0; i < NB_BANDS; i++)
1076         ceps_0[i] = features[i];
1077
1078     st->memid++;
1079     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1080         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1081         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1082         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1083     }
1084     /* Spectral variability features. */
1085     if (st->memid == CEPS_MEM)
1086         st->memid = 0;
1087
1088     for (int i = 0; i < CEPS_MEM; i++) {
1089         float mindist = 1e15f;
1090         for (int j = 0; j < CEPS_MEM; j++) {
1091             float dist = 0.f;
1092             for (int k = 0; k < NB_BANDS; k++) {
1093                 float tmp;
1094
1095                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1096                 dist += tmp*tmp;
1097             }
1098
1099             if (j != i)
1100                 mindist = FFMIN(mindist, dist);
1101         }
1102
1103         spec_variability += mindist;
1104     }
1105
1106     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1107
1108     return 0;
1109 }
1110
1111 static void interp_band_gain(float *g, const float *bandE)
1112 {
1113     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1114
1115     for (int i = 0; i < NB_BANDS - 1; i++) {
1116         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1117
1118         for (int j = 0; j < band_size; j++) {
1119             float frac = (float)j / band_size;
1120
1121             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1122         }
1123     }
1124 }
1125
1126 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1127                          const float *Exp, const float *g)
1128 {
1129     float newE[NB_BANDS];
1130     float r[NB_BANDS];
1131     float norm[NB_BANDS];
1132     float rf[FREQ_SIZE] = {0};
1133     float normf[FREQ_SIZE]={0};
1134
1135     for (int i = 0; i < NB_BANDS; i++) {
1136         if (Exp[i]>g[i]) r[i] = 1;
1137         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1138         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1139         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1140     }
1141     interp_band_gain(rf, r);
1142     for (int i = 0; i < FREQ_SIZE; i++) {
1143         X[i].re += rf[i]*P[i].re;
1144         X[i].im += rf[i]*P[i].im;
1145     }
1146     compute_band_energy(newE, X);
1147     for (int i = 0; i < NB_BANDS; i++) {
1148         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1149     }
1150     interp_band_gain(normf, norm);
1151     for (int i = 0; i < FREQ_SIZE; i++) {
1152         X[i].re *= normf[i];
1153         X[i].im *= normf[i];
1154     }
1155 }
1156
1157 static const float tansig_table[201] = {
1158     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1159     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1160     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1161     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1162     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1163     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1164     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1165     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1166     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1167     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1168     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1169     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1170     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1171     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1172     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1173     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1174     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1175     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1176     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1177     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1178     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1179     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1180     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1181     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1182     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1183     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1184     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1185     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1186     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1187     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1188     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1189     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1190     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1191     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1192     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1193     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1194     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1195     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1196     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1197     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1198     1.000000f,
1199 };
1200
1201 static inline float tansig_approx(float x)
1202 {
1203     float y, dy;
1204     float sign=1;
1205     int i;
1206
1207     /* Tests are reversed to catch NaNs */
1208     if (!(x<8))
1209         return 1;
1210     if (!(x>-8))
1211         return -1;
1212     /* Another check in case of -ffast-math */
1213
1214     if (isnan(x))
1215        return 0;
1216
1217     if (x < 0) {
1218        x=-x;
1219        sign=-1;
1220     }
1221     i = (int)floor(.5f+25*x);
1222     x -= .04f*i;
1223     y = tansig_table[i];
1224     dy = 1-y*y;
1225     y = y + x*dy*(1 - y*x);
1226     return sign*y;
1227 }
1228
1229 static inline float sigmoid_approx(float x)
1230 {
1231     return .5f + .5f*tansig_approx(.5f*x);
1232 }
1233
1234 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1235 {
1236     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1237
1238     for (int i = 0; i < N; i++) {
1239         /* Compute update gate. */
1240         float sum = layer->bias[i];
1241
1242         for (int j = 0; j < M; j++)
1243             sum += layer->input_weights[j * stride + i] * input[j];
1244
1245         output[i] = WEIGHTS_SCALE * sum;
1246     }
1247
1248     if (layer->activation == ACTIVATION_SIGMOID) {
1249         for (int i = 0; i < N; i++)
1250             output[i] = sigmoid_approx(output[i]);
1251     } else if (layer->activation == ACTIVATION_TANH) {
1252         for (int i = 0; i < N; i++)
1253             output[i] = tansig_approx(output[i]);
1254     } else if (layer->activation == ACTIVATION_RELU) {
1255         for (int i = 0; i < N; i++)
1256             output[i] = FFMAX(0, output[i]);
1257     } else {
1258         av_assert0(0);
1259     }
1260 }
1261
1262 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1263 {
1264     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1265     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1266     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1267     const int M = gru->nb_inputs;
1268     const int N = gru->nb_neurons;
1269     const int AN = FFALIGN(N, 4);
1270     const int AM = FFALIGN(M, 4);
1271     const int stride = 3 * AN, istride = 3 * AM;
1272
1273     for (int i = 0; i < N; i++) {
1274         /* Compute update gate. */
1275         float sum = gru->bias[i];
1276
1277         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1278         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1279         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1280     }
1281
1282     for (int i = 0; i < N; i++) {
1283         /* Compute reset gate. */
1284         float sum = gru->bias[N + i];
1285
1286         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1287         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1288         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1289     }
1290
1291     for (int i = 0; i < N; i++) {
1292         /* Compute output. */
1293         float sum = gru->bias[2 * N + i];
1294
1295         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1296         for (int j = 0; j < N; j++)
1297             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1298
1299         if (gru->activation == ACTIVATION_SIGMOID)
1300             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1301         else if (gru->activation == ACTIVATION_TANH)
1302             sum = tansig_approx(WEIGHTS_SCALE * sum);
1303         else if (gru->activation == ACTIVATION_RELU)
1304             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1305         else
1306             av_assert0(0);
1307         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1308     }
1309
1310     RNN_COPY(state, h, N);
1311 }
1312
1313 #define INPUT_SIZE 42
1314
1315 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1316 {
1317     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1318     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1319     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1320
1321     compute_dense(rnn->model->input_dense, dense_out, input);
1322     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1323     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1324
1325     for (int i = 0; i < rnn->model->input_dense_size; i++)
1326         noise_input[i] = dense_out[i];
1327     for (int i = 0; i < rnn->model->vad_gru_size; i++)
1328         noise_input[i + rnn->model->input_dense_size] = rnn->vad_gru_state[i];
1329     for (int i = 0; i < INPUT_SIZE; i++)
1330         noise_input[i + rnn->model->input_dense_size + rnn->model->vad_gru_size] = input[i];
1331
1332     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1333
1334     for (int i = 0; i < rnn->model->vad_gru_size; i++)
1335         denoise_input[i] = rnn->vad_gru_state[i];
1336     for (int i = 0; i < rnn->model->noise_gru_size; i++)
1337         denoise_input[i + rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
1338     for (int i = 0; i < INPUT_SIZE; i++)
1339         denoise_input[i + rnn->model->vad_gru_size + rnn->model->noise_gru_size] = input[i];
1340
1341     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1342     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1343 }
1344
1345 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1346 {
1347     AVComplexFloat X[FREQ_SIZE];
1348     AVComplexFloat P[WINDOW_SIZE];
1349     float x[FRAME_SIZE];
1350     float Ex[NB_BANDS], Ep[NB_BANDS];
1351     float Exp[NB_BANDS];
1352     float features[NB_FEATURES];
1353     float g[NB_BANDS];
1354     float gf[FREQ_SIZE];
1355     float vad_prob = 0;
1356     static const float a_hp[2] = {-1.99599, 0.99600};
1357     static const float b_hp[2] = {-2, 1};
1358     int silence;
1359
1360     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1361     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1362
1363     if (!silence) {
1364         compute_rnn(s, &st->rnn, g, &vad_prob, features);
1365         pitch_filter(X, P, Ex, Ep, Exp, g);
1366         for (int i = 0; i < NB_BANDS; i++) {
1367             float alpha = .6f;
1368
1369             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1370             st->lastg[i] = g[i];
1371         }
1372
1373         interp_band_gain(gf, g);
1374
1375         for (int i = 0; i < FREQ_SIZE; i++) {
1376             X[i].re *= gf[i];
1377             X[i].im *= gf[i];
1378         }
1379     }
1380
1381     frame_synthesis(s, st, out, X);
1382
1383     return vad_prob;
1384 }
1385
1386 typedef struct ThreadData {
1387     AVFrame *in, *out;
1388 } ThreadData;
1389
1390 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1391 {
1392     AudioRNNContext *s = ctx->priv;
1393     ThreadData *td = arg;
1394     AVFrame *in = td->in;
1395     AVFrame *out = td->out;
1396     const int start = (out->channels * jobnr) / nb_jobs;
1397     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1398
1399     for (int ch = start; ch < end; ch++) {
1400         rnnoise_channel(s, &s->st[ch],
1401                         (float *)out->extended_data[ch],
1402                         (const float *)in->extended_data[ch]);
1403     }
1404
1405     return 0;
1406 }
1407
1408 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1409 {
1410     AVFilterContext *ctx = inlink->dst;
1411     AVFilterLink *outlink = ctx->outputs[0];
1412     AVFrame *out = NULL;
1413     ThreadData td;
1414
1415     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1416     if (!out) {
1417         av_frame_free(&in);
1418         return AVERROR(ENOMEM);
1419     }
1420     out->pts = in->pts;
1421
1422     td.in = in; td.out = out;
1423     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1424                                                                    ff_filter_get_nb_threads(ctx)));
1425
1426     av_frame_free(&in);
1427     return ff_filter_frame(outlink, out);
1428 }
1429
1430 static int activate(AVFilterContext *ctx)
1431 {
1432     AVFilterLink *inlink = ctx->inputs[0];
1433     AVFilterLink *outlink = ctx->outputs[0];
1434     AVFrame *in = NULL;
1435     int ret;
1436
1437     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1438
1439     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1440     if (ret < 0)
1441         return ret;
1442
1443     if (ret > 0)
1444         return filter_frame(inlink, in);
1445
1446     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1447     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1448
1449     return FFERROR_NOT_READY;
1450 }
1451
1452 static av_cold int init(AVFilterContext *ctx)
1453 {
1454     AudioRNNContext *s = ctx->priv;
1455     FILE *f;
1456
1457     s->fdsp = avpriv_float_dsp_alloc(0);
1458     if (!s->fdsp)
1459         return AVERROR(ENOMEM);
1460
1461     if (!s->model_name)
1462         return AVERROR(EINVAL);
1463     f = av_fopen_utf8(s->model_name, "r");
1464     if (!f)
1465         return AVERROR(EINVAL);
1466
1467     s->model = rnnoise_model_from_file(f);
1468     fclose(f);
1469     if (!s->model)
1470         return AVERROR(EINVAL);
1471
1472     for (int i = 0; i < FRAME_SIZE; i++) {
1473         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1474         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1475     }
1476
1477     for (int i = 0; i < NB_BANDS; i++) {
1478         for (int j = 0; j < NB_BANDS; j++) {
1479             s->dct_table[i*NB_BANDS + j] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1480             if (j == 0)
1481                 s->dct_table[i*NB_BANDS + j] *= sqrtf(.5);
1482         }
1483     }
1484
1485     return 0;
1486 }
1487
1488 static av_cold void uninit(AVFilterContext *ctx)
1489 {
1490     AudioRNNContext *s = ctx->priv;
1491
1492     av_freep(&s->fdsp);
1493     rnnoise_model_free(s->model);
1494     s->model = NULL;
1495
1496     if (s->st) {
1497         for (int ch = 0; ch < s->channels; ch++) {
1498             av_freep(&s->st[ch].rnn.vad_gru_state);
1499             av_freep(&s->st[ch].rnn.noise_gru_state);
1500             av_freep(&s->st[ch].rnn.denoise_gru_state);
1501             av_tx_uninit(&s->st[ch].tx);
1502             av_tx_uninit(&s->st[ch].txi);
1503         }
1504     }
1505     av_freep(&s->st);
1506 }
1507
1508 static const AVFilterPad inputs[] = {
1509     {
1510         .name         = "default",
1511         .type         = AVMEDIA_TYPE_AUDIO,
1512         .config_props = config_input,
1513     },
1514     { NULL }
1515 };
1516
1517 static const AVFilterPad outputs[] = {
1518     {
1519         .name          = "default",
1520         .type          = AVMEDIA_TYPE_AUDIO,
1521     },
1522     { NULL }
1523 };
1524
1525 #define OFFSET(x) offsetof(AudioRNNContext, x)
1526 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1527
1528 static const AVOption arnndn_options[] = {
1529     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1530     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1531     { NULL }
1532 };
1533
1534 AVFILTER_DEFINE_CLASS(arnndn);
1535
1536 AVFilter ff_af_arnndn = {
1537     .name          = "arnndn",
1538     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1539     .query_formats = query_formats,
1540     .priv_size     = sizeof(AudioRNNContext),
1541     .priv_class    = &arnndn_class,
1542     .activate      = activate,
1543     .init          = init,
1544     .uninit        = uninit,
1545     .inputs        = inputs,
1546     .outputs       = outputs,
1547     .flags         = AVFILTER_FLAG_SLICE_THREADS,
1548 };