]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
avfilter/af_arnndn: add mix option
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
41 #include "avfilter.h"
42 #include "audio.h"
43 #include "filters.h"
44 #include "formats.h"
45
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56 #define SQUARE(x) ((x)*(x))
57
58 #define NB_BANDS 22
59
60 #define CEPS_MEM 8
61 #define NB_DELTA_CEPS 6
62
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65 #define WEIGHTS_SCALE (1.f/256)
66
67 #define MAX_NEURONS 128
68
69 #define ACTIVATION_TANH    0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU    2
72
73 #define Q15ONE 1.0f
74
75 typedef struct DenseLayer {
76     const float *bias;
77     const float *input_weights;
78     int nb_inputs;
79     int nb_neurons;
80     int activation;
81 } DenseLayer;
82
83 typedef struct GRULayer {
84     const float *bias;
85     const float *input_weights;
86     const float *recurrent_weights;
87     int nb_inputs;
88     int nb_neurons;
89     int activation;
90 } GRULayer;
91
92 typedef struct RNNModel {
93     int input_dense_size;
94     const DenseLayer *input_dense;
95
96     int vad_gru_size;
97     const GRULayer *vad_gru;
98
99     int noise_gru_size;
100     const GRULayer *noise_gru;
101
102     int denoise_gru_size;
103     const GRULayer *denoise_gru;
104
105     int denoise_output_size;
106     const DenseLayer *denoise_output;
107
108     int vad_output_size;
109     const DenseLayer *vad_output;
110 } RNNModel;
111
112 typedef struct RNNState {
113     float *vad_gru_state;
114     float *noise_gru_state;
115     float *denoise_gru_state;
116     RNNModel *model;
117 } RNNState;
118
119 typedef struct DenoiseState {
120     float analysis_mem[FRAME_SIZE];
121     float cepstral_mem[CEPS_MEM][NB_BANDS];
122     int memid;
123     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124     float pitch_buf[PITCH_BUF_SIZE];
125     float pitch_enh_buf[PITCH_BUF_SIZE];
126     float last_gain;
127     int last_period;
128     float mem_hp_x[2];
129     float lastg[NB_BANDS];
130     float history[FRAME_SIZE];
131     RNNState rnn;
132     AVTXContext *tx, *txi;
133     av_tx_fn tx_fn, txi_fn;
134 } DenoiseState;
135
136 typedef struct AudioRNNContext {
137     const AVClass *class;
138
139     char *model_name;
140     float mix;
141
142     int channels;
143     DenoiseState *st;
144
145     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
146     DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
147
148     RNNModel *model;
149
150     AVFloatDSPContext *fdsp;
151 } AudioRNNContext;
152
153 #define F_ACTIVATION_TANH       0
154 #define F_ACTIVATION_SIGMOID    1
155 #define F_ACTIVATION_RELU       2
156
157 static void rnnoise_model_free(RNNModel *model)
158 {
159 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
160 #define FREE_DENSE(name) do { \
161     if (model->name) { \
162         av_free((void *) model->name->input_weights); \
163         av_free((void *) model->name->bias); \
164         av_free((void *) model->name); \
165     } \
166     } while (0)
167 #define FREE_GRU(name) do { \
168     if (model->name) { \
169         av_free((void *) model->name->input_weights); \
170         av_free((void *) model->name->recurrent_weights); \
171         av_free((void *) model->name->bias); \
172         av_free((void *) model->name); \
173     } \
174     } while (0)
175
176     if (!model)
177         return;
178     FREE_DENSE(input_dense);
179     FREE_GRU(vad_gru);
180     FREE_GRU(noise_gru);
181     FREE_GRU(denoise_gru);
182     FREE_DENSE(denoise_output);
183     FREE_DENSE(vad_output);
184     av_free(model);
185 }
186
187 static RNNModel *rnnoise_model_from_file(FILE *f)
188 {
189     RNNModel *ret;
190     DenseLayer *input_dense;
191     GRULayer *vad_gru;
192     GRULayer *noise_gru;
193     GRULayer *denoise_gru;
194     DenseLayer *denoise_output;
195     DenseLayer *vad_output;
196     int in;
197
198     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
199         return NULL;
200
201     ret = av_calloc(1, sizeof(RNNModel));
202     if (!ret)
203         return NULL;
204
205 #define ALLOC_LAYER(type, name) \
206     name = av_calloc(1, sizeof(type)); \
207     if (!name) { \
208         rnnoise_model_free(ret); \
209         return NULL; \
210     } \
211     ret->name = name
212
213     ALLOC_LAYER(DenseLayer, input_dense);
214     ALLOC_LAYER(GRULayer, vad_gru);
215     ALLOC_LAYER(GRULayer, noise_gru);
216     ALLOC_LAYER(GRULayer, denoise_gru);
217     ALLOC_LAYER(DenseLayer, denoise_output);
218     ALLOC_LAYER(DenseLayer, vad_output);
219
220 #define INPUT_VAL(name) do { \
221     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
222         rnnoise_model_free(ret); \
223         return NULL; \
224     } \
225     name = in; \
226     } while (0)
227
228 #define INPUT_ACTIVATION(name) do { \
229     int activation; \
230     INPUT_VAL(activation); \
231     switch (activation) { \
232     case F_ACTIVATION_SIGMOID: \
233         name = ACTIVATION_SIGMOID; \
234         break; \
235     case F_ACTIVATION_RELU: \
236         name = ACTIVATION_RELU; \
237         break; \
238     default: \
239         name = ACTIVATION_TANH; \
240     } \
241     } while (0)
242
243 #define INPUT_ARRAY(name, len) do { \
244     float *values = av_calloc((len), sizeof(float)); \
245     if (!values) { \
246         rnnoise_model_free(ret); \
247         return NULL; \
248     } \
249     name = values; \
250     for (int i = 0; i < (len); i++) { \
251         if (fscanf(f, "%d", &in) != 1) { \
252             rnnoise_model_free(ret); \
253             return NULL; \
254         } \
255         values[i] = in; \
256     } \
257     } while (0)
258
259 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
260     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
261     if (!values) { \
262         rnnoise_model_free(ret); \
263         return NULL; \
264     } \
265     name = values; \
266     for (int k = 0; k < (len0); k++) { \
267         for (int i = 0; i < (len2); i++) { \
268             for (int j = 0; j < (len1); j++) { \
269                 if (fscanf(f, "%d", &in) != 1) { \
270                     rnnoise_model_free(ret); \
271                     return NULL; \
272                 } \
273                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
274             } \
275         } \
276     } \
277     } while (0)
278
279 #define INPUT_DENSE(name) do { \
280     INPUT_VAL(name->nb_inputs); \
281     INPUT_VAL(name->nb_neurons); \
282     ret->name ## _size = name->nb_neurons; \
283     INPUT_ACTIVATION(name->activation); \
284     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
285     INPUT_ARRAY(name->bias, name->nb_neurons); \
286     } while (0)
287
288 #define INPUT_GRU(name) do { \
289     INPUT_VAL(name->nb_inputs); \
290     INPUT_VAL(name->nb_neurons); \
291     ret->name ## _size = name->nb_neurons; \
292     INPUT_ACTIVATION(name->activation); \
293     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
294     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
295     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
296     } while (0)
297
298     INPUT_DENSE(input_dense);
299     INPUT_GRU(vad_gru);
300     INPUT_GRU(noise_gru);
301     INPUT_GRU(denoise_gru);
302     INPUT_DENSE(denoise_output);
303     INPUT_DENSE(vad_output);
304
305     if (vad_output->nb_neurons != 1) {
306         rnnoise_model_free(ret);
307         return NULL;
308     }
309
310     return ret;
311 }
312
313 static int query_formats(AVFilterContext *ctx)
314 {
315     AVFilterFormats *formats = NULL;
316     AVFilterChannelLayouts *layouts = NULL;
317     static const enum AVSampleFormat sample_fmts[] = {
318         AV_SAMPLE_FMT_FLTP,
319         AV_SAMPLE_FMT_NONE
320     };
321     int ret, sample_rates[] = { 48000, -1 };
322
323     formats = ff_make_format_list(sample_fmts);
324     if (!formats)
325         return AVERROR(ENOMEM);
326     ret = ff_set_common_formats(ctx, formats);
327     if (ret < 0)
328         return ret;
329
330     layouts = ff_all_channel_counts();
331     if (!layouts)
332         return AVERROR(ENOMEM);
333
334     ret = ff_set_common_channel_layouts(ctx, layouts);
335     if (ret < 0)
336         return ret;
337
338     formats = ff_make_format_list(sample_rates);
339     if (!formats)
340         return AVERROR(ENOMEM);
341     return ff_set_common_samplerates(ctx, formats);
342 }
343
344 static int config_input(AVFilterLink *inlink)
345 {
346     AVFilterContext *ctx = inlink->dst;
347     AudioRNNContext *s = ctx->priv;
348     int ret;
349
350     s->channels = inlink->channels;
351
352     s->st = av_calloc(s->channels, sizeof(DenoiseState));
353     if (!s->st)
354         return AVERROR(ENOMEM);
355
356     for (int i = 0; i < s->channels; i++) {
357         DenoiseState *st = &s->st[i];
358
359         st->rnn.model = s->model;
360         st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
361         st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
362         st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
363         if (!st->rnn.vad_gru_state ||
364             !st->rnn.noise_gru_state ||
365             !st->rnn.denoise_gru_state)
366             return AVERROR(ENOMEM);
367
368         ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
369         if (ret < 0)
370             return ret;
371
372         ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
373         if (ret < 0)
374             return ret;
375     }
376
377     return 0;
378 }
379
380 static void biquad(float *y, float mem[2], const float *x,
381                    const float *b, const float *a, int N)
382 {
383     for (int i = 0; i < N; i++) {
384         float xi, yi;
385
386         xi = x[i];
387         yi = x[i] + mem[0];
388         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
389         mem[1] = (b[1]*xi - a[1]*yi);
390         y[i] = yi;
391     }
392 }
393
394 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
395 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
396 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
397
398 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
399 {
400     AVComplexFloat x[WINDOW_SIZE];
401     AVComplexFloat y[WINDOW_SIZE];
402
403     for (int i = 0; i < WINDOW_SIZE; i++) {
404         x[i].re = in[i];
405         x[i].im = 0;
406     }
407
408     st->tx_fn(st->tx, y, x, sizeof(float));
409
410     RNN_COPY(out, y, FREQ_SIZE);
411 }
412
413 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
414 {
415     AVComplexFloat x[WINDOW_SIZE];
416     AVComplexFloat y[WINDOW_SIZE];
417
418     RNN_COPY(x, in, FREQ_SIZE);
419
420     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
421         x[i].re =  x[WINDOW_SIZE - i].re;
422         x[i].im = -x[WINDOW_SIZE - i].im;
423     }
424
425     st->txi_fn(st->txi, y, x, sizeof(float));
426
427     for (int i = 0; i < WINDOW_SIZE; i++)
428         out[i] = y[i].re / WINDOW_SIZE;
429 }
430
431 static const uint8_t eband5ms[] = {
432 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
433   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
434 };
435
436 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
437 {
438     float sum[NB_BANDS] = {0};
439
440     for (int i = 0; i < NB_BANDS - 1; i++) {
441         int band_size;
442
443         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
444         for (int j = 0; j < band_size; j++) {
445             float tmp, frac = (float)j / band_size;
446
447             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
448             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
449             sum[i]     += (1.f - frac) * tmp;
450             sum[i + 1] +=        frac  * tmp;
451         }
452     }
453
454     sum[0] *= 2;
455     sum[NB_BANDS - 1] *= 2;
456
457     for (int i = 0; i < NB_BANDS; i++)
458         bandE[i] = sum[i];
459 }
460
461 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
462 {
463     float sum[NB_BANDS] = { 0 };
464
465     for (int i = 0; i < NB_BANDS - 1; i++) {
466         int band_size;
467
468         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
469         for (int j = 0; j < band_size; j++) {
470             float tmp, frac = (float)j / band_size;
471
472             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
473             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
474             sum[i]     += (1 - frac) * tmp;
475             sum[i + 1] +=      frac  * tmp;
476         }
477     }
478
479     sum[0] *= 2;
480     sum[NB_BANDS-1] *= 2;
481
482     for (int i = 0; i < NB_BANDS; i++)
483         bandE[i] = sum[i];
484 }
485
486 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
487 {
488     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
489
490     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
491     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
492     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
493     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
494     forward_transform(st, X, x);
495     compute_band_energy(Ex, X);
496 }
497
498 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
499 {
500     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
501     const float *src = st->history;
502     const float mix = s->mix;
503     const float imix = 1.f - FFMAX(mix, 0.f);
504
505     inverse_transform(st, x, y);
506     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
507     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
508     RNN_COPY(out, x, FRAME_SIZE);
509     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
510
511     for (int n = 0; n < FRAME_SIZE; n++)
512         out[n] = out[n] * mix + src[n] * imix;
513 }
514
515 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
516 {
517     float y_0, y_1, y_2, y_3 = 0;
518     int j;
519
520     y_0 = *y++;
521     y_1 = *y++;
522     y_2 = *y++;
523
524     for (j = 0; j < len - 3; j += 4) {
525         float tmp;
526
527         tmp = *x++;
528         y_3 = *y++;
529         sum[0] += tmp * y_0;
530         sum[1] += tmp * y_1;
531         sum[2] += tmp * y_2;
532         sum[3] += tmp * y_3;
533         tmp = *x++;
534         y_0 = *y++;
535         sum[0] += tmp * y_1;
536         sum[1] += tmp * y_2;
537         sum[2] += tmp * y_3;
538         sum[3] += tmp * y_0;
539         tmp = *x++;
540         y_1 = *y++;
541         sum[0] += tmp * y_2;
542         sum[1] += tmp * y_3;
543         sum[2] += tmp * y_0;
544         sum[3] += tmp * y_1;
545         tmp = *x++;
546         y_2 = *y++;
547         sum[0] += tmp * y_3;
548         sum[1] += tmp * y_0;
549         sum[2] += tmp * y_1;
550         sum[3] += tmp * y_2;
551     }
552
553     if (j++ < len) {
554         float tmp = *x++;
555
556         y_3 = *y++;
557         sum[0] += tmp * y_0;
558         sum[1] += tmp * y_1;
559         sum[2] += tmp * y_2;
560         sum[3] += tmp * y_3;
561     }
562
563     if (j++ < len) {
564         float tmp=*x++;
565
566         y_0 = *y++;
567         sum[0] += tmp * y_1;
568         sum[1] += tmp * y_2;
569         sum[2] += tmp * y_3;
570         sum[3] += tmp * y_0;
571     }
572
573     if (j < len) {
574         float tmp=*x++;
575
576         y_1 = *y++;
577         sum[0] += tmp * y_2;
578         sum[1] += tmp * y_3;
579         sum[2] += tmp * y_0;
580         sum[3] += tmp * y_1;
581     }
582 }
583
584 static inline float celt_inner_prod(const float *x,
585                                     const float *y, int N)
586 {
587     float xy = 0.f;
588
589     for (int i = 0; i < N; i++)
590         xy += x[i] * y[i];
591
592     return xy;
593 }
594
595 static void celt_pitch_xcorr(const float *x, const float *y,
596                              float *xcorr, int len, int max_pitch)
597 {
598     int i;
599
600     for (i = 0; i < max_pitch - 3; i += 4) {
601         float sum[4] = { 0, 0, 0, 0};
602
603         xcorr_kernel(x, y + i, sum, len);
604
605         xcorr[i]     = sum[0];
606         xcorr[i + 1] = sum[1];
607         xcorr[i + 2] = sum[2];
608         xcorr[i + 3] = sum[3];
609     }
610     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
611     for (; i < max_pitch; i++) {
612         xcorr[i] = celt_inner_prod(x, y + i, len);
613     }
614 }
615
616 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
617                          float       *ac,  /* out: [0...lag-1] ac values */
618                          const float *window,
619                          int          overlap,
620                          int          lag,
621                          int          n)
622 {
623     int fastN = n - lag;
624     int shift;
625     const float *xptr;
626     float xx[PITCH_BUF_SIZE>>1];
627
628     if (overlap == 0) {
629         xptr = x;
630     } else {
631         for (int i = 0; i < n; i++)
632             xx[i] = x[i];
633         for (int i = 0; i < overlap; i++) {
634             xx[i] = x[i] * window[i];
635             xx[n-i-1] = x[n-i-1] * window[i];
636         }
637         xptr = xx;
638     }
639
640     shift = 0;
641     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
642
643     for (int k = 0; k <= lag; k++) {
644         float d = 0.f;
645
646         for (int i = k + fastN; i < n; i++)
647             d += xptr[i] * xptr[i-k];
648         ac[k] += d;
649     }
650
651     return shift;
652 }
653
654 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
655                 const float *ac,   /* in:  [0...p] autocorrelation values  */
656                           int p)
657 {
658     float r, error = ac[0];
659
660     RNN_CLEAR(lpc, p);
661     if (ac[0] != 0) {
662         for (int i = 0; i < p; i++) {
663             /* Sum up this iteration's reflection coefficient */
664             float rr = 0;
665             for (int j = 0; j < i; j++)
666                 rr += (lpc[j] * ac[i - j]);
667             rr += ac[i + 1];
668             r = -rr/error;
669             /*  Update LPC coefficients and total error */
670             lpc[i] = r;
671             for (int j = 0; j < (i + 1) >> 1; j++) {
672                 float tmp1, tmp2;
673                 tmp1 = lpc[j];
674                 tmp2 = lpc[i-1-j];
675                 lpc[j]     = tmp1 + (r*tmp2);
676                 lpc[i-1-j] = tmp2 + (r*tmp1);
677             }
678
679             error = error - (r * r *error);
680             /* Bail out once we get 30 dB gain */
681             if (error < .001f * ac[0])
682                 break;
683         }
684     }
685 }
686
687 static void celt_fir5(const float *x,
688                       const float *num,
689                       float *y,
690                       int N,
691                       float *mem)
692 {
693     float num0, num1, num2, num3, num4;
694     float mem0, mem1, mem2, mem3, mem4;
695
696     num0 = num[0];
697     num1 = num[1];
698     num2 = num[2];
699     num3 = num[3];
700     num4 = num[4];
701     mem0 = mem[0];
702     mem1 = mem[1];
703     mem2 = mem[2];
704     mem3 = mem[3];
705     mem4 = mem[4];
706
707     for (int i = 0; i < N; i++) {
708         float sum = x[i];
709
710         sum += (num0*mem0);
711         sum += (num1*mem1);
712         sum += (num2*mem2);
713         sum += (num3*mem3);
714         sum += (num4*mem4);
715         mem4 = mem3;
716         mem3 = mem2;
717         mem2 = mem1;
718         mem1 = mem0;
719         mem0 = x[i];
720         y[i] = sum;
721     }
722
723     mem[0] = mem0;
724     mem[1] = mem1;
725     mem[2] = mem2;
726     mem[3] = mem3;
727     mem[4] = mem4;
728 }
729
730 static void pitch_downsample(float *x[], float *x_lp,
731                              int len, int C)
732 {
733     float ac[5];
734     float tmp=Q15ONE;
735     float lpc[4], mem[5]={0,0,0,0,0};
736     float lpc2[5];
737     float c1 = .8f;
738
739     for (int i = 1; i < len >> 1; i++)
740         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
741     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
742     if (C==2) {
743         for (int i = 1; i < len >> 1; i++)
744             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
745         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
746     }
747
748     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
749
750     /* Noise floor -40 dB */
751     ac[0] *= 1.0001f;
752     /* Lag windowing */
753     for (int i = 1; i <= 4; i++) {
754         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
755         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
756     }
757
758     celt_lpc(lpc, ac, 4);
759     for (int i = 0; i < 4; i++) {
760         tmp = .9f * tmp;
761         lpc[i] = (lpc[i] * tmp);
762     }
763     /* Add a zero */
764     lpc2[0] = lpc[0] + .8f;
765     lpc2[1] = lpc[1] + (c1 * lpc[0]);
766     lpc2[2] = lpc[2] + (c1 * lpc[1]);
767     lpc2[3] = lpc[3] + (c1 * lpc[2]);
768     lpc2[4] = (c1 * lpc[3]);
769     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
770 }
771
772 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
773                                    int N, float *xy1, float *xy2)
774 {
775     float xy01 = 0, xy02 = 0;
776
777     for (int i = 0; i < N; i++) {
778         xy01 += (x[i] * y01[i]);
779         xy02 += (x[i] * y02[i]);
780     }
781
782     *xy1 = xy01;
783     *xy2 = xy02;
784 }
785
786 static float compute_pitch_gain(float xy, float xx, float yy)
787 {
788     return xy / sqrtf(1.f + xx * yy);
789 }
790
791 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
792 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
793                              int *T0_, int prev_period, float prev_gain)
794 {
795     int k, i, T, T0;
796     float g, g0;
797     float pg;
798     float xy,xx,yy,xy2;
799     float xcorr[3];
800     float best_xy, best_yy;
801     int offset;
802     int minperiod0;
803     float yy_lookup[PITCH_MAX_PERIOD+1];
804
805     minperiod0 = minperiod;
806     maxperiod /= 2;
807     minperiod /= 2;
808     *T0_ /= 2;
809     prev_period /= 2;
810     N /= 2;
811     x += maxperiod;
812     if (*T0_>=maxperiod)
813         *T0_=maxperiod-1;
814
815     T = T0 = *T0_;
816     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
817     yy_lookup[0] = xx;
818     yy=xx;
819     for (i = 1; i <= maxperiod; i++) {
820         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
821         yy_lookup[i] = FFMAX(0, yy);
822     }
823     yy = yy_lookup[T0];
824     best_xy = xy;
825     best_yy = yy;
826     g = g0 = compute_pitch_gain(xy, xx, yy);
827     /* Look for any pitch at T/k */
828     for (k = 2; k <= 15; k++) {
829         int T1, T1b;
830         float g1;
831         float cont=0;
832         float thresh;
833         T1 = (2*T0+k)/(2*k);
834         if (T1 < minperiod)
835             break;
836         /* Look for another strong correlation at T1b */
837         if (k==2)
838         {
839             if (T1+T0>maxperiod)
840                 T1b = T0;
841             else
842                 T1b = T0+T1;
843         } else
844         {
845             T1b = (2*second_check[k]*T0+k)/(2*k);
846         }
847         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
848         xy = .5f * (xy + xy2);
849         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
850         g1 = compute_pitch_gain(xy, xx, yy);
851         if (FFABS(T1-prev_period)<=1)
852             cont = prev_gain;
853         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
854             cont = prev_gain * .5f;
855         else
856             cont = 0;
857         thresh = FFMAX(.3f, (.7f * g0) - cont);
858         /* Bias against very high pitch (very short period) to avoid false-positives
859            due to short-term correlation */
860         if (T1<3*minperiod)
861             thresh = FFMAX(.4f, (.85f * g0) - cont);
862         else if (T1<2*minperiod)
863             thresh = FFMAX(.5f, (.9f * g0) - cont);
864         if (g1 > thresh)
865         {
866             best_xy = xy;
867             best_yy = yy;
868             T = T1;
869             g = g1;
870         }
871     }
872     best_xy = FFMAX(0, best_xy);
873     if (best_yy <= best_xy)
874         pg = Q15ONE;
875     else
876         pg = best_xy/(best_yy + 1);
877
878     for (k = 0; k < 3; k++)
879         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
880     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
881         offset = 1;
882     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
883         offset = -1;
884     else
885         offset = 0;
886     if (pg > g)
887         pg = g;
888     *T0_ = 2*T+offset;
889
890     if (*T0_<minperiod0)
891         *T0_=minperiod0;
892     return pg;
893 }
894
895 static void find_best_pitch(float *xcorr, float *y, int len,
896                             int max_pitch, int *best_pitch)
897 {
898     float best_num[2];
899     float best_den[2];
900     float Syy = 1.f;
901
902     best_num[0] = -1;
903     best_num[1] = -1;
904     best_den[0] = 0;
905     best_den[1] = 0;
906     best_pitch[0] = 0;
907     best_pitch[1] = 1;
908
909     for (int j = 0; j < len; j++)
910         Syy += y[j] * y[j];
911
912     for (int i = 0; i < max_pitch; i++) {
913         if (xcorr[i]>0) {
914             float num;
915             float xcorr16;
916
917             xcorr16 = xcorr[i];
918             /* Considering the range of xcorr16, this should avoid both underflows
919                and overflows (inf) when squaring xcorr16 */
920             xcorr16 *= 1e-12f;
921             num = xcorr16 * xcorr16;
922             if ((num * best_den[1]) > (best_num[1] * Syy)) {
923                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
924                     best_num[1] = best_num[0];
925                     best_den[1] = best_den[0];
926                     best_pitch[1] = best_pitch[0];
927                     best_num[0] = num;
928                     best_den[0] = Syy;
929                     best_pitch[0] = i;
930                 } else {
931                     best_num[1] = num;
932                     best_den[1] = Syy;
933                     best_pitch[1] = i;
934                 }
935             }
936         }
937         Syy += y[i+len]*y[i+len] - y[i] * y[i];
938         Syy = FFMAX(1, Syy);
939     }
940 }
941
942 static void pitch_search(const float *x_lp, float *y,
943                          int len, int max_pitch, int *pitch)
944 {
945     int lag;
946     int best_pitch[2]={0,0};
947     int offset;
948
949     float x_lp4[WINDOW_SIZE];
950     float y_lp4[WINDOW_SIZE];
951     float xcorr[WINDOW_SIZE];
952
953     lag = len+max_pitch;
954
955     /* Downsample by 2 again */
956     for (int j = 0; j < len >> 2; j++)
957         x_lp4[j] = x_lp[2*j];
958     for (int j = 0; j < lag >> 2; j++)
959         y_lp4[j] = y[2*j];
960
961     /* Coarse search with 4x decimation */
962
963     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
964
965     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
966
967     /* Finer search with 2x decimation */
968     for (int i = 0; i < max_pitch >> 1; i++) {
969         float sum;
970         xcorr[i] = 0;
971         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
972             continue;
973         sum = celt_inner_prod(x_lp, y+i, len>>1);
974         xcorr[i] = FFMAX(-1, sum);
975     }
976
977     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
978
979     /* Refine by pseudo-interpolation */
980     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
981         float a, b, c;
982
983         a = xcorr[best_pitch[0] - 1];
984         b = xcorr[best_pitch[0]];
985         c = xcorr[best_pitch[0] + 1];
986         if (c - a > .7f * (b - a))
987             offset = 1;
988         else if (a - c > .7f * (b-c))
989             offset = -1;
990         else
991             offset = 0;
992     } else {
993         offset = 0;
994     }
995
996     *pitch = 2 * best_pitch[0] - offset;
997 }
998
999 static void dct(AudioRNNContext *s, float *out, const float *in)
1000 {
1001     for (int i = 0; i < NB_BANDS; i++) {
1002         float sum;
1003
1004         sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1005         out[i] = sum * sqrtf(2.f / 22);
1006     }
1007 }
1008
1009 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1010                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1011 {
1012     float E = 0;
1013     float *ceps_0, *ceps_1, *ceps_2;
1014     float spec_variability = 0;
1015     LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1016     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1017     float pitch_buf[PITCH_BUF_SIZE>>1];
1018     int pitch_index;
1019     float gain;
1020     float *(pre[1]);
1021     float tmp[NB_BANDS];
1022     float follow, logMax;
1023
1024     frame_analysis(s, st, X, Ex, in);
1025     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1026     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1027     pre[0] = &st->pitch_buf[0];
1028     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1029     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1030             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1031     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1032
1033     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1034             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1035     st->last_period = pitch_index;
1036     st->last_gain = gain;
1037
1038     for (int i = 0; i < WINDOW_SIZE; i++)
1039         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1040
1041     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1042     forward_transform(st, P, p);
1043     compute_band_energy(Ep, P);
1044     compute_band_corr(Exp, X, P);
1045
1046     for (int i = 0; i < NB_BANDS; i++)
1047         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1048
1049     dct(s, tmp, Exp);
1050
1051     for (int i = 0; i < NB_DELTA_CEPS; i++)
1052         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1053
1054     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1055     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1056     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1057     logMax = -2;
1058     follow = -2;
1059
1060     for (int i = 0; i < NB_BANDS; i++) {
1061         Ly[i] = log10f(1e-2f + Ex[i]);
1062         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1063         logMax = FFMAX(logMax, Ly[i]);
1064         follow = FFMAX(follow-1.5, Ly[i]);
1065         E += Ex[i];
1066     }
1067
1068     if (E < 0.04f) {
1069         /* If there's no audio, avoid messing up the state. */
1070         RNN_CLEAR(features, NB_FEATURES);
1071         return 1;
1072     }
1073
1074     dct(s, features, Ly);
1075     features[0] -= 12;
1076     features[1] -= 4;
1077     ceps_0 = st->cepstral_mem[st->memid];
1078     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1079     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1080
1081     for (int i = 0; i < NB_BANDS; i++)
1082         ceps_0[i] = features[i];
1083
1084     st->memid++;
1085     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1086         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1087         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1088         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1089     }
1090     /* Spectral variability features. */
1091     if (st->memid == CEPS_MEM)
1092         st->memid = 0;
1093
1094     for (int i = 0; i < CEPS_MEM; i++) {
1095         float mindist = 1e15f;
1096         for (int j = 0; j < CEPS_MEM; j++) {
1097             float dist = 0.f;
1098             for (int k = 0; k < NB_BANDS; k++) {
1099                 float tmp;
1100
1101                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1102                 dist += tmp*tmp;
1103             }
1104
1105             if (j != i)
1106                 mindist = FFMIN(mindist, dist);
1107         }
1108
1109         spec_variability += mindist;
1110     }
1111
1112     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1113
1114     return 0;
1115 }
1116
1117 static void interp_band_gain(float *g, const float *bandE)
1118 {
1119     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1120
1121     for (int i = 0; i < NB_BANDS - 1; i++) {
1122         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1123
1124         for (int j = 0; j < band_size; j++) {
1125             float frac = (float)j / band_size;
1126
1127             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1128         }
1129     }
1130 }
1131
1132 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1133                          const float *Exp, const float *g)
1134 {
1135     float newE[NB_BANDS];
1136     float r[NB_BANDS];
1137     float norm[NB_BANDS];
1138     float rf[FREQ_SIZE] = {0};
1139     float normf[FREQ_SIZE]={0};
1140
1141     for (int i = 0; i < NB_BANDS; i++) {
1142         if (Exp[i]>g[i]) r[i] = 1;
1143         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1144         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1145         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1146     }
1147     interp_band_gain(rf, r);
1148     for (int i = 0; i < FREQ_SIZE; i++) {
1149         X[i].re += rf[i]*P[i].re;
1150         X[i].im += rf[i]*P[i].im;
1151     }
1152     compute_band_energy(newE, X);
1153     for (int i = 0; i < NB_BANDS; i++) {
1154         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1155     }
1156     interp_band_gain(normf, norm);
1157     for (int i = 0; i < FREQ_SIZE; i++) {
1158         X[i].re *= normf[i];
1159         X[i].im *= normf[i];
1160     }
1161 }
1162
1163 static const float tansig_table[201] = {
1164     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1165     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1166     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1167     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1168     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1169     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1170     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1171     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1172     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1173     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1174     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1175     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1176     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1177     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1178     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1179     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1180     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1181     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1182     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1183     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1184     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1185     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1186     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1187     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1188     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1189     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1190     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1191     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1192     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1193     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1194     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1195     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1196     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1197     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1198     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1199     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1200     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1201     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1202     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1203     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1204     1.000000f,
1205 };
1206
1207 static inline float tansig_approx(float x)
1208 {
1209     float y, dy;
1210     float sign=1;
1211     int i;
1212
1213     /* Tests are reversed to catch NaNs */
1214     if (!(x<8))
1215         return 1;
1216     if (!(x>-8))
1217         return -1;
1218     /* Another check in case of -ffast-math */
1219
1220     if (isnan(x))
1221        return 0;
1222
1223     if (x < 0) {
1224        x=-x;
1225        sign=-1;
1226     }
1227     i = (int)floor(.5f+25*x);
1228     x -= .04f*i;
1229     y = tansig_table[i];
1230     dy = 1-y*y;
1231     y = y + x*dy*(1 - y*x);
1232     return sign*y;
1233 }
1234
1235 static inline float sigmoid_approx(float x)
1236 {
1237     return .5f + .5f*tansig_approx(.5f*x);
1238 }
1239
1240 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1241 {
1242     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1243
1244     for (int i = 0; i < N; i++) {
1245         /* Compute update gate. */
1246         float sum = layer->bias[i];
1247
1248         for (int j = 0; j < M; j++)
1249             sum += layer->input_weights[j * stride + i] * input[j];
1250
1251         output[i] = WEIGHTS_SCALE * sum;
1252     }
1253
1254     if (layer->activation == ACTIVATION_SIGMOID) {
1255         for (int i = 0; i < N; i++)
1256             output[i] = sigmoid_approx(output[i]);
1257     } else if (layer->activation == ACTIVATION_TANH) {
1258         for (int i = 0; i < N; i++)
1259             output[i] = tansig_approx(output[i]);
1260     } else if (layer->activation == ACTIVATION_RELU) {
1261         for (int i = 0; i < N; i++)
1262             output[i] = FFMAX(0, output[i]);
1263     } else {
1264         av_assert0(0);
1265     }
1266 }
1267
1268 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1269 {
1270     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1271     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1272     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1273     const int M = gru->nb_inputs;
1274     const int N = gru->nb_neurons;
1275     const int AN = FFALIGN(N, 4);
1276     const int AM = FFALIGN(M, 4);
1277     const int stride = 3 * AN, istride = 3 * AM;
1278
1279     for (int i = 0; i < N; i++) {
1280         /* Compute update gate. */
1281         float sum = gru->bias[i];
1282
1283         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1284         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1285         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1286     }
1287
1288     for (int i = 0; i < N; i++) {
1289         /* Compute reset gate. */
1290         float sum = gru->bias[N + i];
1291
1292         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1293         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1294         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1295     }
1296
1297     for (int i = 0; i < N; i++) {
1298         /* Compute output. */
1299         float sum = gru->bias[2 * N + i];
1300
1301         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1302         for (int j = 0; j < N; j++)
1303             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1304
1305         if (gru->activation == ACTIVATION_SIGMOID)
1306             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1307         else if (gru->activation == ACTIVATION_TANH)
1308             sum = tansig_approx(WEIGHTS_SCALE * sum);
1309         else if (gru->activation == ACTIVATION_RELU)
1310             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1311         else
1312             av_assert0(0);
1313         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1314     }
1315
1316     RNN_COPY(state, h, N);
1317 }
1318
1319 #define INPUT_SIZE 42
1320
1321 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1322 {
1323     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1324     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1325     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1326
1327     compute_dense(rnn->model->input_dense, dense_out, input);
1328     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1329     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1330
1331     memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1332     memcpy(noise_input + rnn->model->input_dense_size,
1333            rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1334     memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1335            input, INPUT_SIZE * sizeof(float));
1336
1337     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1338
1339     memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1340     memcpy(denoise_input + rnn->model->vad_gru_size,
1341            rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1342     memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1343            input, INPUT_SIZE * sizeof(float));
1344
1345     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1346     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1347 }
1348
1349 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1350                              int disabled)
1351 {
1352     AVComplexFloat X[FREQ_SIZE];
1353     AVComplexFloat P[WINDOW_SIZE];
1354     float x[FRAME_SIZE];
1355     float Ex[NB_BANDS], Ep[NB_BANDS];
1356     LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1357     float features[NB_FEATURES];
1358     float g[NB_BANDS];
1359     float gf[FREQ_SIZE];
1360     float vad_prob = 0;
1361     float *history = st->history;
1362     static const float a_hp[2] = {-1.99599, 0.99600};
1363     static const float b_hp[2] = {-2, 1};
1364     int silence;
1365
1366     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1367     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1368
1369     if (!silence && !disabled) {
1370         compute_rnn(s, &st->rnn, g, &vad_prob, features);
1371         pitch_filter(X, P, Ex, Ep, Exp, g);
1372         for (int i = 0; i < NB_BANDS; i++) {
1373             float alpha = .6f;
1374
1375             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1376             st->lastg[i] = g[i];
1377         }
1378
1379         interp_band_gain(gf, g);
1380
1381         for (int i = 0; i < FREQ_SIZE; i++) {
1382             X[i].re *= gf[i];
1383             X[i].im *= gf[i];
1384         }
1385     }
1386
1387     frame_synthesis(s, st, out, X);
1388     memcpy(history, in, FRAME_SIZE * sizeof(*history));
1389
1390     return vad_prob;
1391 }
1392
1393 typedef struct ThreadData {
1394     AVFrame *in, *out;
1395 } ThreadData;
1396
1397 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1398 {
1399     AudioRNNContext *s = ctx->priv;
1400     ThreadData *td = arg;
1401     AVFrame *in = td->in;
1402     AVFrame *out = td->out;
1403     const int start = (out->channels * jobnr) / nb_jobs;
1404     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1405
1406     for (int ch = start; ch < end; ch++) {
1407         rnnoise_channel(s, &s->st[ch],
1408                         (float *)out->extended_data[ch],
1409                         (const float *)in->extended_data[ch],
1410                         ctx->is_disabled);
1411     }
1412
1413     return 0;
1414 }
1415
1416 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1417 {
1418     AVFilterContext *ctx = inlink->dst;
1419     AVFilterLink *outlink = ctx->outputs[0];
1420     AVFrame *out = NULL;
1421     ThreadData td;
1422
1423     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1424     if (!out) {
1425         av_frame_free(&in);
1426         return AVERROR(ENOMEM);
1427     }
1428     out->pts = in->pts;
1429
1430     td.in = in; td.out = out;
1431     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1432                                                                    ff_filter_get_nb_threads(ctx)));
1433
1434     av_frame_free(&in);
1435     return ff_filter_frame(outlink, out);
1436 }
1437
1438 static int activate(AVFilterContext *ctx)
1439 {
1440     AVFilterLink *inlink = ctx->inputs[0];
1441     AVFilterLink *outlink = ctx->outputs[0];
1442     AVFrame *in = NULL;
1443     int ret;
1444
1445     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1446
1447     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1448     if (ret < 0)
1449         return ret;
1450
1451     if (ret > 0)
1452         return filter_frame(inlink, in);
1453
1454     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1455     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1456
1457     return FFERROR_NOT_READY;
1458 }
1459
1460 static av_cold int init(AVFilterContext *ctx)
1461 {
1462     AudioRNNContext *s = ctx->priv;
1463     FILE *f;
1464
1465     s->fdsp = avpriv_float_dsp_alloc(0);
1466     if (!s->fdsp)
1467         return AVERROR(ENOMEM);
1468
1469     if (!s->model_name)
1470         return AVERROR(EINVAL);
1471     f = av_fopen_utf8(s->model_name, "r");
1472     if (!f)
1473         return AVERROR(EINVAL);
1474
1475     s->model = rnnoise_model_from_file(f);
1476     fclose(f);
1477     if (!s->model)
1478         return AVERROR(EINVAL);
1479
1480     for (int i = 0; i < FRAME_SIZE; i++) {
1481         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1482         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1483     }
1484
1485     for (int i = 0; i < NB_BANDS; i++) {
1486         for (int j = 0; j < NB_BANDS; j++) {
1487             s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1488             if (j == 0)
1489                 s->dct_table[j][i] *= sqrtf(.5);
1490         }
1491     }
1492
1493     return 0;
1494 }
1495
1496 static av_cold void uninit(AVFilterContext *ctx)
1497 {
1498     AudioRNNContext *s = ctx->priv;
1499
1500     av_freep(&s->fdsp);
1501     rnnoise_model_free(s->model);
1502     s->model = NULL;
1503
1504     if (s->st) {
1505         for (int ch = 0; ch < s->channels; ch++) {
1506             av_freep(&s->st[ch].rnn.vad_gru_state);
1507             av_freep(&s->st[ch].rnn.noise_gru_state);
1508             av_freep(&s->st[ch].rnn.denoise_gru_state);
1509             av_tx_uninit(&s->st[ch].tx);
1510             av_tx_uninit(&s->st[ch].txi);
1511         }
1512     }
1513     av_freep(&s->st);
1514 }
1515
1516 static const AVFilterPad inputs[] = {
1517     {
1518         .name         = "default",
1519         .type         = AVMEDIA_TYPE_AUDIO,
1520         .config_props = config_input,
1521     },
1522     { NULL }
1523 };
1524
1525 static const AVFilterPad outputs[] = {
1526     {
1527         .name          = "default",
1528         .type          = AVMEDIA_TYPE_AUDIO,
1529     },
1530     { NULL }
1531 };
1532
1533 #define OFFSET(x) offsetof(AudioRNNContext, x)
1534 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1535
1536 static const AVOption arnndn_options[] = {
1537     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1538     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1539     { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1540     { NULL }
1541 };
1542
1543 AVFILTER_DEFINE_CLASS(arnndn);
1544
1545 AVFilter ff_af_arnndn = {
1546     .name          = "arnndn",
1547     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1548     .query_formats = query_formats,
1549     .priv_size     = sizeof(AudioRNNContext),
1550     .priv_class    = &arnndn_class,
1551     .activate      = activate,
1552     .init          = init,
1553     .uninit        = uninit,
1554     .inputs        = inputs,
1555     .outputs       = outputs,
1556     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1557                      AVFILTER_FLAG_SLICE_THREADS,
1558 };