]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
avfilter/af_arnndn: use memcpy for copying in compute_rnn()
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
41 #include "avfilter.h"
42 #include "audio.h"
43 #include "filters.h"
44 #include "formats.h"
45
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56 #define SQUARE(x) ((x)*(x))
57
58 #define NB_BANDS 22
59
60 #define CEPS_MEM 8
61 #define NB_DELTA_CEPS 6
62
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65 #define WEIGHTS_SCALE (1.f/256)
66
67 #define MAX_NEURONS 128
68
69 #define ACTIVATION_TANH    0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU    2
72
73 #define Q15ONE 1.0f
74
75 typedef struct DenseLayer {
76     const float *bias;
77     const float *input_weights;
78     int nb_inputs;
79     int nb_neurons;
80     int activation;
81 } DenseLayer;
82
83 typedef struct GRULayer {
84     const float *bias;
85     const float *input_weights;
86     const float *recurrent_weights;
87     int nb_inputs;
88     int nb_neurons;
89     int activation;
90 } GRULayer;
91
92 typedef struct RNNModel {
93     int input_dense_size;
94     const DenseLayer *input_dense;
95
96     int vad_gru_size;
97     const GRULayer *vad_gru;
98
99     int noise_gru_size;
100     const GRULayer *noise_gru;
101
102     int denoise_gru_size;
103     const GRULayer *denoise_gru;
104
105     int denoise_output_size;
106     const DenseLayer *denoise_output;
107
108     int vad_output_size;
109     const DenseLayer *vad_output;
110 } RNNModel;
111
112 typedef struct RNNState {
113     float *vad_gru_state;
114     float *noise_gru_state;
115     float *denoise_gru_state;
116     RNNModel *model;
117 } RNNState;
118
119 typedef struct DenoiseState {
120     float analysis_mem[FRAME_SIZE];
121     float cepstral_mem[CEPS_MEM][NB_BANDS];
122     int memid;
123     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124     float pitch_buf[PITCH_BUF_SIZE];
125     float pitch_enh_buf[PITCH_BUF_SIZE];
126     float last_gain;
127     int last_period;
128     float mem_hp_x[2];
129     float lastg[NB_BANDS];
130     RNNState rnn;
131     AVTXContext *tx, *txi;
132     av_tx_fn tx_fn, txi_fn;
133 } DenoiseState;
134
135 typedef struct AudioRNNContext {
136     const AVClass *class;
137
138     char *model_name;
139
140     int channels;
141     DenoiseState *st;
142
143     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144     DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
145
146     RNNModel *model;
147
148     AVFloatDSPContext *fdsp;
149 } AudioRNNContext;
150
151 #define F_ACTIVATION_TANH       0
152 #define F_ACTIVATION_SIGMOID    1
153 #define F_ACTIVATION_RELU       2
154
155 static void rnnoise_model_free(RNNModel *model)
156 {
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
159     if (model->name) { \
160         av_free((void *) model->name->input_weights); \
161         av_free((void *) model->name->bias); \
162         av_free((void *) model->name); \
163     } \
164     } while (0)
165 #define FREE_GRU(name) do { \
166     if (model->name) { \
167         av_free((void *) model->name->input_weights); \
168         av_free((void *) model->name->recurrent_weights); \
169         av_free((void *) model->name->bias); \
170         av_free((void *) model->name); \
171     } \
172     } while (0)
173
174     if (!model)
175         return;
176     FREE_DENSE(input_dense);
177     FREE_GRU(vad_gru);
178     FREE_GRU(noise_gru);
179     FREE_GRU(denoise_gru);
180     FREE_DENSE(denoise_output);
181     FREE_DENSE(vad_output);
182     av_free(model);
183 }
184
185 static RNNModel *rnnoise_model_from_file(FILE *f)
186 {
187     RNNModel *ret;
188     DenseLayer *input_dense;
189     GRULayer *vad_gru;
190     GRULayer *noise_gru;
191     GRULayer *denoise_gru;
192     DenseLayer *denoise_output;
193     DenseLayer *vad_output;
194     int in;
195
196     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
197         return NULL;
198
199     ret = av_calloc(1, sizeof(RNNModel));
200     if (!ret)
201         return NULL;
202
203 #define ALLOC_LAYER(type, name) \
204     name = av_calloc(1, sizeof(type)); \
205     if (!name) { \
206         rnnoise_model_free(ret); \
207         return NULL; \
208     } \
209     ret->name = name
210
211     ALLOC_LAYER(DenseLayer, input_dense);
212     ALLOC_LAYER(GRULayer, vad_gru);
213     ALLOC_LAYER(GRULayer, noise_gru);
214     ALLOC_LAYER(GRULayer, denoise_gru);
215     ALLOC_LAYER(DenseLayer, denoise_output);
216     ALLOC_LAYER(DenseLayer, vad_output);
217
218 #define INPUT_VAL(name) do { \
219     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220         rnnoise_model_free(ret); \
221         return NULL; \
222     } \
223     name = in; \
224     } while (0)
225
226 #define INPUT_ACTIVATION(name) do { \
227     int activation; \
228     INPUT_VAL(activation); \
229     switch (activation) { \
230     case F_ACTIVATION_SIGMOID: \
231         name = ACTIVATION_SIGMOID; \
232         break; \
233     case F_ACTIVATION_RELU: \
234         name = ACTIVATION_RELU; \
235         break; \
236     default: \
237         name = ACTIVATION_TANH; \
238     } \
239     } while (0)
240
241 #define INPUT_ARRAY(name, len) do { \
242     float *values = av_calloc((len), sizeof(float)); \
243     if (!values) { \
244         rnnoise_model_free(ret); \
245         return NULL; \
246     } \
247     name = values; \
248     for (int i = 0; i < (len); i++) { \
249         if (fscanf(f, "%d", &in) != 1) { \
250             rnnoise_model_free(ret); \
251             return NULL; \
252         } \
253         values[i] = in; \
254     } \
255     } while (0)
256
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
259     if (!values) { \
260         rnnoise_model_free(ret); \
261         return NULL; \
262     } \
263     name = values; \
264     for (int k = 0; k < (len0); k++) { \
265         for (int i = 0; i < (len2); i++) { \
266             for (int j = 0; j < (len1); j++) { \
267                 if (fscanf(f, "%d", &in) != 1) { \
268                     rnnoise_model_free(ret); \
269                     return NULL; \
270                 } \
271                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
272             } \
273         } \
274     } \
275     } while (0)
276
277 #define INPUT_DENSE(name) do { \
278     INPUT_VAL(name->nb_inputs); \
279     INPUT_VAL(name->nb_neurons); \
280     ret->name ## _size = name->nb_neurons; \
281     INPUT_ACTIVATION(name->activation); \
282     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283     INPUT_ARRAY(name->bias, name->nb_neurons); \
284     } while (0)
285
286 #define INPUT_GRU(name) do { \
287     INPUT_VAL(name->nb_inputs); \
288     INPUT_VAL(name->nb_neurons); \
289     ret->name ## _size = name->nb_neurons; \
290     INPUT_ACTIVATION(name->activation); \
291     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
294     } while (0)
295
296     INPUT_DENSE(input_dense);
297     INPUT_GRU(vad_gru);
298     INPUT_GRU(noise_gru);
299     INPUT_GRU(denoise_gru);
300     INPUT_DENSE(denoise_output);
301     INPUT_DENSE(vad_output);
302
303     if (vad_output->nb_neurons != 1) {
304         rnnoise_model_free(ret);
305         return NULL;
306     }
307
308     return ret;
309 }
310
311 static int query_formats(AVFilterContext *ctx)
312 {
313     AVFilterFormats *formats = NULL;
314     AVFilterChannelLayouts *layouts = NULL;
315     static const enum AVSampleFormat sample_fmts[] = {
316         AV_SAMPLE_FMT_FLTP,
317         AV_SAMPLE_FMT_NONE
318     };
319     int ret, sample_rates[] = { 48000, -1 };
320
321     formats = ff_make_format_list(sample_fmts);
322     if (!formats)
323         return AVERROR(ENOMEM);
324     ret = ff_set_common_formats(ctx, formats);
325     if (ret < 0)
326         return ret;
327
328     layouts = ff_all_channel_counts();
329     if (!layouts)
330         return AVERROR(ENOMEM);
331
332     ret = ff_set_common_channel_layouts(ctx, layouts);
333     if (ret < 0)
334         return ret;
335
336     formats = ff_make_format_list(sample_rates);
337     if (!formats)
338         return AVERROR(ENOMEM);
339     return ff_set_common_samplerates(ctx, formats);
340 }
341
342 static int config_input(AVFilterLink *inlink)
343 {
344     AVFilterContext *ctx = inlink->dst;
345     AudioRNNContext *s = ctx->priv;
346     int ret;
347
348     s->channels = inlink->channels;
349
350     s->st = av_calloc(s->channels, sizeof(DenoiseState));
351     if (!s->st)
352         return AVERROR(ENOMEM);
353
354     for (int i = 0; i < s->channels; i++) {
355         DenoiseState *st = &s->st[i];
356
357         st->rnn.model = s->model;
358         st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359         st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360         st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361         if (!st->rnn.vad_gru_state ||
362             !st->rnn.noise_gru_state ||
363             !st->rnn.denoise_gru_state)
364             return AVERROR(ENOMEM);
365
366         ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
367         if (ret < 0)
368             return ret;
369
370         ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
371         if (ret < 0)
372             return ret;
373     }
374
375     return 0;
376 }
377
378 static void biquad(float *y, float mem[2], const float *x,
379                    const float *b, const float *a, int N)
380 {
381     for (int i = 0; i < N; i++) {
382         float xi, yi;
383
384         xi = x[i];
385         yi = x[i] + mem[0];
386         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387         mem[1] = (b[1]*xi - a[1]*yi);
388         y[i] = yi;
389     }
390 }
391
392 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
395
396 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
397 {
398     AVComplexFloat x[WINDOW_SIZE];
399     AVComplexFloat y[WINDOW_SIZE];
400
401     for (int i = 0; i < WINDOW_SIZE; i++) {
402         x[i].re = in[i];
403         x[i].im = 0;
404     }
405
406     st->tx_fn(st->tx, y, x, sizeof(float));
407
408     RNN_COPY(out, y, FREQ_SIZE);
409 }
410
411 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
412 {
413     AVComplexFloat x[WINDOW_SIZE];
414     AVComplexFloat y[WINDOW_SIZE];
415
416     RNN_COPY(x, in, FREQ_SIZE);
417
418     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
419         x[i].re =  x[WINDOW_SIZE - i].re;
420         x[i].im = -x[WINDOW_SIZE - i].im;
421     }
422
423     st->txi_fn(st->txi, y, x, sizeof(float));
424
425     for (int i = 0; i < WINDOW_SIZE; i++)
426         out[i] = y[i].re / WINDOW_SIZE;
427 }
428
429 static const uint8_t eband5ms[] = {
430 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
431   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
432 };
433
434 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
435 {
436     float sum[NB_BANDS] = {0};
437
438     for (int i = 0; i < NB_BANDS - 1; i++) {
439         int band_size;
440
441         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
442         for (int j = 0; j < band_size; j++) {
443             float tmp, frac = (float)j / band_size;
444
445             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
446             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
447             sum[i]     += (1.f - frac) * tmp;
448             sum[i + 1] +=        frac  * tmp;
449         }
450     }
451
452     sum[0] *= 2;
453     sum[NB_BANDS - 1] *= 2;
454
455     for (int i = 0; i < NB_BANDS; i++)
456         bandE[i] = sum[i];
457 }
458
459 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
460 {
461     float sum[NB_BANDS] = { 0 };
462
463     for (int i = 0; i < NB_BANDS - 1; i++) {
464         int band_size;
465
466         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
467         for (int j = 0; j < band_size; j++) {
468             float tmp, frac = (float)j / band_size;
469
470             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
471             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
472             sum[i]     += (1 - frac) * tmp;
473             sum[i + 1] +=      frac  * tmp;
474         }
475     }
476
477     sum[0] *= 2;
478     sum[NB_BANDS-1] *= 2;
479
480     for (int i = 0; i < NB_BANDS; i++)
481         bandE[i] = sum[i];
482 }
483
484 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
485 {
486     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
487
488     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
489     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
490     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
491     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
492     forward_transform(st, X, x);
493     compute_band_energy(Ex, X);
494 }
495
496 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
497 {
498     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
499
500     inverse_transform(st, x, y);
501     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
502     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
503     RNN_COPY(out, x, FRAME_SIZE);
504     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
505 }
506
507 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
508 {
509     float y_0, y_1, y_2, y_3 = 0;
510     int j;
511
512     y_0 = *y++;
513     y_1 = *y++;
514     y_2 = *y++;
515
516     for (j = 0; j < len - 3; j += 4) {
517         float tmp;
518
519         tmp = *x++;
520         y_3 = *y++;
521         sum[0] += tmp * y_0;
522         sum[1] += tmp * y_1;
523         sum[2] += tmp * y_2;
524         sum[3] += tmp * y_3;
525         tmp = *x++;
526         y_0 = *y++;
527         sum[0] += tmp * y_1;
528         sum[1] += tmp * y_2;
529         sum[2] += tmp * y_3;
530         sum[3] += tmp * y_0;
531         tmp = *x++;
532         y_1 = *y++;
533         sum[0] += tmp * y_2;
534         sum[1] += tmp * y_3;
535         sum[2] += tmp * y_0;
536         sum[3] += tmp * y_1;
537         tmp = *x++;
538         y_2 = *y++;
539         sum[0] += tmp * y_3;
540         sum[1] += tmp * y_0;
541         sum[2] += tmp * y_1;
542         sum[3] += tmp * y_2;
543     }
544
545     if (j++ < len) {
546         float tmp = *x++;
547
548         y_3 = *y++;
549         sum[0] += tmp * y_0;
550         sum[1] += tmp * y_1;
551         sum[2] += tmp * y_2;
552         sum[3] += tmp * y_3;
553     }
554
555     if (j++ < len) {
556         float tmp=*x++;
557
558         y_0 = *y++;
559         sum[0] += tmp * y_1;
560         sum[1] += tmp * y_2;
561         sum[2] += tmp * y_3;
562         sum[3] += tmp * y_0;
563     }
564
565     if (j < len) {
566         float tmp=*x++;
567
568         y_1 = *y++;
569         sum[0] += tmp * y_2;
570         sum[1] += tmp * y_3;
571         sum[2] += tmp * y_0;
572         sum[3] += tmp * y_1;
573     }
574 }
575
576 static inline float celt_inner_prod(const float *x,
577                                     const float *y, int N)
578 {
579     float xy = 0.f;
580
581     for (int i = 0; i < N; i++)
582         xy += x[i] * y[i];
583
584     return xy;
585 }
586
587 static void celt_pitch_xcorr(const float *x, const float *y,
588                              float *xcorr, int len, int max_pitch)
589 {
590     int i;
591
592     for (i = 0; i < max_pitch - 3; i += 4) {
593         float sum[4] = { 0, 0, 0, 0};
594
595         xcorr_kernel(x, y + i, sum, len);
596
597         xcorr[i]     = sum[0];
598         xcorr[i + 1] = sum[1];
599         xcorr[i + 2] = sum[2];
600         xcorr[i + 3] = sum[3];
601     }
602     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
603     for (; i < max_pitch; i++) {
604         xcorr[i] = celt_inner_prod(x, y + i, len);
605     }
606 }
607
608 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
609                          float       *ac,  /* out: [0...lag-1] ac values */
610                          const float *window,
611                          int          overlap,
612                          int          lag,
613                          int          n)
614 {
615     int fastN = n - lag;
616     int shift;
617     const float *xptr;
618     float xx[PITCH_BUF_SIZE>>1];
619
620     if (overlap == 0) {
621         xptr = x;
622     } else {
623         for (int i = 0; i < n; i++)
624             xx[i] = x[i];
625         for (int i = 0; i < overlap; i++) {
626             xx[i] = x[i] * window[i];
627             xx[n-i-1] = x[n-i-1] * window[i];
628         }
629         xptr = xx;
630     }
631
632     shift = 0;
633     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
634
635     for (int k = 0; k <= lag; k++) {
636         float d = 0.f;
637
638         for (int i = k + fastN; i < n; i++)
639             d += xptr[i] * xptr[i-k];
640         ac[k] += d;
641     }
642
643     return shift;
644 }
645
646 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
647                 const float *ac,   /* in:  [0...p] autocorrelation values  */
648                           int p)
649 {
650     float r, error = ac[0];
651
652     RNN_CLEAR(lpc, p);
653     if (ac[0] != 0) {
654         for (int i = 0; i < p; i++) {
655             /* Sum up this iteration's reflection coefficient */
656             float rr = 0;
657             for (int j = 0; j < i; j++)
658                 rr += (lpc[j] * ac[i - j]);
659             rr += ac[i + 1];
660             r = -rr/error;
661             /*  Update LPC coefficients and total error */
662             lpc[i] = r;
663             for (int j = 0; j < (i + 1) >> 1; j++) {
664                 float tmp1, tmp2;
665                 tmp1 = lpc[j];
666                 tmp2 = lpc[i-1-j];
667                 lpc[j]     = tmp1 + (r*tmp2);
668                 lpc[i-1-j] = tmp2 + (r*tmp1);
669             }
670
671             error = error - (r * r *error);
672             /* Bail out once we get 30 dB gain */
673             if (error < .001f * ac[0])
674                 break;
675         }
676     }
677 }
678
679 static void celt_fir5(const float *x,
680                       const float *num,
681                       float *y,
682                       int N,
683                       float *mem)
684 {
685     float num0, num1, num2, num3, num4;
686     float mem0, mem1, mem2, mem3, mem4;
687
688     num0 = num[0];
689     num1 = num[1];
690     num2 = num[2];
691     num3 = num[3];
692     num4 = num[4];
693     mem0 = mem[0];
694     mem1 = mem[1];
695     mem2 = mem[2];
696     mem3 = mem[3];
697     mem4 = mem[4];
698
699     for (int i = 0; i < N; i++) {
700         float sum = x[i];
701
702         sum += (num0*mem0);
703         sum += (num1*mem1);
704         sum += (num2*mem2);
705         sum += (num3*mem3);
706         sum += (num4*mem4);
707         mem4 = mem3;
708         mem3 = mem2;
709         mem2 = mem1;
710         mem1 = mem0;
711         mem0 = x[i];
712         y[i] = sum;
713     }
714
715     mem[0] = mem0;
716     mem[1] = mem1;
717     mem[2] = mem2;
718     mem[3] = mem3;
719     mem[4] = mem4;
720 }
721
722 static void pitch_downsample(float *x[], float *x_lp,
723                              int len, int C)
724 {
725     float ac[5];
726     float tmp=Q15ONE;
727     float lpc[4], mem[5]={0,0,0,0,0};
728     float lpc2[5];
729     float c1 = .8f;
730
731     for (int i = 1; i < len >> 1; i++)
732         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
733     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
734     if (C==2) {
735         for (int i = 1; i < len >> 1; i++)
736             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
737         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
738     }
739
740     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
741
742     /* Noise floor -40 dB */
743     ac[0] *= 1.0001f;
744     /* Lag windowing */
745     for (int i = 1; i <= 4; i++) {
746         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
747         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
748     }
749
750     celt_lpc(lpc, ac, 4);
751     for (int i = 0; i < 4; i++) {
752         tmp = .9f * tmp;
753         lpc[i] = (lpc[i] * tmp);
754     }
755     /* Add a zero */
756     lpc2[0] = lpc[0] + .8f;
757     lpc2[1] = lpc[1] + (c1 * lpc[0]);
758     lpc2[2] = lpc[2] + (c1 * lpc[1]);
759     lpc2[3] = lpc[3] + (c1 * lpc[2]);
760     lpc2[4] = (c1 * lpc[3]);
761     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
762 }
763
764 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
765                                    int N, float *xy1, float *xy2)
766 {
767     float xy01 = 0, xy02 = 0;
768
769     for (int i = 0; i < N; i++) {
770         xy01 += (x[i] * y01[i]);
771         xy02 += (x[i] * y02[i]);
772     }
773
774     *xy1 = xy01;
775     *xy2 = xy02;
776 }
777
778 static float compute_pitch_gain(float xy, float xx, float yy)
779 {
780     return xy / sqrtf(1.f + xx * yy);
781 }
782
783 static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
784 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
785                              int *T0_, int prev_period, float prev_gain)
786 {
787     int k, i, T, T0;
788     float g, g0;
789     float pg;
790     float xy,xx,yy,xy2;
791     float xcorr[3];
792     float best_xy, best_yy;
793     int offset;
794     int minperiod0;
795     float yy_lookup[PITCH_MAX_PERIOD+1];
796
797     minperiod0 = minperiod;
798     maxperiod /= 2;
799     minperiod /= 2;
800     *T0_ /= 2;
801     prev_period /= 2;
802     N /= 2;
803     x += maxperiod;
804     if (*T0_>=maxperiod)
805         *T0_=maxperiod-1;
806
807     T = T0 = *T0_;
808     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
809     yy_lookup[0] = xx;
810     yy=xx;
811     for (i = 1; i <= maxperiod; i++) {
812         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
813         yy_lookup[i] = FFMAX(0, yy);
814     }
815     yy = yy_lookup[T0];
816     best_xy = xy;
817     best_yy = yy;
818     g = g0 = compute_pitch_gain(xy, xx, yy);
819     /* Look for any pitch at T/k */
820     for (k = 2; k <= 15; k++) {
821         int T1, T1b;
822         float g1;
823         float cont=0;
824         float thresh;
825         T1 = (2*T0+k)/(2*k);
826         if (T1 < minperiod)
827             break;
828         /* Look for another strong correlation at T1b */
829         if (k==2)
830         {
831             if (T1+T0>maxperiod)
832                 T1b = T0;
833             else
834                 T1b = T0+T1;
835         } else
836         {
837             T1b = (2*second_check[k]*T0+k)/(2*k);
838         }
839         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
840         xy = .5f * (xy + xy2);
841         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
842         g1 = compute_pitch_gain(xy, xx, yy);
843         if (FFABS(T1-prev_period)<=1)
844             cont = prev_gain;
845         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
846             cont = prev_gain * .5f;
847         else
848             cont = 0;
849         thresh = FFMAX(.3f, (.7f * g0) - cont);
850         /* Bias against very high pitch (very short period) to avoid false-positives
851            due to short-term correlation */
852         if (T1<3*minperiod)
853             thresh = FFMAX(.4f, (.85f * g0) - cont);
854         else if (T1<2*minperiod)
855             thresh = FFMAX(.5f, (.9f * g0) - cont);
856         if (g1 > thresh)
857         {
858             best_xy = xy;
859             best_yy = yy;
860             T = T1;
861             g = g1;
862         }
863     }
864     best_xy = FFMAX(0, best_xy);
865     if (best_yy <= best_xy)
866         pg = Q15ONE;
867     else
868         pg = best_xy/(best_yy + 1);
869
870     for (k = 0; k < 3; k++)
871         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
872     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
873         offset = 1;
874     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
875         offset = -1;
876     else
877         offset = 0;
878     if (pg > g)
879         pg = g;
880     *T0_ = 2*T+offset;
881
882     if (*T0_<minperiod0)
883         *T0_=minperiod0;
884     return pg;
885 }
886
887 static void find_best_pitch(float *xcorr, float *y, int len,
888                             int max_pitch, int *best_pitch)
889 {
890     float best_num[2];
891     float best_den[2];
892     float Syy = 1.f;
893
894     best_num[0] = -1;
895     best_num[1] = -1;
896     best_den[0] = 0;
897     best_den[1] = 0;
898     best_pitch[0] = 0;
899     best_pitch[1] = 1;
900
901     for (int j = 0; j < len; j++)
902         Syy += y[j] * y[j];
903
904     for (int i = 0; i < max_pitch; i++) {
905         if (xcorr[i]>0) {
906             float num;
907             float xcorr16;
908
909             xcorr16 = xcorr[i];
910             /* Considering the range of xcorr16, this should avoid both underflows
911                and overflows (inf) when squaring xcorr16 */
912             xcorr16 *= 1e-12f;
913             num = xcorr16 * xcorr16;
914             if ((num * best_den[1]) > (best_num[1] * Syy)) {
915                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
916                     best_num[1] = best_num[0];
917                     best_den[1] = best_den[0];
918                     best_pitch[1] = best_pitch[0];
919                     best_num[0] = num;
920                     best_den[0] = Syy;
921                     best_pitch[0] = i;
922                 } else {
923                     best_num[1] = num;
924                     best_den[1] = Syy;
925                     best_pitch[1] = i;
926                 }
927             }
928         }
929         Syy += y[i+len]*y[i+len] - y[i] * y[i];
930         Syy = FFMAX(1, Syy);
931     }
932 }
933
934 static void pitch_search(const float *x_lp, float *y,
935                          int len, int max_pitch, int *pitch)
936 {
937     int lag;
938     int best_pitch[2]={0,0};
939     int offset;
940
941     float x_lp4[WINDOW_SIZE];
942     float y_lp4[WINDOW_SIZE];
943     float xcorr[WINDOW_SIZE];
944
945     lag = len+max_pitch;
946
947     /* Downsample by 2 again */
948     for (int j = 0; j < len >> 2; j++)
949         x_lp4[j] = x_lp[2*j];
950     for (int j = 0; j < lag >> 2; j++)
951         y_lp4[j] = y[2*j];
952
953     /* Coarse search with 4x decimation */
954
955     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
956
957     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
958
959     /* Finer search with 2x decimation */
960     for (int i = 0; i < max_pitch >> 1; i++) {
961         float sum;
962         xcorr[i] = 0;
963         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
964             continue;
965         sum = celt_inner_prod(x_lp, y+i, len>>1);
966         xcorr[i] = FFMAX(-1, sum);
967     }
968
969     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
970
971     /* Refine by pseudo-interpolation */
972     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
973         float a, b, c;
974
975         a = xcorr[best_pitch[0] - 1];
976         b = xcorr[best_pitch[0]];
977         c = xcorr[best_pitch[0] + 1];
978         if (c - a > .7f * (b - a))
979             offset = 1;
980         else if (a - c > .7f * (b-c))
981             offset = -1;
982         else
983             offset = 0;
984     } else {
985         offset = 0;
986     }
987
988     *pitch = 2 * best_pitch[0] - offset;
989 }
990
991 static void dct(AudioRNNContext *s, float *out, const float *in)
992 {
993     for (int i = 0; i < NB_BANDS; i++) {
994         float sum;
995
996         sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
997         out[i] = sum * sqrtf(2.f / 22);
998     }
999 }
1000
1001 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1002                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1003 {
1004     float E = 0;
1005     float *ceps_0, *ceps_1, *ceps_2;
1006     float spec_variability = 0;
1007     LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1008     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1009     float pitch_buf[PITCH_BUF_SIZE>>1];
1010     int pitch_index;
1011     float gain;
1012     float *(pre[1]);
1013     float tmp[NB_BANDS];
1014     float follow, logMax;
1015
1016     frame_analysis(s, st, X, Ex, in);
1017     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1018     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1019     pre[0] = &st->pitch_buf[0];
1020     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1021     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1022             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1023     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1024
1025     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1026             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1027     st->last_period = pitch_index;
1028     st->last_gain = gain;
1029
1030     for (int i = 0; i < WINDOW_SIZE; i++)
1031         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1032
1033     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1034     forward_transform(st, P, p);
1035     compute_band_energy(Ep, P);
1036     compute_band_corr(Exp, X, P);
1037
1038     for (int i = 0; i < NB_BANDS; i++)
1039         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1040
1041     dct(s, tmp, Exp);
1042
1043     for (int i = 0; i < NB_DELTA_CEPS; i++)
1044         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1045
1046     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1047     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1048     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1049     logMax = -2;
1050     follow = -2;
1051
1052     for (int i = 0; i < NB_BANDS; i++) {
1053         Ly[i] = log10f(1e-2f + Ex[i]);
1054         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1055         logMax = FFMAX(logMax, Ly[i]);
1056         follow = FFMAX(follow-1.5, Ly[i]);
1057         E += Ex[i];
1058     }
1059
1060     if (E < 0.04f) {
1061         /* If there's no audio, avoid messing up the state. */
1062         RNN_CLEAR(features, NB_FEATURES);
1063         return 1;
1064     }
1065
1066     dct(s, features, Ly);
1067     features[0] -= 12;
1068     features[1] -= 4;
1069     ceps_0 = st->cepstral_mem[st->memid];
1070     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1071     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1072
1073     for (int i = 0; i < NB_BANDS; i++)
1074         ceps_0[i] = features[i];
1075
1076     st->memid++;
1077     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1078         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1079         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1080         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1081     }
1082     /* Spectral variability features. */
1083     if (st->memid == CEPS_MEM)
1084         st->memid = 0;
1085
1086     for (int i = 0; i < CEPS_MEM; i++) {
1087         float mindist = 1e15f;
1088         for (int j = 0; j < CEPS_MEM; j++) {
1089             float dist = 0.f;
1090             for (int k = 0; k < NB_BANDS; k++) {
1091                 float tmp;
1092
1093                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1094                 dist += tmp*tmp;
1095             }
1096
1097             if (j != i)
1098                 mindist = FFMIN(mindist, dist);
1099         }
1100
1101         spec_variability += mindist;
1102     }
1103
1104     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1105
1106     return 0;
1107 }
1108
1109 static void interp_band_gain(float *g, const float *bandE)
1110 {
1111     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1112
1113     for (int i = 0; i < NB_BANDS - 1; i++) {
1114         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1115
1116         for (int j = 0; j < band_size; j++) {
1117             float frac = (float)j / band_size;
1118
1119             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1120         }
1121     }
1122 }
1123
1124 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1125                          const float *Exp, const float *g)
1126 {
1127     float newE[NB_BANDS];
1128     float r[NB_BANDS];
1129     float norm[NB_BANDS];
1130     float rf[FREQ_SIZE] = {0};
1131     float normf[FREQ_SIZE]={0};
1132
1133     for (int i = 0; i < NB_BANDS; i++) {
1134         if (Exp[i]>g[i]) r[i] = 1;
1135         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1136         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1137         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1138     }
1139     interp_band_gain(rf, r);
1140     for (int i = 0; i < FREQ_SIZE; i++) {
1141         X[i].re += rf[i]*P[i].re;
1142         X[i].im += rf[i]*P[i].im;
1143     }
1144     compute_band_energy(newE, X);
1145     for (int i = 0; i < NB_BANDS; i++) {
1146         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1147     }
1148     interp_band_gain(normf, norm);
1149     for (int i = 0; i < FREQ_SIZE; i++) {
1150         X[i].re *= normf[i];
1151         X[i].im *= normf[i];
1152     }
1153 }
1154
1155 static const float tansig_table[201] = {
1156     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1157     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1158     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1159     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1160     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1161     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1162     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1163     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1164     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1165     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1166     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1167     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1168     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1169     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1170     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1171     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1172     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1173     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1174     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1175     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1176     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1177     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1178     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1179     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1180     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1181     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1182     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1183     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1184     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1185     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1186     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1187     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1188     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1189     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1190     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1191     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1192     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1193     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1194     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1195     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1196     1.000000f,
1197 };
1198
1199 static inline float tansig_approx(float x)
1200 {
1201     float y, dy;
1202     float sign=1;
1203     int i;
1204
1205     /* Tests are reversed to catch NaNs */
1206     if (!(x<8))
1207         return 1;
1208     if (!(x>-8))
1209         return -1;
1210     /* Another check in case of -ffast-math */
1211
1212     if (isnan(x))
1213        return 0;
1214
1215     if (x < 0) {
1216        x=-x;
1217        sign=-1;
1218     }
1219     i = (int)floor(.5f+25*x);
1220     x -= .04f*i;
1221     y = tansig_table[i];
1222     dy = 1-y*y;
1223     y = y + x*dy*(1 - y*x);
1224     return sign*y;
1225 }
1226
1227 static inline float sigmoid_approx(float x)
1228 {
1229     return .5f + .5f*tansig_approx(.5f*x);
1230 }
1231
1232 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1233 {
1234     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1235
1236     for (int i = 0; i < N; i++) {
1237         /* Compute update gate. */
1238         float sum = layer->bias[i];
1239
1240         for (int j = 0; j < M; j++)
1241             sum += layer->input_weights[j * stride + i] * input[j];
1242
1243         output[i] = WEIGHTS_SCALE * sum;
1244     }
1245
1246     if (layer->activation == ACTIVATION_SIGMOID) {
1247         for (int i = 0; i < N; i++)
1248             output[i] = sigmoid_approx(output[i]);
1249     } else if (layer->activation == ACTIVATION_TANH) {
1250         for (int i = 0; i < N; i++)
1251             output[i] = tansig_approx(output[i]);
1252     } else if (layer->activation == ACTIVATION_RELU) {
1253         for (int i = 0; i < N; i++)
1254             output[i] = FFMAX(0, output[i]);
1255     } else {
1256         av_assert0(0);
1257     }
1258 }
1259
1260 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1261 {
1262     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1263     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1264     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1265     const int M = gru->nb_inputs;
1266     const int N = gru->nb_neurons;
1267     const int AN = FFALIGN(N, 4);
1268     const int AM = FFALIGN(M, 4);
1269     const int stride = 3 * AN, istride = 3 * AM;
1270
1271     for (int i = 0; i < N; i++) {
1272         /* Compute update gate. */
1273         float sum = gru->bias[i];
1274
1275         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1276         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1277         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1278     }
1279
1280     for (int i = 0; i < N; i++) {
1281         /* Compute reset gate. */
1282         float sum = gru->bias[N + i];
1283
1284         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1285         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1286         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1287     }
1288
1289     for (int i = 0; i < N; i++) {
1290         /* Compute output. */
1291         float sum = gru->bias[2 * N + i];
1292
1293         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1294         for (int j = 0; j < N; j++)
1295             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1296
1297         if (gru->activation == ACTIVATION_SIGMOID)
1298             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1299         else if (gru->activation == ACTIVATION_TANH)
1300             sum = tansig_approx(WEIGHTS_SCALE * sum);
1301         else if (gru->activation == ACTIVATION_RELU)
1302             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1303         else
1304             av_assert0(0);
1305         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1306     }
1307
1308     RNN_COPY(state, h, N);
1309 }
1310
1311 #define INPUT_SIZE 42
1312
1313 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1314 {
1315     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1316     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1317     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1318
1319     compute_dense(rnn->model->input_dense, dense_out, input);
1320     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1321     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1322
1323     memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1324     memcpy(noise_input + rnn->model->input_dense_size,
1325            rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1326     memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1327            input, INPUT_SIZE * sizeof(float));
1328
1329     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1330
1331     memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1332     memcpy(denoise_input + rnn->model->vad_gru_size,
1333            rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1334     memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1335            input, INPUT_SIZE * sizeof(float));
1336
1337     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1338     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1339 }
1340
1341 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1342 {
1343     AVComplexFloat X[FREQ_SIZE];
1344     AVComplexFloat P[WINDOW_SIZE];
1345     float x[FRAME_SIZE];
1346     float Ex[NB_BANDS], Ep[NB_BANDS];
1347     LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1348     float features[NB_FEATURES];
1349     float g[NB_BANDS];
1350     float gf[FREQ_SIZE];
1351     float vad_prob = 0;
1352     static const float a_hp[2] = {-1.99599, 0.99600};
1353     static const float b_hp[2] = {-2, 1};
1354     int silence;
1355
1356     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1357     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1358
1359     if (!silence) {
1360         compute_rnn(s, &st->rnn, g, &vad_prob, features);
1361         pitch_filter(X, P, Ex, Ep, Exp, g);
1362         for (int i = 0; i < NB_BANDS; i++) {
1363             float alpha = .6f;
1364
1365             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1366             st->lastg[i] = g[i];
1367         }
1368
1369         interp_band_gain(gf, g);
1370
1371         for (int i = 0; i < FREQ_SIZE; i++) {
1372             X[i].re *= gf[i];
1373             X[i].im *= gf[i];
1374         }
1375     }
1376
1377     frame_synthesis(s, st, out, X);
1378
1379     return vad_prob;
1380 }
1381
1382 typedef struct ThreadData {
1383     AVFrame *in, *out;
1384 } ThreadData;
1385
1386 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1387 {
1388     AudioRNNContext *s = ctx->priv;
1389     ThreadData *td = arg;
1390     AVFrame *in = td->in;
1391     AVFrame *out = td->out;
1392     const int start = (out->channels * jobnr) / nb_jobs;
1393     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1394
1395     for (int ch = start; ch < end; ch++) {
1396         rnnoise_channel(s, &s->st[ch],
1397                         (float *)out->extended_data[ch],
1398                         (const float *)in->extended_data[ch]);
1399     }
1400
1401     return 0;
1402 }
1403
1404 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1405 {
1406     AVFilterContext *ctx = inlink->dst;
1407     AVFilterLink *outlink = ctx->outputs[0];
1408     AVFrame *out = NULL;
1409     ThreadData td;
1410
1411     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1412     if (!out) {
1413         av_frame_free(&in);
1414         return AVERROR(ENOMEM);
1415     }
1416     out->pts = in->pts;
1417
1418     td.in = in; td.out = out;
1419     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1420                                                                    ff_filter_get_nb_threads(ctx)));
1421
1422     av_frame_free(&in);
1423     return ff_filter_frame(outlink, out);
1424 }
1425
1426 static int activate(AVFilterContext *ctx)
1427 {
1428     AVFilterLink *inlink = ctx->inputs[0];
1429     AVFilterLink *outlink = ctx->outputs[0];
1430     AVFrame *in = NULL;
1431     int ret;
1432
1433     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1434
1435     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1436     if (ret < 0)
1437         return ret;
1438
1439     if (ret > 0)
1440         return filter_frame(inlink, in);
1441
1442     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1443     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1444
1445     return FFERROR_NOT_READY;
1446 }
1447
1448 static av_cold int init(AVFilterContext *ctx)
1449 {
1450     AudioRNNContext *s = ctx->priv;
1451     FILE *f;
1452
1453     s->fdsp = avpriv_float_dsp_alloc(0);
1454     if (!s->fdsp)
1455         return AVERROR(ENOMEM);
1456
1457     if (!s->model_name)
1458         return AVERROR(EINVAL);
1459     f = av_fopen_utf8(s->model_name, "r");
1460     if (!f)
1461         return AVERROR(EINVAL);
1462
1463     s->model = rnnoise_model_from_file(f);
1464     fclose(f);
1465     if (!s->model)
1466         return AVERROR(EINVAL);
1467
1468     for (int i = 0; i < FRAME_SIZE; i++) {
1469         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1470         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1471     }
1472
1473     for (int i = 0; i < NB_BANDS; i++) {
1474         for (int j = 0; j < NB_BANDS; j++) {
1475             s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1476             if (j == 0)
1477                 s->dct_table[j][i] *= sqrtf(.5);
1478         }
1479     }
1480
1481     return 0;
1482 }
1483
1484 static av_cold void uninit(AVFilterContext *ctx)
1485 {
1486     AudioRNNContext *s = ctx->priv;
1487
1488     av_freep(&s->fdsp);
1489     rnnoise_model_free(s->model);
1490     s->model = NULL;
1491
1492     if (s->st) {
1493         for (int ch = 0; ch < s->channels; ch++) {
1494             av_freep(&s->st[ch].rnn.vad_gru_state);
1495             av_freep(&s->st[ch].rnn.noise_gru_state);
1496             av_freep(&s->st[ch].rnn.denoise_gru_state);
1497             av_tx_uninit(&s->st[ch].tx);
1498             av_tx_uninit(&s->st[ch].txi);
1499         }
1500     }
1501     av_freep(&s->st);
1502 }
1503
1504 static const AVFilterPad inputs[] = {
1505     {
1506         .name         = "default",
1507         .type         = AVMEDIA_TYPE_AUDIO,
1508         .config_props = config_input,
1509     },
1510     { NULL }
1511 };
1512
1513 static const AVFilterPad outputs[] = {
1514     {
1515         .name          = "default",
1516         .type          = AVMEDIA_TYPE_AUDIO,
1517     },
1518     { NULL }
1519 };
1520
1521 #define OFFSET(x) offsetof(AudioRNNContext, x)
1522 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1523
1524 static const AVOption arnndn_options[] = {
1525     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1526     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1527     { NULL }
1528 };
1529
1530 AVFILTER_DEFINE_CLASS(arnndn);
1531
1532 AVFilter ff_af_arnndn = {
1533     .name          = "arnndn",
1534     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1535     .query_formats = query_formats,
1536     .priv_size     = sizeof(AudioRNNContext),
1537     .priv_class    = &arnndn_class,
1538     .activate      = activate,
1539     .init          = init,
1540     .uninit        = uninit,
1541     .inputs        = inputs,
1542     .outputs       = outputs,
1543     .flags         = AVFILTER_FLAG_SLICE_THREADS,
1544 };