]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
lavu: move LOCAL_ALIGNED from internal.h to mem_internal.h
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
42 #include "avfilter.h"
43 #include "audio.h"
44 #include "filters.h"
45 #include "formats.h"
46
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
51
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56
57 #define SQUARE(x) ((x)*(x))
58
59 #define NB_BANDS 22
60
61 #define CEPS_MEM 8
62 #define NB_DELTA_CEPS 6
63
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65
66 #define WEIGHTS_SCALE (1.f/256)
67
68 #define MAX_NEURONS 128
69
70 #define ACTIVATION_TANH    0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU    2
73
74 #define Q15ONE 1.0f
75
76 typedef struct DenseLayer {
77     const float *bias;
78     const float *input_weights;
79     int nb_inputs;
80     int nb_neurons;
81     int activation;
82 } DenseLayer;
83
84 typedef struct GRULayer {
85     const float *bias;
86     const float *input_weights;
87     const float *recurrent_weights;
88     int nb_inputs;
89     int nb_neurons;
90     int activation;
91 } GRULayer;
92
93 typedef struct RNNModel {
94     int input_dense_size;
95     const DenseLayer *input_dense;
96
97     int vad_gru_size;
98     const GRULayer *vad_gru;
99
100     int noise_gru_size;
101     const GRULayer *noise_gru;
102
103     int denoise_gru_size;
104     const GRULayer *denoise_gru;
105
106     int denoise_output_size;
107     const DenseLayer *denoise_output;
108
109     int vad_output_size;
110     const DenseLayer *vad_output;
111 } RNNModel;
112
113 typedef struct RNNState {
114     float *vad_gru_state;
115     float *noise_gru_state;
116     float *denoise_gru_state;
117     RNNModel *model;
118 } RNNState;
119
120 typedef struct DenoiseState {
121     float analysis_mem[FRAME_SIZE];
122     float cepstral_mem[CEPS_MEM][NB_BANDS];
123     int memid;
124     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125     float pitch_buf[PITCH_BUF_SIZE];
126     float pitch_enh_buf[PITCH_BUF_SIZE];
127     float last_gain;
128     int last_period;
129     float mem_hp_x[2];
130     float lastg[NB_BANDS];
131     float history[FRAME_SIZE];
132     RNNState rnn;
133     AVTXContext *tx, *txi;
134     av_tx_fn tx_fn, txi_fn;
135 } DenoiseState;
136
137 typedef struct AudioRNNContext {
138     const AVClass *class;
139
140     char *model_name;
141     float mix;
142
143     int channels;
144     DenoiseState *st;
145
146     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147     DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148
149     RNNModel *model;
150
151     AVFloatDSPContext *fdsp;
152 } AudioRNNContext;
153
154 #define F_ACTIVATION_TANH       0
155 #define F_ACTIVATION_SIGMOID    1
156 #define F_ACTIVATION_RELU       2
157
158 static void rnnoise_model_free(RNNModel *model)
159 {
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
162     if (model->name) { \
163         av_free((void *) model->name->input_weights); \
164         av_free((void *) model->name->bias); \
165         av_free((void *) model->name); \
166     } \
167     } while (0)
168 #define FREE_GRU(name) do { \
169     if (model->name) { \
170         av_free((void *) model->name->input_weights); \
171         av_free((void *) model->name->recurrent_weights); \
172         av_free((void *) model->name->bias); \
173         av_free((void *) model->name); \
174     } \
175     } while (0)
176
177     if (!model)
178         return;
179     FREE_DENSE(input_dense);
180     FREE_GRU(vad_gru);
181     FREE_GRU(noise_gru);
182     FREE_GRU(denoise_gru);
183     FREE_DENSE(denoise_output);
184     FREE_DENSE(vad_output);
185     av_free(model);
186 }
187
188 static RNNModel *rnnoise_model_from_file(FILE *f)
189 {
190     RNNModel *ret;
191     DenseLayer *input_dense;
192     GRULayer *vad_gru;
193     GRULayer *noise_gru;
194     GRULayer *denoise_gru;
195     DenseLayer *denoise_output;
196     DenseLayer *vad_output;
197     int in;
198
199     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200         return NULL;
201
202     ret = av_calloc(1, sizeof(RNNModel));
203     if (!ret)
204         return NULL;
205
206 #define ALLOC_LAYER(type, name) \
207     name = av_calloc(1, sizeof(type)); \
208     if (!name) { \
209         rnnoise_model_free(ret); \
210         return NULL; \
211     } \
212     ret->name = name
213
214     ALLOC_LAYER(DenseLayer, input_dense);
215     ALLOC_LAYER(GRULayer, vad_gru);
216     ALLOC_LAYER(GRULayer, noise_gru);
217     ALLOC_LAYER(GRULayer, denoise_gru);
218     ALLOC_LAYER(DenseLayer, denoise_output);
219     ALLOC_LAYER(DenseLayer, vad_output);
220
221 #define INPUT_VAL(name) do { \
222     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223         rnnoise_model_free(ret); \
224         return NULL; \
225     } \
226     name = in; \
227     } while (0)
228
229 #define INPUT_ACTIVATION(name) do { \
230     int activation; \
231     INPUT_VAL(activation); \
232     switch (activation) { \
233     case F_ACTIVATION_SIGMOID: \
234         name = ACTIVATION_SIGMOID; \
235         break; \
236     case F_ACTIVATION_RELU: \
237         name = ACTIVATION_RELU; \
238         break; \
239     default: \
240         name = ACTIVATION_TANH; \
241     } \
242     } while (0)
243
244 #define INPUT_ARRAY(name, len) do { \
245     float *values = av_calloc((len), sizeof(float)); \
246     if (!values) { \
247         rnnoise_model_free(ret); \
248         return NULL; \
249     } \
250     name = values; \
251     for (int i = 0; i < (len); i++) { \
252         if (fscanf(f, "%d", &in) != 1) { \
253             rnnoise_model_free(ret); \
254             return NULL; \
255         } \
256         values[i] = in; \
257     } \
258     } while (0)
259
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262     if (!values) { \
263         rnnoise_model_free(ret); \
264         return NULL; \
265     } \
266     name = values; \
267     for (int k = 0; k < (len0); k++) { \
268         for (int i = 0; i < (len2); i++) { \
269             for (int j = 0; j < (len1); j++) { \
270                 if (fscanf(f, "%d", &in) != 1) { \
271                     rnnoise_model_free(ret); \
272                     return NULL; \
273                 } \
274                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275             } \
276         } \
277     } \
278     } while (0)
279
280 #define INPUT_DENSE(name) do { \
281     INPUT_VAL(name->nb_inputs); \
282     INPUT_VAL(name->nb_neurons); \
283     ret->name ## _size = name->nb_neurons; \
284     INPUT_ACTIVATION(name->activation); \
285     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
286     INPUT_ARRAY(name->bias, name->nb_neurons); \
287     } while (0)
288
289 #define INPUT_GRU(name) do { \
290     INPUT_VAL(name->nb_inputs); \
291     INPUT_VAL(name->nb_neurons); \
292     ret->name ## _size = name->nb_neurons; \
293     INPUT_ACTIVATION(name->activation); \
294     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
295     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
296     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
297     } while (0)
298
299     INPUT_DENSE(input_dense);
300     INPUT_GRU(vad_gru);
301     INPUT_GRU(noise_gru);
302     INPUT_GRU(denoise_gru);
303     INPUT_DENSE(denoise_output);
304     INPUT_DENSE(vad_output);
305
306     if (vad_output->nb_neurons != 1) {
307         rnnoise_model_free(ret);
308         return NULL;
309     }
310
311     return ret;
312 }
313
314 static int query_formats(AVFilterContext *ctx)
315 {
316     AVFilterFormats *formats = NULL;
317     AVFilterChannelLayouts *layouts = NULL;
318     static const enum AVSampleFormat sample_fmts[] = {
319         AV_SAMPLE_FMT_FLTP,
320         AV_SAMPLE_FMT_NONE
321     };
322     int ret, sample_rates[] = { 48000, -1 };
323
324     formats = ff_make_format_list(sample_fmts);
325     if (!formats)
326         return AVERROR(ENOMEM);
327     ret = ff_set_common_formats(ctx, formats);
328     if (ret < 0)
329         return ret;
330
331     layouts = ff_all_channel_counts();
332     if (!layouts)
333         return AVERROR(ENOMEM);
334
335     ret = ff_set_common_channel_layouts(ctx, layouts);
336     if (ret < 0)
337         return ret;
338
339     formats = ff_make_format_list(sample_rates);
340     if (!formats)
341         return AVERROR(ENOMEM);
342     return ff_set_common_samplerates(ctx, formats);
343 }
344
345 static int config_input(AVFilterLink *inlink)
346 {
347     AVFilterContext *ctx = inlink->dst;
348     AudioRNNContext *s = ctx->priv;
349     int ret;
350
351     s->channels = inlink->channels;
352
353     s->st = av_calloc(s->channels, sizeof(DenoiseState));
354     if (!s->st)
355         return AVERROR(ENOMEM);
356
357     for (int i = 0; i < s->channels; i++) {
358         DenoiseState *st = &s->st[i];
359
360         st->rnn.model = s->model;
361         st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
362         st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
363         st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
364         if (!st->rnn.vad_gru_state ||
365             !st->rnn.noise_gru_state ||
366             !st->rnn.denoise_gru_state)
367             return AVERROR(ENOMEM);
368
369         ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
370         if (ret < 0)
371             return ret;
372
373         ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
374         if (ret < 0)
375             return ret;
376     }
377
378     return 0;
379 }
380
381 static void biquad(float *y, float mem[2], const float *x,
382                    const float *b, const float *a, int N)
383 {
384     for (int i = 0; i < N; i++) {
385         float xi, yi;
386
387         xi = x[i];
388         yi = x[i] + mem[0];
389         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
390         mem[1] = (b[1]*xi - a[1]*yi);
391         y[i] = yi;
392     }
393 }
394
395 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
396 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
397 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
398
399 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
400 {
401     AVComplexFloat x[WINDOW_SIZE];
402     AVComplexFloat y[WINDOW_SIZE];
403
404     for (int i = 0; i < WINDOW_SIZE; i++) {
405         x[i].re = in[i];
406         x[i].im = 0;
407     }
408
409     st->tx_fn(st->tx, y, x, sizeof(float));
410
411     RNN_COPY(out, y, FREQ_SIZE);
412 }
413
414 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
415 {
416     AVComplexFloat x[WINDOW_SIZE];
417     AVComplexFloat y[WINDOW_SIZE];
418
419     RNN_COPY(x, in, FREQ_SIZE);
420
421     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
422         x[i].re =  x[WINDOW_SIZE - i].re;
423         x[i].im = -x[WINDOW_SIZE - i].im;
424     }
425
426     st->txi_fn(st->txi, y, x, sizeof(float));
427
428     for (int i = 0; i < WINDOW_SIZE; i++)
429         out[i] = y[i].re / WINDOW_SIZE;
430 }
431
432 static const uint8_t eband5ms[] = {
433 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
434   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
435 };
436
437 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
438 {
439     float sum[NB_BANDS] = {0};
440
441     for (int i = 0; i < NB_BANDS - 1; i++) {
442         int band_size;
443
444         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
445         for (int j = 0; j < band_size; j++) {
446             float tmp, frac = (float)j / band_size;
447
448             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
449             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
450             sum[i]     += (1.f - frac) * tmp;
451             sum[i + 1] +=        frac  * tmp;
452         }
453     }
454
455     sum[0] *= 2;
456     sum[NB_BANDS - 1] *= 2;
457
458     for (int i = 0; i < NB_BANDS; i++)
459         bandE[i] = sum[i];
460 }
461
462 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
463 {
464     float sum[NB_BANDS] = { 0 };
465
466     for (int i = 0; i < NB_BANDS - 1; i++) {
467         int band_size;
468
469         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
470         for (int j = 0; j < band_size; j++) {
471             float tmp, frac = (float)j / band_size;
472
473             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
474             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
475             sum[i]     += (1 - frac) * tmp;
476             sum[i + 1] +=      frac  * tmp;
477         }
478     }
479
480     sum[0] *= 2;
481     sum[NB_BANDS-1] *= 2;
482
483     for (int i = 0; i < NB_BANDS; i++)
484         bandE[i] = sum[i];
485 }
486
487 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
488 {
489     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
490
491     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
492     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
493     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
494     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
495     forward_transform(st, X, x);
496     compute_band_energy(Ex, X);
497 }
498
499 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
500 {
501     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
502     const float *src = st->history;
503     const float mix = s->mix;
504     const float imix = 1.f - FFMAX(mix, 0.f);
505
506     inverse_transform(st, x, y);
507     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
508     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
509     RNN_COPY(out, x, FRAME_SIZE);
510     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
511
512     for (int n = 0; n < FRAME_SIZE; n++)
513         out[n] = out[n] * mix + src[n] * imix;
514 }
515
516 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
517 {
518     float y_0, y_1, y_2, y_3 = 0;
519     int j;
520
521     y_0 = *y++;
522     y_1 = *y++;
523     y_2 = *y++;
524
525     for (j = 0; j < len - 3; j += 4) {
526         float tmp;
527
528         tmp = *x++;
529         y_3 = *y++;
530         sum[0] += tmp * y_0;
531         sum[1] += tmp * y_1;
532         sum[2] += tmp * y_2;
533         sum[3] += tmp * y_3;
534         tmp = *x++;
535         y_0 = *y++;
536         sum[0] += tmp * y_1;
537         sum[1] += tmp * y_2;
538         sum[2] += tmp * y_3;
539         sum[3] += tmp * y_0;
540         tmp = *x++;
541         y_1 = *y++;
542         sum[0] += tmp * y_2;
543         sum[1] += tmp * y_3;
544         sum[2] += tmp * y_0;
545         sum[3] += tmp * y_1;
546         tmp = *x++;
547         y_2 = *y++;
548         sum[0] += tmp * y_3;
549         sum[1] += tmp * y_0;
550         sum[2] += tmp * y_1;
551         sum[3] += tmp * y_2;
552     }
553
554     if (j++ < len) {
555         float tmp = *x++;
556
557         y_3 = *y++;
558         sum[0] += tmp * y_0;
559         sum[1] += tmp * y_1;
560         sum[2] += tmp * y_2;
561         sum[3] += tmp * y_3;
562     }
563
564     if (j++ < len) {
565         float tmp=*x++;
566
567         y_0 = *y++;
568         sum[0] += tmp * y_1;
569         sum[1] += tmp * y_2;
570         sum[2] += tmp * y_3;
571         sum[3] += tmp * y_0;
572     }
573
574     if (j < len) {
575         float tmp=*x++;
576
577         y_1 = *y++;
578         sum[0] += tmp * y_2;
579         sum[1] += tmp * y_3;
580         sum[2] += tmp * y_0;
581         sum[3] += tmp * y_1;
582     }
583 }
584
585 static inline float celt_inner_prod(const float *x,
586                                     const float *y, int N)
587 {
588     float xy = 0.f;
589
590     for (int i = 0; i < N; i++)
591         xy += x[i] * y[i];
592
593     return xy;
594 }
595
596 static void celt_pitch_xcorr(const float *x, const float *y,
597                              float *xcorr, int len, int max_pitch)
598 {
599     int i;
600
601     for (i = 0; i < max_pitch - 3; i += 4) {
602         float sum[4] = { 0, 0, 0, 0};
603
604         xcorr_kernel(x, y + i, sum, len);
605
606         xcorr[i]     = sum[0];
607         xcorr[i + 1] = sum[1];
608         xcorr[i + 2] = sum[2];
609         xcorr[i + 3] = sum[3];
610     }
611     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
612     for (; i < max_pitch; i++) {
613         xcorr[i] = celt_inner_prod(x, y + i, len);
614     }
615 }
616
617 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
618                          float       *ac,  /* out: [0...lag-1] ac values */
619                          const float *window,
620                          int          overlap,
621                          int          lag,
622                          int          n)
623 {
624     int fastN = n - lag;
625     int shift;
626     const float *xptr;
627     float xx[PITCH_BUF_SIZE>>1];
628
629     if (overlap == 0) {
630         xptr = x;
631     } else {
632         for (int i = 0; i < n; i++)
633             xx[i] = x[i];
634         for (int i = 0; i < overlap; i++) {
635             xx[i] = x[i] * window[i];
636             xx[n-i-1] = x[n-i-1] * window[i];
637         }
638         xptr = xx;
639     }
640
641     shift = 0;
642     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
643
644     for (int k = 0; k <= lag; k++) {
645         float d = 0.f;
646
647         for (int i = k + fastN; i < n; i++)
648             d += xptr[i] * xptr[i-k];
649         ac[k] += d;
650     }
651
652     return shift;
653 }
654
655 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
656                 const float *ac,   /* in:  [0...p] autocorrelation values  */
657                           int p)
658 {
659     float r, error = ac[0];
660
661     RNN_CLEAR(lpc, p);
662     if (ac[0] != 0) {
663         for (int i = 0; i < p; i++) {
664             /* Sum up this iteration's reflection coefficient */
665             float rr = 0;
666             for (int j = 0; j < i; j++)
667                 rr += (lpc[j] * ac[i - j]);
668             rr += ac[i + 1];
669             r = -rr/error;
670             /*  Update LPC coefficients and total error */
671             lpc[i] = r;
672             for (int j = 0; j < (i + 1) >> 1; j++) {
673                 float tmp1, tmp2;
674                 tmp1 = lpc[j];
675                 tmp2 = lpc[i-1-j];
676                 lpc[j]     = tmp1 + (r*tmp2);
677                 lpc[i-1-j] = tmp2 + (r*tmp1);
678             }
679
680             error = error - (r * r *error);
681             /* Bail out once we get 30 dB gain */
682             if (error < .001f * ac[0])
683                 break;
684         }
685     }
686 }
687
688 static void celt_fir5(const float *x,
689                       const float *num,
690                       float *y,
691                       int N,
692                       float *mem)
693 {
694     float num0, num1, num2, num3, num4;
695     float mem0, mem1, mem2, mem3, mem4;
696
697     num0 = num[0];
698     num1 = num[1];
699     num2 = num[2];
700     num3 = num[3];
701     num4 = num[4];
702     mem0 = mem[0];
703     mem1 = mem[1];
704     mem2 = mem[2];
705     mem3 = mem[3];
706     mem4 = mem[4];
707
708     for (int i = 0; i < N; i++) {
709         float sum = x[i];
710
711         sum += (num0*mem0);
712         sum += (num1*mem1);
713         sum += (num2*mem2);
714         sum += (num3*mem3);
715         sum += (num4*mem4);
716         mem4 = mem3;
717         mem3 = mem2;
718         mem2 = mem1;
719         mem1 = mem0;
720         mem0 = x[i];
721         y[i] = sum;
722     }
723
724     mem[0] = mem0;
725     mem[1] = mem1;
726     mem[2] = mem2;
727     mem[3] = mem3;
728     mem[4] = mem4;
729 }
730
731 static void pitch_downsample(float *x[], float *x_lp,
732                              int len, int C)
733 {
734     float ac[5];
735     float tmp=Q15ONE;
736     float lpc[4], mem[5]={0,0,0,0,0};
737     float lpc2[5];
738     float c1 = .8f;
739
740     for (int i = 1; i < len >> 1; i++)
741         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
742     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
743     if (C==2) {
744         for (int i = 1; i < len >> 1; i++)
745             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
746         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
747     }
748
749     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
750
751     /* Noise floor -40 dB */
752     ac[0] *= 1.0001f;
753     /* Lag windowing */
754     for (int i = 1; i <= 4; i++) {
755         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
756         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
757     }
758
759     celt_lpc(lpc, ac, 4);
760     for (int i = 0; i < 4; i++) {
761         tmp = .9f * tmp;
762         lpc[i] = (lpc[i] * tmp);
763     }
764     /* Add a zero */
765     lpc2[0] = lpc[0] + .8f;
766     lpc2[1] = lpc[1] + (c1 * lpc[0]);
767     lpc2[2] = lpc[2] + (c1 * lpc[1]);
768     lpc2[3] = lpc[3] + (c1 * lpc[2]);
769     lpc2[4] = (c1 * lpc[3]);
770     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
771 }
772
773 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
774                                    int N, float *xy1, float *xy2)
775 {
776     float xy01 = 0, xy02 = 0;
777
778     for (int i = 0; i < N; i++) {
779         xy01 += (x[i] * y01[i]);
780         xy02 += (x[i] * y02[i]);
781     }
782
783     *xy1 = xy01;
784     *xy2 = xy02;
785 }
786
787 static float compute_pitch_gain(float xy, float xx, float yy)
788 {
789     return xy / sqrtf(1.f + xx * yy);
790 }
791
792 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
793 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
794                              int *T0_, int prev_period, float prev_gain)
795 {
796     int k, i, T, T0;
797     float g, g0;
798     float pg;
799     float xy,xx,yy,xy2;
800     float xcorr[3];
801     float best_xy, best_yy;
802     int offset;
803     int minperiod0;
804     float yy_lookup[PITCH_MAX_PERIOD+1];
805
806     minperiod0 = minperiod;
807     maxperiod /= 2;
808     minperiod /= 2;
809     *T0_ /= 2;
810     prev_period /= 2;
811     N /= 2;
812     x += maxperiod;
813     if (*T0_>=maxperiod)
814         *T0_=maxperiod-1;
815
816     T = T0 = *T0_;
817     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
818     yy_lookup[0] = xx;
819     yy=xx;
820     for (i = 1; i <= maxperiod; i++) {
821         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
822         yy_lookup[i] = FFMAX(0, yy);
823     }
824     yy = yy_lookup[T0];
825     best_xy = xy;
826     best_yy = yy;
827     g = g0 = compute_pitch_gain(xy, xx, yy);
828     /* Look for any pitch at T/k */
829     for (k = 2; k <= 15; k++) {
830         int T1, T1b;
831         float g1;
832         float cont=0;
833         float thresh;
834         T1 = (2*T0+k)/(2*k);
835         if (T1 < minperiod)
836             break;
837         /* Look for another strong correlation at T1b */
838         if (k==2)
839         {
840             if (T1+T0>maxperiod)
841                 T1b = T0;
842             else
843                 T1b = T0+T1;
844         } else
845         {
846             T1b = (2*second_check[k]*T0+k)/(2*k);
847         }
848         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
849         xy = .5f * (xy + xy2);
850         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
851         g1 = compute_pitch_gain(xy, xx, yy);
852         if (FFABS(T1-prev_period)<=1)
853             cont = prev_gain;
854         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
855             cont = prev_gain * .5f;
856         else
857             cont = 0;
858         thresh = FFMAX(.3f, (.7f * g0) - cont);
859         /* Bias against very high pitch (very short period) to avoid false-positives
860            due to short-term correlation */
861         if (T1<3*minperiod)
862             thresh = FFMAX(.4f, (.85f * g0) - cont);
863         else if (T1<2*minperiod)
864             thresh = FFMAX(.5f, (.9f * g0) - cont);
865         if (g1 > thresh)
866         {
867             best_xy = xy;
868             best_yy = yy;
869             T = T1;
870             g = g1;
871         }
872     }
873     best_xy = FFMAX(0, best_xy);
874     if (best_yy <= best_xy)
875         pg = Q15ONE;
876     else
877         pg = best_xy/(best_yy + 1);
878
879     for (k = 0; k < 3; k++)
880         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
881     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
882         offset = 1;
883     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
884         offset = -1;
885     else
886         offset = 0;
887     if (pg > g)
888         pg = g;
889     *T0_ = 2*T+offset;
890
891     if (*T0_<minperiod0)
892         *T0_=minperiod0;
893     return pg;
894 }
895
896 static void find_best_pitch(float *xcorr, float *y, int len,
897                             int max_pitch, int *best_pitch)
898 {
899     float best_num[2];
900     float best_den[2];
901     float Syy = 1.f;
902
903     best_num[0] = -1;
904     best_num[1] = -1;
905     best_den[0] = 0;
906     best_den[1] = 0;
907     best_pitch[0] = 0;
908     best_pitch[1] = 1;
909
910     for (int j = 0; j < len; j++)
911         Syy += y[j] * y[j];
912
913     for (int i = 0; i < max_pitch; i++) {
914         if (xcorr[i]>0) {
915             float num;
916             float xcorr16;
917
918             xcorr16 = xcorr[i];
919             /* Considering the range of xcorr16, this should avoid both underflows
920                and overflows (inf) when squaring xcorr16 */
921             xcorr16 *= 1e-12f;
922             num = xcorr16 * xcorr16;
923             if ((num * best_den[1]) > (best_num[1] * Syy)) {
924                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
925                     best_num[1] = best_num[0];
926                     best_den[1] = best_den[0];
927                     best_pitch[1] = best_pitch[0];
928                     best_num[0] = num;
929                     best_den[0] = Syy;
930                     best_pitch[0] = i;
931                 } else {
932                     best_num[1] = num;
933                     best_den[1] = Syy;
934                     best_pitch[1] = i;
935                 }
936             }
937         }
938         Syy += y[i+len]*y[i+len] - y[i] * y[i];
939         Syy = FFMAX(1, Syy);
940     }
941 }
942
943 static void pitch_search(const float *x_lp, float *y,
944                          int len, int max_pitch, int *pitch)
945 {
946     int lag;
947     int best_pitch[2]={0,0};
948     int offset;
949
950     float x_lp4[WINDOW_SIZE];
951     float y_lp4[WINDOW_SIZE];
952     float xcorr[WINDOW_SIZE];
953
954     lag = len+max_pitch;
955
956     /* Downsample by 2 again */
957     for (int j = 0; j < len >> 2; j++)
958         x_lp4[j] = x_lp[2*j];
959     for (int j = 0; j < lag >> 2; j++)
960         y_lp4[j] = y[2*j];
961
962     /* Coarse search with 4x decimation */
963
964     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
965
966     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
967
968     /* Finer search with 2x decimation */
969     for (int i = 0; i < max_pitch >> 1; i++) {
970         float sum;
971         xcorr[i] = 0;
972         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
973             continue;
974         sum = celt_inner_prod(x_lp, y+i, len>>1);
975         xcorr[i] = FFMAX(-1, sum);
976     }
977
978     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
979
980     /* Refine by pseudo-interpolation */
981     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
982         float a, b, c;
983
984         a = xcorr[best_pitch[0] - 1];
985         b = xcorr[best_pitch[0]];
986         c = xcorr[best_pitch[0] + 1];
987         if (c - a > .7f * (b - a))
988             offset = 1;
989         else if (a - c > .7f * (b-c))
990             offset = -1;
991         else
992             offset = 0;
993     } else {
994         offset = 0;
995     }
996
997     *pitch = 2 * best_pitch[0] - offset;
998 }
999
1000 static void dct(AudioRNNContext *s, float *out, const float *in)
1001 {
1002     for (int i = 0; i < NB_BANDS; i++) {
1003         float sum;
1004
1005         sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1006         out[i] = sum * sqrtf(2.f / 22);
1007     }
1008 }
1009
1010 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1011                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1012 {
1013     float E = 0;
1014     float *ceps_0, *ceps_1, *ceps_2;
1015     float spec_variability = 0;
1016     LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1017     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1018     float pitch_buf[PITCH_BUF_SIZE>>1];
1019     int pitch_index;
1020     float gain;
1021     float *(pre[1]);
1022     float tmp[NB_BANDS];
1023     float follow, logMax;
1024
1025     frame_analysis(s, st, X, Ex, in);
1026     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1027     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1028     pre[0] = &st->pitch_buf[0];
1029     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1030     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1031             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1032     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1033
1034     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1035             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1036     st->last_period = pitch_index;
1037     st->last_gain = gain;
1038
1039     for (int i = 0; i < WINDOW_SIZE; i++)
1040         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1041
1042     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1043     forward_transform(st, P, p);
1044     compute_band_energy(Ep, P);
1045     compute_band_corr(Exp, X, P);
1046
1047     for (int i = 0; i < NB_BANDS; i++)
1048         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1049
1050     dct(s, tmp, Exp);
1051
1052     for (int i = 0; i < NB_DELTA_CEPS; i++)
1053         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1054
1055     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1056     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1057     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1058     logMax = -2;
1059     follow = -2;
1060
1061     for (int i = 0; i < NB_BANDS; i++) {
1062         Ly[i] = log10f(1e-2f + Ex[i]);
1063         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1064         logMax = FFMAX(logMax, Ly[i]);
1065         follow = FFMAX(follow-1.5, Ly[i]);
1066         E += Ex[i];
1067     }
1068
1069     if (E < 0.04f) {
1070         /* If there's no audio, avoid messing up the state. */
1071         RNN_CLEAR(features, NB_FEATURES);
1072         return 1;
1073     }
1074
1075     dct(s, features, Ly);
1076     features[0] -= 12;
1077     features[1] -= 4;
1078     ceps_0 = st->cepstral_mem[st->memid];
1079     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1080     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1081
1082     for (int i = 0; i < NB_BANDS; i++)
1083         ceps_0[i] = features[i];
1084
1085     st->memid++;
1086     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1087         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1088         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1089         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1090     }
1091     /* Spectral variability features. */
1092     if (st->memid == CEPS_MEM)
1093         st->memid = 0;
1094
1095     for (int i = 0; i < CEPS_MEM; i++) {
1096         float mindist = 1e15f;
1097         for (int j = 0; j < CEPS_MEM; j++) {
1098             float dist = 0.f;
1099             for (int k = 0; k < NB_BANDS; k++) {
1100                 float tmp;
1101
1102                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1103                 dist += tmp*tmp;
1104             }
1105
1106             if (j != i)
1107                 mindist = FFMIN(mindist, dist);
1108         }
1109
1110         spec_variability += mindist;
1111     }
1112
1113     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1114
1115     return 0;
1116 }
1117
1118 static void interp_band_gain(float *g, const float *bandE)
1119 {
1120     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1121
1122     for (int i = 0; i < NB_BANDS - 1; i++) {
1123         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1124
1125         for (int j = 0; j < band_size; j++) {
1126             float frac = (float)j / band_size;
1127
1128             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1129         }
1130     }
1131 }
1132
1133 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1134                          const float *Exp, const float *g)
1135 {
1136     float newE[NB_BANDS];
1137     float r[NB_BANDS];
1138     float norm[NB_BANDS];
1139     float rf[FREQ_SIZE] = {0};
1140     float normf[FREQ_SIZE]={0};
1141
1142     for (int i = 0; i < NB_BANDS; i++) {
1143         if (Exp[i]>g[i]) r[i] = 1;
1144         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1145         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1146         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1147     }
1148     interp_band_gain(rf, r);
1149     for (int i = 0; i < FREQ_SIZE; i++) {
1150         X[i].re += rf[i]*P[i].re;
1151         X[i].im += rf[i]*P[i].im;
1152     }
1153     compute_band_energy(newE, X);
1154     for (int i = 0; i < NB_BANDS; i++) {
1155         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1156     }
1157     interp_band_gain(normf, norm);
1158     for (int i = 0; i < FREQ_SIZE; i++) {
1159         X[i].re *= normf[i];
1160         X[i].im *= normf[i];
1161     }
1162 }
1163
1164 static const float tansig_table[201] = {
1165     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1166     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1167     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1168     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1169     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1170     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1171     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1172     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1173     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1174     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1175     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1176     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1177     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1178     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1179     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1180     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1181     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1182     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1183     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1184     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1185     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1186     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1187     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1188     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1189     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1190     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1191     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1192     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1193     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1194     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1195     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1196     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1197     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1198     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1199     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1200     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1201     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1202     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1203     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1204     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1205     1.000000f,
1206 };
1207
1208 static inline float tansig_approx(float x)
1209 {
1210     float y, dy;
1211     float sign=1;
1212     int i;
1213
1214     /* Tests are reversed to catch NaNs */
1215     if (!(x<8))
1216         return 1;
1217     if (!(x>-8))
1218         return -1;
1219     /* Another check in case of -ffast-math */
1220
1221     if (isnan(x))
1222        return 0;
1223
1224     if (x < 0) {
1225        x=-x;
1226        sign=-1;
1227     }
1228     i = (int)floor(.5f+25*x);
1229     x -= .04f*i;
1230     y = tansig_table[i];
1231     dy = 1-y*y;
1232     y = y + x*dy*(1 - y*x);
1233     return sign*y;
1234 }
1235
1236 static inline float sigmoid_approx(float x)
1237 {
1238     return .5f + .5f*tansig_approx(.5f*x);
1239 }
1240
1241 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1242 {
1243     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1244
1245     for (int i = 0; i < N; i++) {
1246         /* Compute update gate. */
1247         float sum = layer->bias[i];
1248
1249         for (int j = 0; j < M; j++)
1250             sum += layer->input_weights[j * stride + i] * input[j];
1251
1252         output[i] = WEIGHTS_SCALE * sum;
1253     }
1254
1255     if (layer->activation == ACTIVATION_SIGMOID) {
1256         for (int i = 0; i < N; i++)
1257             output[i] = sigmoid_approx(output[i]);
1258     } else if (layer->activation == ACTIVATION_TANH) {
1259         for (int i = 0; i < N; i++)
1260             output[i] = tansig_approx(output[i]);
1261     } else if (layer->activation == ACTIVATION_RELU) {
1262         for (int i = 0; i < N; i++)
1263             output[i] = FFMAX(0, output[i]);
1264     } else {
1265         av_assert0(0);
1266     }
1267 }
1268
1269 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1270 {
1271     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1272     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1273     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1274     const int M = gru->nb_inputs;
1275     const int N = gru->nb_neurons;
1276     const int AN = FFALIGN(N, 4);
1277     const int AM = FFALIGN(M, 4);
1278     const int stride = 3 * AN, istride = 3 * AM;
1279
1280     for (int i = 0; i < N; i++) {
1281         /* Compute update gate. */
1282         float sum = gru->bias[i];
1283
1284         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1285         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1286         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1287     }
1288
1289     for (int i = 0; i < N; i++) {
1290         /* Compute reset gate. */
1291         float sum = gru->bias[N + i];
1292
1293         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1294         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1295         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1296     }
1297
1298     for (int i = 0; i < N; i++) {
1299         /* Compute output. */
1300         float sum = gru->bias[2 * N + i];
1301
1302         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1303         for (int j = 0; j < N; j++)
1304             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1305
1306         if (gru->activation == ACTIVATION_SIGMOID)
1307             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1308         else if (gru->activation == ACTIVATION_TANH)
1309             sum = tansig_approx(WEIGHTS_SCALE * sum);
1310         else if (gru->activation == ACTIVATION_RELU)
1311             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1312         else
1313             av_assert0(0);
1314         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1315     }
1316
1317     RNN_COPY(state, h, N);
1318 }
1319
1320 #define INPUT_SIZE 42
1321
1322 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1323 {
1324     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1325     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1326     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1327
1328     compute_dense(rnn->model->input_dense, dense_out, input);
1329     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1330     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1331
1332     memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1333     memcpy(noise_input + rnn->model->input_dense_size,
1334            rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1335     memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1336            input, INPUT_SIZE * sizeof(float));
1337
1338     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1339
1340     memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1341     memcpy(denoise_input + rnn->model->vad_gru_size,
1342            rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1343     memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1344            input, INPUT_SIZE * sizeof(float));
1345
1346     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1347     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1348 }
1349
1350 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1351                              int disabled)
1352 {
1353     AVComplexFloat X[FREQ_SIZE];
1354     AVComplexFloat P[WINDOW_SIZE];
1355     float x[FRAME_SIZE];
1356     float Ex[NB_BANDS], Ep[NB_BANDS];
1357     LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1358     float features[NB_FEATURES];
1359     float g[NB_BANDS];
1360     float gf[FREQ_SIZE];
1361     float vad_prob = 0;
1362     float *history = st->history;
1363     static const float a_hp[2] = {-1.99599, 0.99600};
1364     static const float b_hp[2] = {-2, 1};
1365     int silence;
1366
1367     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1368     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1369
1370     if (!silence && !disabled) {
1371         compute_rnn(s, &st->rnn, g, &vad_prob, features);
1372         pitch_filter(X, P, Ex, Ep, Exp, g);
1373         for (int i = 0; i < NB_BANDS; i++) {
1374             float alpha = .6f;
1375
1376             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1377             st->lastg[i] = g[i];
1378         }
1379
1380         interp_band_gain(gf, g);
1381
1382         for (int i = 0; i < FREQ_SIZE; i++) {
1383             X[i].re *= gf[i];
1384             X[i].im *= gf[i];
1385         }
1386     }
1387
1388     frame_synthesis(s, st, out, X);
1389     memcpy(history, in, FRAME_SIZE * sizeof(*history));
1390
1391     return vad_prob;
1392 }
1393
1394 typedef struct ThreadData {
1395     AVFrame *in, *out;
1396 } ThreadData;
1397
1398 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1399 {
1400     AudioRNNContext *s = ctx->priv;
1401     ThreadData *td = arg;
1402     AVFrame *in = td->in;
1403     AVFrame *out = td->out;
1404     const int start = (out->channels * jobnr) / nb_jobs;
1405     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1406
1407     for (int ch = start; ch < end; ch++) {
1408         rnnoise_channel(s, &s->st[ch],
1409                         (float *)out->extended_data[ch],
1410                         (const float *)in->extended_data[ch],
1411                         ctx->is_disabled);
1412     }
1413
1414     return 0;
1415 }
1416
1417 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1418 {
1419     AVFilterContext *ctx = inlink->dst;
1420     AVFilterLink *outlink = ctx->outputs[0];
1421     AVFrame *out = NULL;
1422     ThreadData td;
1423
1424     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1425     if (!out) {
1426         av_frame_free(&in);
1427         return AVERROR(ENOMEM);
1428     }
1429     out->pts = in->pts;
1430
1431     td.in = in; td.out = out;
1432     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1433                                                                    ff_filter_get_nb_threads(ctx)));
1434
1435     av_frame_free(&in);
1436     return ff_filter_frame(outlink, out);
1437 }
1438
1439 static int activate(AVFilterContext *ctx)
1440 {
1441     AVFilterLink *inlink = ctx->inputs[0];
1442     AVFilterLink *outlink = ctx->outputs[0];
1443     AVFrame *in = NULL;
1444     int ret;
1445
1446     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1447
1448     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1449     if (ret < 0)
1450         return ret;
1451
1452     if (ret > 0)
1453         return filter_frame(inlink, in);
1454
1455     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1456     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1457
1458     return FFERROR_NOT_READY;
1459 }
1460
1461 static av_cold int init(AVFilterContext *ctx)
1462 {
1463     AudioRNNContext *s = ctx->priv;
1464     FILE *f;
1465
1466     s->fdsp = avpriv_float_dsp_alloc(0);
1467     if (!s->fdsp)
1468         return AVERROR(ENOMEM);
1469
1470     if (!s->model_name)
1471         return AVERROR(EINVAL);
1472     f = av_fopen_utf8(s->model_name, "r");
1473     if (!f)
1474         return AVERROR(EINVAL);
1475
1476     s->model = rnnoise_model_from_file(f);
1477     fclose(f);
1478     if (!s->model)
1479         return AVERROR(EINVAL);
1480
1481     for (int i = 0; i < FRAME_SIZE; i++) {
1482         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1483         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1484     }
1485
1486     for (int i = 0; i < NB_BANDS; i++) {
1487         for (int j = 0; j < NB_BANDS; j++) {
1488             s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1489             if (j == 0)
1490                 s->dct_table[j][i] *= sqrtf(.5);
1491         }
1492     }
1493
1494     return 0;
1495 }
1496
1497 static av_cold void uninit(AVFilterContext *ctx)
1498 {
1499     AudioRNNContext *s = ctx->priv;
1500
1501     av_freep(&s->fdsp);
1502     rnnoise_model_free(s->model);
1503     s->model = NULL;
1504
1505     if (s->st) {
1506         for (int ch = 0; ch < s->channels; ch++) {
1507             av_freep(&s->st[ch].rnn.vad_gru_state);
1508             av_freep(&s->st[ch].rnn.noise_gru_state);
1509             av_freep(&s->st[ch].rnn.denoise_gru_state);
1510             av_tx_uninit(&s->st[ch].tx);
1511             av_tx_uninit(&s->st[ch].txi);
1512         }
1513     }
1514     av_freep(&s->st);
1515 }
1516
1517 static const AVFilterPad inputs[] = {
1518     {
1519         .name         = "default",
1520         .type         = AVMEDIA_TYPE_AUDIO,
1521         .config_props = config_input,
1522     },
1523     { NULL }
1524 };
1525
1526 static const AVFilterPad outputs[] = {
1527     {
1528         .name          = "default",
1529         .type          = AVMEDIA_TYPE_AUDIO,
1530     },
1531     { NULL }
1532 };
1533
1534 #define OFFSET(x) offsetof(AudioRNNContext, x)
1535 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1536
1537 static const AVOption arnndn_options[] = {
1538     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1539     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1540     { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1541     { NULL }
1542 };
1543
1544 AVFILTER_DEFINE_CLASS(arnndn);
1545
1546 AVFilter ff_af_arnndn = {
1547     .name          = "arnndn",
1548     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1549     .query_formats = query_formats,
1550     .priv_size     = sizeof(AudioRNNContext),
1551     .priv_class    = &arnndn_class,
1552     .activate      = activate,
1553     .init          = init,
1554     .uninit        = uninit,
1555     .inputs        = inputs,
1556     .outputs       = outputs,
1557     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1558                      AVFILTER_FLAG_SLICE_THREADS,
1559 };