]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
avfilter/af_arnndn: add support for commands
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
42 #include "avfilter.h"
43 #include "audio.h"
44 #include "filters.h"
45 #include "formats.h"
46
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
51
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56
57 #define SQUARE(x) ((x)*(x))
58
59 #define NB_BANDS 22
60
61 #define CEPS_MEM 8
62 #define NB_DELTA_CEPS 6
63
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65
66 #define WEIGHTS_SCALE (1.f/256)
67
68 #define MAX_NEURONS 128
69
70 #define ACTIVATION_TANH    0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU    2
73
74 #define Q15ONE 1.0f
75
76 typedef struct DenseLayer {
77     const float *bias;
78     const float *input_weights;
79     int nb_inputs;
80     int nb_neurons;
81     int activation;
82 } DenseLayer;
83
84 typedef struct GRULayer {
85     const float *bias;
86     const float *input_weights;
87     const float *recurrent_weights;
88     int nb_inputs;
89     int nb_neurons;
90     int activation;
91 } GRULayer;
92
93 typedef struct RNNModel {
94     int input_dense_size;
95     const DenseLayer *input_dense;
96
97     int vad_gru_size;
98     const GRULayer *vad_gru;
99
100     int noise_gru_size;
101     const GRULayer *noise_gru;
102
103     int denoise_gru_size;
104     const GRULayer *denoise_gru;
105
106     int denoise_output_size;
107     const DenseLayer *denoise_output;
108
109     int vad_output_size;
110     const DenseLayer *vad_output;
111 } RNNModel;
112
113 typedef struct RNNState {
114     float *vad_gru_state;
115     float *noise_gru_state;
116     float *denoise_gru_state;
117     RNNModel *model;
118 } RNNState;
119
120 typedef struct DenoiseState {
121     float analysis_mem[FRAME_SIZE];
122     float cepstral_mem[CEPS_MEM][NB_BANDS];
123     int memid;
124     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125     float pitch_buf[PITCH_BUF_SIZE];
126     float pitch_enh_buf[PITCH_BUF_SIZE];
127     float last_gain;
128     int last_period;
129     float mem_hp_x[2];
130     float lastg[NB_BANDS];
131     float history[FRAME_SIZE];
132     RNNState rnn[2];
133     AVTXContext *tx, *txi;
134     av_tx_fn tx_fn, txi_fn;
135 } DenoiseState;
136
137 typedef struct AudioRNNContext {
138     const AVClass *class;
139
140     char *model_name;
141     float mix;
142
143     int channels;
144     DenoiseState *st;
145
146     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147     DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148
149     RNNModel *model[2];
150
151     AVFloatDSPContext *fdsp;
152 } AudioRNNContext;
153
154 #define F_ACTIVATION_TANH       0
155 #define F_ACTIVATION_SIGMOID    1
156 #define F_ACTIVATION_RELU       2
157
158 static void rnnoise_model_free(RNNModel *model)
159 {
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
162     if (model->name) { \
163         av_free((void *) model->name->input_weights); \
164         av_free((void *) model->name->bias); \
165         av_free((void *) model->name); \
166     } \
167     } while (0)
168 #define FREE_GRU(name) do { \
169     if (model->name) { \
170         av_free((void *) model->name->input_weights); \
171         av_free((void *) model->name->recurrent_weights); \
172         av_free((void *) model->name->bias); \
173         av_free((void *) model->name); \
174     } \
175     } while (0)
176
177     if (!model)
178         return;
179     FREE_DENSE(input_dense);
180     FREE_GRU(vad_gru);
181     FREE_GRU(noise_gru);
182     FREE_GRU(denoise_gru);
183     FREE_DENSE(denoise_output);
184     FREE_DENSE(vad_output);
185     av_free(model);
186 }
187
188 static RNNModel *rnnoise_model_from_file(FILE *f)
189 {
190     RNNModel *ret;
191     DenseLayer *input_dense;
192     GRULayer *vad_gru;
193     GRULayer *noise_gru;
194     GRULayer *denoise_gru;
195     DenseLayer *denoise_output;
196     DenseLayer *vad_output;
197     int in;
198
199     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200         return NULL;
201
202     ret = av_calloc(1, sizeof(RNNModel));
203     if (!ret)
204         return NULL;
205
206 #define ALLOC_LAYER(type, name) \
207     name = av_calloc(1, sizeof(type)); \
208     if (!name) { \
209         rnnoise_model_free(ret); \
210         return NULL; \
211     } \
212     ret->name = name
213
214     ALLOC_LAYER(DenseLayer, input_dense);
215     ALLOC_LAYER(GRULayer, vad_gru);
216     ALLOC_LAYER(GRULayer, noise_gru);
217     ALLOC_LAYER(GRULayer, denoise_gru);
218     ALLOC_LAYER(DenseLayer, denoise_output);
219     ALLOC_LAYER(DenseLayer, vad_output);
220
221 #define INPUT_VAL(name) do { \
222     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223         rnnoise_model_free(ret); \
224         return NULL; \
225     } \
226     name = in; \
227     } while (0)
228
229 #define INPUT_ACTIVATION(name) do { \
230     int activation; \
231     INPUT_VAL(activation); \
232     switch (activation) { \
233     case F_ACTIVATION_SIGMOID: \
234         name = ACTIVATION_SIGMOID; \
235         break; \
236     case F_ACTIVATION_RELU: \
237         name = ACTIVATION_RELU; \
238         break; \
239     default: \
240         name = ACTIVATION_TANH; \
241     } \
242     } while (0)
243
244 #define INPUT_ARRAY(name, len) do { \
245     float *values = av_calloc((len), sizeof(float)); \
246     if (!values) { \
247         rnnoise_model_free(ret); \
248         return NULL; \
249     } \
250     name = values; \
251     for (int i = 0; i < (len); i++) { \
252         if (fscanf(f, "%d", &in) != 1) { \
253             rnnoise_model_free(ret); \
254             return NULL; \
255         } \
256         values[i] = in; \
257     } \
258     } while (0)
259
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262     if (!values) { \
263         rnnoise_model_free(ret); \
264         return NULL; \
265     } \
266     name = values; \
267     for (int k = 0; k < (len0); k++) { \
268         for (int i = 0; i < (len2); i++) { \
269             for (int j = 0; j < (len1); j++) { \
270                 if (fscanf(f, "%d", &in) != 1) { \
271                     rnnoise_model_free(ret); \
272                     return NULL; \
273                 } \
274                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275             } \
276         } \
277     } \
278     } while (0)
279
280 #define INPUT_DENSE(name) do { \
281     INPUT_VAL(name->nb_inputs); \
282     INPUT_VAL(name->nb_neurons); \
283     ret->name ## _size = name->nb_neurons; \
284     INPUT_ACTIVATION(name->activation); \
285     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
286     INPUT_ARRAY(name->bias, name->nb_neurons); \
287     } while (0)
288
289 #define INPUT_GRU(name) do { \
290     INPUT_VAL(name->nb_inputs); \
291     INPUT_VAL(name->nb_neurons); \
292     ret->name ## _size = name->nb_neurons; \
293     INPUT_ACTIVATION(name->activation); \
294     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
295     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
296     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
297     } while (0)
298
299     INPUT_DENSE(input_dense);
300     INPUT_GRU(vad_gru);
301     INPUT_GRU(noise_gru);
302     INPUT_GRU(denoise_gru);
303     INPUT_DENSE(denoise_output);
304     INPUT_DENSE(vad_output);
305
306     if (vad_output->nb_neurons != 1) {
307         rnnoise_model_free(ret);
308         return NULL;
309     }
310
311     return ret;
312 }
313
314 static int query_formats(AVFilterContext *ctx)
315 {
316     AVFilterFormats *formats = NULL;
317     AVFilterChannelLayouts *layouts = NULL;
318     static const enum AVSampleFormat sample_fmts[] = {
319         AV_SAMPLE_FMT_FLTP,
320         AV_SAMPLE_FMT_NONE
321     };
322     int ret, sample_rates[] = { 48000, -1 };
323
324     formats = ff_make_format_list(sample_fmts);
325     if (!formats)
326         return AVERROR(ENOMEM);
327     ret = ff_set_common_formats(ctx, formats);
328     if (ret < 0)
329         return ret;
330
331     layouts = ff_all_channel_counts();
332     if (!layouts)
333         return AVERROR(ENOMEM);
334
335     ret = ff_set_common_channel_layouts(ctx, layouts);
336     if (ret < 0)
337         return ret;
338
339     formats = ff_make_format_list(sample_rates);
340     if (!formats)
341         return AVERROR(ENOMEM);
342     return ff_set_common_samplerates(ctx, formats);
343 }
344
345 static int config_input(AVFilterLink *inlink)
346 {
347     AVFilterContext *ctx = inlink->dst;
348     AudioRNNContext *s = ctx->priv;
349     int ret;
350
351     s->channels = inlink->channels;
352
353     if (!s->st)
354         s->st = av_calloc(s->channels, sizeof(DenoiseState));
355     if (!s->st)
356         return AVERROR(ENOMEM);
357
358     for (int i = 0; i < s->channels; i++) {
359         DenoiseState *st = &s->st[i];
360
361         st->rnn[0].model = s->model[0];
362         st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
363         st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
364         st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
365         if (!st->rnn[0].vad_gru_state ||
366             !st->rnn[0].noise_gru_state ||
367             !st->rnn[0].denoise_gru_state)
368             return AVERROR(ENOMEM);
369     }
370
371     for (int i = 0; i < s->channels; i++) {
372         DenoiseState *st = &s->st[i];
373
374         if (!st->tx)
375             ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
376         if (ret < 0)
377             return ret;
378
379         if (!st->txi)
380             ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
381         if (ret < 0)
382             return ret;
383     }
384
385     return 0;
386 }
387
388 static void biquad(float *y, float mem[2], const float *x,
389                    const float *b, const float *a, int N)
390 {
391     for (int i = 0; i < N; i++) {
392         float xi, yi;
393
394         xi = x[i];
395         yi = x[i] + mem[0];
396         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
397         mem[1] = (b[1]*xi - a[1]*yi);
398         y[i] = yi;
399     }
400 }
401
402 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
403 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
404 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
405
406 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
407 {
408     AVComplexFloat x[WINDOW_SIZE];
409     AVComplexFloat y[WINDOW_SIZE];
410
411     for (int i = 0; i < WINDOW_SIZE; i++) {
412         x[i].re = in[i];
413         x[i].im = 0;
414     }
415
416     st->tx_fn(st->tx, y, x, sizeof(float));
417
418     RNN_COPY(out, y, FREQ_SIZE);
419 }
420
421 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
422 {
423     AVComplexFloat x[WINDOW_SIZE];
424     AVComplexFloat y[WINDOW_SIZE];
425
426     RNN_COPY(x, in, FREQ_SIZE);
427
428     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
429         x[i].re =  x[WINDOW_SIZE - i].re;
430         x[i].im = -x[WINDOW_SIZE - i].im;
431     }
432
433     st->txi_fn(st->txi, y, x, sizeof(float));
434
435     for (int i = 0; i < WINDOW_SIZE; i++)
436         out[i] = y[i].re / WINDOW_SIZE;
437 }
438
439 static const uint8_t eband5ms[] = {
440 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
441   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
442 };
443
444 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
445 {
446     float sum[NB_BANDS] = {0};
447
448     for (int i = 0; i < NB_BANDS - 1; i++) {
449         int band_size;
450
451         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
452         for (int j = 0; j < band_size; j++) {
453             float tmp, frac = (float)j / band_size;
454
455             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
456             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
457             sum[i]     += (1.f - frac) * tmp;
458             sum[i + 1] +=        frac  * tmp;
459         }
460     }
461
462     sum[0] *= 2;
463     sum[NB_BANDS - 1] *= 2;
464
465     for (int i = 0; i < NB_BANDS; i++)
466         bandE[i] = sum[i];
467 }
468
469 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
470 {
471     float sum[NB_BANDS] = { 0 };
472
473     for (int i = 0; i < NB_BANDS - 1; i++) {
474         int band_size;
475
476         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
477         for (int j = 0; j < band_size; j++) {
478             float tmp, frac = (float)j / band_size;
479
480             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
481             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
482             sum[i]     += (1 - frac) * tmp;
483             sum[i + 1] +=      frac  * tmp;
484         }
485     }
486
487     sum[0] *= 2;
488     sum[NB_BANDS-1] *= 2;
489
490     for (int i = 0; i < NB_BANDS; i++)
491         bandE[i] = sum[i];
492 }
493
494 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
495 {
496     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
497
498     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
499     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
500     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
501     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
502     forward_transform(st, X, x);
503     compute_band_energy(Ex, X);
504 }
505
506 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
507 {
508     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
509     const float *src = st->history;
510     const float mix = s->mix;
511     const float imix = 1.f - FFMAX(mix, 0.f);
512
513     inverse_transform(st, x, y);
514     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
515     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
516     RNN_COPY(out, x, FRAME_SIZE);
517     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
518
519     for (int n = 0; n < FRAME_SIZE; n++)
520         out[n] = out[n] * mix + src[n] * imix;
521 }
522
523 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
524 {
525     float y_0, y_1, y_2, y_3 = 0;
526     int j;
527
528     y_0 = *y++;
529     y_1 = *y++;
530     y_2 = *y++;
531
532     for (j = 0; j < len - 3; j += 4) {
533         float tmp;
534
535         tmp = *x++;
536         y_3 = *y++;
537         sum[0] += tmp * y_0;
538         sum[1] += tmp * y_1;
539         sum[2] += tmp * y_2;
540         sum[3] += tmp * y_3;
541         tmp = *x++;
542         y_0 = *y++;
543         sum[0] += tmp * y_1;
544         sum[1] += tmp * y_2;
545         sum[2] += tmp * y_3;
546         sum[3] += tmp * y_0;
547         tmp = *x++;
548         y_1 = *y++;
549         sum[0] += tmp * y_2;
550         sum[1] += tmp * y_3;
551         sum[2] += tmp * y_0;
552         sum[3] += tmp * y_1;
553         tmp = *x++;
554         y_2 = *y++;
555         sum[0] += tmp * y_3;
556         sum[1] += tmp * y_0;
557         sum[2] += tmp * y_1;
558         sum[3] += tmp * y_2;
559     }
560
561     if (j++ < len) {
562         float tmp = *x++;
563
564         y_3 = *y++;
565         sum[0] += tmp * y_0;
566         sum[1] += tmp * y_1;
567         sum[2] += tmp * y_2;
568         sum[3] += tmp * y_3;
569     }
570
571     if (j++ < len) {
572         float tmp=*x++;
573
574         y_0 = *y++;
575         sum[0] += tmp * y_1;
576         sum[1] += tmp * y_2;
577         sum[2] += tmp * y_3;
578         sum[3] += tmp * y_0;
579     }
580
581     if (j < len) {
582         float tmp=*x++;
583
584         y_1 = *y++;
585         sum[0] += tmp * y_2;
586         sum[1] += tmp * y_3;
587         sum[2] += tmp * y_0;
588         sum[3] += tmp * y_1;
589     }
590 }
591
592 static inline float celt_inner_prod(const float *x,
593                                     const float *y, int N)
594 {
595     float xy = 0.f;
596
597     for (int i = 0; i < N; i++)
598         xy += x[i] * y[i];
599
600     return xy;
601 }
602
603 static void celt_pitch_xcorr(const float *x, const float *y,
604                              float *xcorr, int len, int max_pitch)
605 {
606     int i;
607
608     for (i = 0; i < max_pitch - 3; i += 4) {
609         float sum[4] = { 0, 0, 0, 0};
610
611         xcorr_kernel(x, y + i, sum, len);
612
613         xcorr[i]     = sum[0];
614         xcorr[i + 1] = sum[1];
615         xcorr[i + 2] = sum[2];
616         xcorr[i + 3] = sum[3];
617     }
618     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
619     for (; i < max_pitch; i++) {
620         xcorr[i] = celt_inner_prod(x, y + i, len);
621     }
622 }
623
624 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
625                          float       *ac,  /* out: [0...lag-1] ac values */
626                          const float *window,
627                          int          overlap,
628                          int          lag,
629                          int          n)
630 {
631     int fastN = n - lag;
632     int shift;
633     const float *xptr;
634     float xx[PITCH_BUF_SIZE>>1];
635
636     if (overlap == 0) {
637         xptr = x;
638     } else {
639         for (int i = 0; i < n; i++)
640             xx[i] = x[i];
641         for (int i = 0; i < overlap; i++) {
642             xx[i] = x[i] * window[i];
643             xx[n-i-1] = x[n-i-1] * window[i];
644         }
645         xptr = xx;
646     }
647
648     shift = 0;
649     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
650
651     for (int k = 0; k <= lag; k++) {
652         float d = 0.f;
653
654         for (int i = k + fastN; i < n; i++)
655             d += xptr[i] * xptr[i-k];
656         ac[k] += d;
657     }
658
659     return shift;
660 }
661
662 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
663                 const float *ac,   /* in:  [0...p] autocorrelation values  */
664                           int p)
665 {
666     float r, error = ac[0];
667
668     RNN_CLEAR(lpc, p);
669     if (ac[0] != 0) {
670         for (int i = 0; i < p; i++) {
671             /* Sum up this iteration's reflection coefficient */
672             float rr = 0;
673             for (int j = 0; j < i; j++)
674                 rr += (lpc[j] * ac[i - j]);
675             rr += ac[i + 1];
676             r = -rr/error;
677             /*  Update LPC coefficients and total error */
678             lpc[i] = r;
679             for (int j = 0; j < (i + 1) >> 1; j++) {
680                 float tmp1, tmp2;
681                 tmp1 = lpc[j];
682                 tmp2 = lpc[i-1-j];
683                 lpc[j]     = tmp1 + (r*tmp2);
684                 lpc[i-1-j] = tmp2 + (r*tmp1);
685             }
686
687             error = error - (r * r *error);
688             /* Bail out once we get 30 dB gain */
689             if (error < .001f * ac[0])
690                 break;
691         }
692     }
693 }
694
695 static void celt_fir5(const float *x,
696                       const float *num,
697                       float *y,
698                       int N,
699                       float *mem)
700 {
701     float num0, num1, num2, num3, num4;
702     float mem0, mem1, mem2, mem3, mem4;
703
704     num0 = num[0];
705     num1 = num[1];
706     num2 = num[2];
707     num3 = num[3];
708     num4 = num[4];
709     mem0 = mem[0];
710     mem1 = mem[1];
711     mem2 = mem[2];
712     mem3 = mem[3];
713     mem4 = mem[4];
714
715     for (int i = 0; i < N; i++) {
716         float sum = x[i];
717
718         sum += (num0*mem0);
719         sum += (num1*mem1);
720         sum += (num2*mem2);
721         sum += (num3*mem3);
722         sum += (num4*mem4);
723         mem4 = mem3;
724         mem3 = mem2;
725         mem2 = mem1;
726         mem1 = mem0;
727         mem0 = x[i];
728         y[i] = sum;
729     }
730
731     mem[0] = mem0;
732     mem[1] = mem1;
733     mem[2] = mem2;
734     mem[3] = mem3;
735     mem[4] = mem4;
736 }
737
738 static void pitch_downsample(float *x[], float *x_lp,
739                              int len, int C)
740 {
741     float ac[5];
742     float tmp=Q15ONE;
743     float lpc[4], mem[5]={0,0,0,0,0};
744     float lpc2[5];
745     float c1 = .8f;
746
747     for (int i = 1; i < len >> 1; i++)
748         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
749     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
750     if (C==2) {
751         for (int i = 1; i < len >> 1; i++)
752             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
753         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
754     }
755
756     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
757
758     /* Noise floor -40 dB */
759     ac[0] *= 1.0001f;
760     /* Lag windowing */
761     for (int i = 1; i <= 4; i++) {
762         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
763         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
764     }
765
766     celt_lpc(lpc, ac, 4);
767     for (int i = 0; i < 4; i++) {
768         tmp = .9f * tmp;
769         lpc[i] = (lpc[i] * tmp);
770     }
771     /* Add a zero */
772     lpc2[0] = lpc[0] + .8f;
773     lpc2[1] = lpc[1] + (c1 * lpc[0]);
774     lpc2[2] = lpc[2] + (c1 * lpc[1]);
775     lpc2[3] = lpc[3] + (c1 * lpc[2]);
776     lpc2[4] = (c1 * lpc[3]);
777     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
778 }
779
780 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
781                                    int N, float *xy1, float *xy2)
782 {
783     float xy01 = 0, xy02 = 0;
784
785     for (int i = 0; i < N; i++) {
786         xy01 += (x[i] * y01[i]);
787         xy02 += (x[i] * y02[i]);
788     }
789
790     *xy1 = xy01;
791     *xy2 = xy02;
792 }
793
794 static float compute_pitch_gain(float xy, float xx, float yy)
795 {
796     return xy / sqrtf(1.f + xx * yy);
797 }
798
799 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
800 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
801                              int *T0_, int prev_period, float prev_gain)
802 {
803     int k, i, T, T0;
804     float g, g0;
805     float pg;
806     float xy,xx,yy,xy2;
807     float xcorr[3];
808     float best_xy, best_yy;
809     int offset;
810     int minperiod0;
811     float yy_lookup[PITCH_MAX_PERIOD+1];
812
813     minperiod0 = minperiod;
814     maxperiod /= 2;
815     minperiod /= 2;
816     *T0_ /= 2;
817     prev_period /= 2;
818     N /= 2;
819     x += maxperiod;
820     if (*T0_>=maxperiod)
821         *T0_=maxperiod-1;
822
823     T = T0 = *T0_;
824     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
825     yy_lookup[0] = xx;
826     yy=xx;
827     for (i = 1; i <= maxperiod; i++) {
828         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
829         yy_lookup[i] = FFMAX(0, yy);
830     }
831     yy = yy_lookup[T0];
832     best_xy = xy;
833     best_yy = yy;
834     g = g0 = compute_pitch_gain(xy, xx, yy);
835     /* Look for any pitch at T/k */
836     for (k = 2; k <= 15; k++) {
837         int T1, T1b;
838         float g1;
839         float cont=0;
840         float thresh;
841         T1 = (2*T0+k)/(2*k);
842         if (T1 < minperiod)
843             break;
844         /* Look for another strong correlation at T1b */
845         if (k==2)
846         {
847             if (T1+T0>maxperiod)
848                 T1b = T0;
849             else
850                 T1b = T0+T1;
851         } else
852         {
853             T1b = (2*second_check[k]*T0+k)/(2*k);
854         }
855         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
856         xy = .5f * (xy + xy2);
857         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
858         g1 = compute_pitch_gain(xy, xx, yy);
859         if (FFABS(T1-prev_period)<=1)
860             cont = prev_gain;
861         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
862             cont = prev_gain * .5f;
863         else
864             cont = 0;
865         thresh = FFMAX(.3f, (.7f * g0) - cont);
866         /* Bias against very high pitch (very short period) to avoid false-positives
867            due to short-term correlation */
868         if (T1<3*minperiod)
869             thresh = FFMAX(.4f, (.85f * g0) - cont);
870         else if (T1<2*minperiod)
871             thresh = FFMAX(.5f, (.9f * g0) - cont);
872         if (g1 > thresh)
873         {
874             best_xy = xy;
875             best_yy = yy;
876             T = T1;
877             g = g1;
878         }
879     }
880     best_xy = FFMAX(0, best_xy);
881     if (best_yy <= best_xy)
882         pg = Q15ONE;
883     else
884         pg = best_xy/(best_yy + 1);
885
886     for (k = 0; k < 3; k++)
887         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
888     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
889         offset = 1;
890     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
891         offset = -1;
892     else
893         offset = 0;
894     if (pg > g)
895         pg = g;
896     *T0_ = 2*T+offset;
897
898     if (*T0_<minperiod0)
899         *T0_=minperiod0;
900     return pg;
901 }
902
903 static void find_best_pitch(float *xcorr, float *y, int len,
904                             int max_pitch, int *best_pitch)
905 {
906     float best_num[2];
907     float best_den[2];
908     float Syy = 1.f;
909
910     best_num[0] = -1;
911     best_num[1] = -1;
912     best_den[0] = 0;
913     best_den[1] = 0;
914     best_pitch[0] = 0;
915     best_pitch[1] = 1;
916
917     for (int j = 0; j < len; j++)
918         Syy += y[j] * y[j];
919
920     for (int i = 0; i < max_pitch; i++) {
921         if (xcorr[i]>0) {
922             float num;
923             float xcorr16;
924
925             xcorr16 = xcorr[i];
926             /* Considering the range of xcorr16, this should avoid both underflows
927                and overflows (inf) when squaring xcorr16 */
928             xcorr16 *= 1e-12f;
929             num = xcorr16 * xcorr16;
930             if ((num * best_den[1]) > (best_num[1] * Syy)) {
931                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
932                     best_num[1] = best_num[0];
933                     best_den[1] = best_den[0];
934                     best_pitch[1] = best_pitch[0];
935                     best_num[0] = num;
936                     best_den[0] = Syy;
937                     best_pitch[0] = i;
938                 } else {
939                     best_num[1] = num;
940                     best_den[1] = Syy;
941                     best_pitch[1] = i;
942                 }
943             }
944         }
945         Syy += y[i+len]*y[i+len] - y[i] * y[i];
946         Syy = FFMAX(1, Syy);
947     }
948 }
949
950 static void pitch_search(const float *x_lp, float *y,
951                          int len, int max_pitch, int *pitch)
952 {
953     int lag;
954     int best_pitch[2]={0,0};
955     int offset;
956
957     float x_lp4[WINDOW_SIZE];
958     float y_lp4[WINDOW_SIZE];
959     float xcorr[WINDOW_SIZE];
960
961     lag = len+max_pitch;
962
963     /* Downsample by 2 again */
964     for (int j = 0; j < len >> 2; j++)
965         x_lp4[j] = x_lp[2*j];
966     for (int j = 0; j < lag >> 2; j++)
967         y_lp4[j] = y[2*j];
968
969     /* Coarse search with 4x decimation */
970
971     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
972
973     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
974
975     /* Finer search with 2x decimation */
976     for (int i = 0; i < max_pitch >> 1; i++) {
977         float sum;
978         xcorr[i] = 0;
979         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
980             continue;
981         sum = celt_inner_prod(x_lp, y+i, len>>1);
982         xcorr[i] = FFMAX(-1, sum);
983     }
984
985     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
986
987     /* Refine by pseudo-interpolation */
988     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
989         float a, b, c;
990
991         a = xcorr[best_pitch[0] - 1];
992         b = xcorr[best_pitch[0]];
993         c = xcorr[best_pitch[0] + 1];
994         if (c - a > .7f * (b - a))
995             offset = 1;
996         else if (a - c > .7f * (b-c))
997             offset = -1;
998         else
999             offset = 0;
1000     } else {
1001         offset = 0;
1002     }
1003
1004     *pitch = 2 * best_pitch[0] - offset;
1005 }
1006
1007 static void dct(AudioRNNContext *s, float *out, const float *in)
1008 {
1009     for (int i = 0; i < NB_BANDS; i++) {
1010         float sum;
1011
1012         sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1013         out[i] = sum * sqrtf(2.f / 22);
1014     }
1015 }
1016
1017 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1018                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1019 {
1020     float E = 0;
1021     float *ceps_0, *ceps_1, *ceps_2;
1022     float spec_variability = 0;
1023     LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1024     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1025     float pitch_buf[PITCH_BUF_SIZE>>1];
1026     int pitch_index;
1027     float gain;
1028     float *(pre[1]);
1029     float tmp[NB_BANDS];
1030     float follow, logMax;
1031
1032     frame_analysis(s, st, X, Ex, in);
1033     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1034     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1035     pre[0] = &st->pitch_buf[0];
1036     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1037     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1038             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1039     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1040
1041     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1042             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1043     st->last_period = pitch_index;
1044     st->last_gain = gain;
1045
1046     for (int i = 0; i < WINDOW_SIZE; i++)
1047         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1048
1049     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1050     forward_transform(st, P, p);
1051     compute_band_energy(Ep, P);
1052     compute_band_corr(Exp, X, P);
1053
1054     for (int i = 0; i < NB_BANDS; i++)
1055         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1056
1057     dct(s, tmp, Exp);
1058
1059     for (int i = 0; i < NB_DELTA_CEPS; i++)
1060         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1061
1062     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1063     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1064     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1065     logMax = -2;
1066     follow = -2;
1067
1068     for (int i = 0; i < NB_BANDS; i++) {
1069         Ly[i] = log10f(1e-2f + Ex[i]);
1070         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1071         logMax = FFMAX(logMax, Ly[i]);
1072         follow = FFMAX(follow-1.5, Ly[i]);
1073         E += Ex[i];
1074     }
1075
1076     if (E < 0.04f) {
1077         /* If there's no audio, avoid messing up the state. */
1078         RNN_CLEAR(features, NB_FEATURES);
1079         return 1;
1080     }
1081
1082     dct(s, features, Ly);
1083     features[0] -= 12;
1084     features[1] -= 4;
1085     ceps_0 = st->cepstral_mem[st->memid];
1086     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1087     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1088
1089     for (int i = 0; i < NB_BANDS; i++)
1090         ceps_0[i] = features[i];
1091
1092     st->memid++;
1093     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1094         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1095         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1096         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1097     }
1098     /* Spectral variability features. */
1099     if (st->memid == CEPS_MEM)
1100         st->memid = 0;
1101
1102     for (int i = 0; i < CEPS_MEM; i++) {
1103         float mindist = 1e15f;
1104         for (int j = 0; j < CEPS_MEM; j++) {
1105             float dist = 0.f;
1106             for (int k = 0; k < NB_BANDS; k++) {
1107                 float tmp;
1108
1109                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1110                 dist += tmp*tmp;
1111             }
1112
1113             if (j != i)
1114                 mindist = FFMIN(mindist, dist);
1115         }
1116
1117         spec_variability += mindist;
1118     }
1119
1120     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1121
1122     return 0;
1123 }
1124
1125 static void interp_band_gain(float *g, const float *bandE)
1126 {
1127     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1128
1129     for (int i = 0; i < NB_BANDS - 1; i++) {
1130         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1131
1132         for (int j = 0; j < band_size; j++) {
1133             float frac = (float)j / band_size;
1134
1135             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1136         }
1137     }
1138 }
1139
1140 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1141                          const float *Exp, const float *g)
1142 {
1143     float newE[NB_BANDS];
1144     float r[NB_BANDS];
1145     float norm[NB_BANDS];
1146     float rf[FREQ_SIZE] = {0};
1147     float normf[FREQ_SIZE]={0};
1148
1149     for (int i = 0; i < NB_BANDS; i++) {
1150         if (Exp[i]>g[i]) r[i] = 1;
1151         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1152         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1153         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1154     }
1155     interp_band_gain(rf, r);
1156     for (int i = 0; i < FREQ_SIZE; i++) {
1157         X[i].re += rf[i]*P[i].re;
1158         X[i].im += rf[i]*P[i].im;
1159     }
1160     compute_band_energy(newE, X);
1161     for (int i = 0; i < NB_BANDS; i++) {
1162         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1163     }
1164     interp_band_gain(normf, norm);
1165     for (int i = 0; i < FREQ_SIZE; i++) {
1166         X[i].re *= normf[i];
1167         X[i].im *= normf[i];
1168     }
1169 }
1170
1171 static const float tansig_table[201] = {
1172     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1173     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1174     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1175     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1176     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1177     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1178     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1179     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1180     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1181     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1182     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1183     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1184     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1185     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1186     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1187     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1188     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1189     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1190     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1191     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1192     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1193     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1194     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1195     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1196     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1197     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1198     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1199     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1200     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1201     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1202     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1203     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1204     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1205     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1206     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1207     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1208     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1209     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1210     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1211     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1212     1.000000f,
1213 };
1214
1215 static inline float tansig_approx(float x)
1216 {
1217     float y, dy;
1218     float sign=1;
1219     int i;
1220
1221     /* Tests are reversed to catch NaNs */
1222     if (!(x<8))
1223         return 1;
1224     if (!(x>-8))
1225         return -1;
1226     /* Another check in case of -ffast-math */
1227
1228     if (isnan(x))
1229        return 0;
1230
1231     if (x < 0) {
1232        x=-x;
1233        sign=-1;
1234     }
1235     i = (int)floor(.5f+25*x);
1236     x -= .04f*i;
1237     y = tansig_table[i];
1238     dy = 1-y*y;
1239     y = y + x*dy*(1 - y*x);
1240     return sign*y;
1241 }
1242
1243 static inline float sigmoid_approx(float x)
1244 {
1245     return .5f + .5f*tansig_approx(.5f*x);
1246 }
1247
1248 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1249 {
1250     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1251
1252     for (int i = 0; i < N; i++) {
1253         /* Compute update gate. */
1254         float sum = layer->bias[i];
1255
1256         for (int j = 0; j < M; j++)
1257             sum += layer->input_weights[j * stride + i] * input[j];
1258
1259         output[i] = WEIGHTS_SCALE * sum;
1260     }
1261
1262     if (layer->activation == ACTIVATION_SIGMOID) {
1263         for (int i = 0; i < N; i++)
1264             output[i] = sigmoid_approx(output[i]);
1265     } else if (layer->activation == ACTIVATION_TANH) {
1266         for (int i = 0; i < N; i++)
1267             output[i] = tansig_approx(output[i]);
1268     } else if (layer->activation == ACTIVATION_RELU) {
1269         for (int i = 0; i < N; i++)
1270             output[i] = FFMAX(0, output[i]);
1271     } else {
1272         av_assert0(0);
1273     }
1274 }
1275
1276 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1277 {
1278     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1279     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1280     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1281     const int M = gru->nb_inputs;
1282     const int N = gru->nb_neurons;
1283     const int AN = FFALIGN(N, 4);
1284     const int AM = FFALIGN(M, 4);
1285     const int stride = 3 * AN, istride = 3 * AM;
1286
1287     for (int i = 0; i < N; i++) {
1288         /* Compute update gate. */
1289         float sum = gru->bias[i];
1290
1291         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1292         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1293         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1294     }
1295
1296     for (int i = 0; i < N; i++) {
1297         /* Compute reset gate. */
1298         float sum = gru->bias[N + i];
1299
1300         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1301         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1302         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1303     }
1304
1305     for (int i = 0; i < N; i++) {
1306         /* Compute output. */
1307         float sum = gru->bias[2 * N + i];
1308
1309         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1310         for (int j = 0; j < N; j++)
1311             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1312
1313         if (gru->activation == ACTIVATION_SIGMOID)
1314             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1315         else if (gru->activation == ACTIVATION_TANH)
1316             sum = tansig_approx(WEIGHTS_SCALE * sum);
1317         else if (gru->activation == ACTIVATION_RELU)
1318             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1319         else
1320             av_assert0(0);
1321         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1322     }
1323
1324     RNN_COPY(state, h, N);
1325 }
1326
1327 #define INPUT_SIZE 42
1328
1329 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1330 {
1331     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1332     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1333     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1334
1335     compute_dense(rnn->model->input_dense, dense_out, input);
1336     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1337     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1338
1339     memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1340     memcpy(noise_input + rnn->model->input_dense_size,
1341            rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1342     memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1343            input, INPUT_SIZE * sizeof(float));
1344
1345     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1346
1347     memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1348     memcpy(denoise_input + rnn->model->vad_gru_size,
1349            rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1350     memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1351            input, INPUT_SIZE * sizeof(float));
1352
1353     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1354     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1355 }
1356
1357 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1358                              int disabled)
1359 {
1360     AVComplexFloat X[FREQ_SIZE];
1361     AVComplexFloat P[WINDOW_SIZE];
1362     float x[FRAME_SIZE];
1363     float Ex[NB_BANDS], Ep[NB_BANDS];
1364     LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1365     float features[NB_FEATURES];
1366     float g[NB_BANDS];
1367     float gf[FREQ_SIZE];
1368     float vad_prob = 0;
1369     float *history = st->history;
1370     static const float a_hp[2] = {-1.99599, 0.99600};
1371     static const float b_hp[2] = {-2, 1};
1372     int silence;
1373
1374     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1375     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1376
1377     if (!silence && !disabled) {
1378         compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1379         pitch_filter(X, P, Ex, Ep, Exp, g);
1380         for (int i = 0; i < NB_BANDS; i++) {
1381             float alpha = .6f;
1382
1383             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1384             st->lastg[i] = g[i];
1385         }
1386
1387         interp_band_gain(gf, g);
1388
1389         for (int i = 0; i < FREQ_SIZE; i++) {
1390             X[i].re *= gf[i];
1391             X[i].im *= gf[i];
1392         }
1393     }
1394
1395     frame_synthesis(s, st, out, X);
1396     memcpy(history, in, FRAME_SIZE * sizeof(*history));
1397
1398     return vad_prob;
1399 }
1400
1401 typedef struct ThreadData {
1402     AVFrame *in, *out;
1403 } ThreadData;
1404
1405 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1406 {
1407     AudioRNNContext *s = ctx->priv;
1408     ThreadData *td = arg;
1409     AVFrame *in = td->in;
1410     AVFrame *out = td->out;
1411     const int start = (out->channels * jobnr) / nb_jobs;
1412     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1413
1414     for (int ch = start; ch < end; ch++) {
1415         rnnoise_channel(s, &s->st[ch],
1416                         (float *)out->extended_data[ch],
1417                         (const float *)in->extended_data[ch],
1418                         ctx->is_disabled);
1419     }
1420
1421     return 0;
1422 }
1423
1424 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1425 {
1426     AVFilterContext *ctx = inlink->dst;
1427     AVFilterLink *outlink = ctx->outputs[0];
1428     AVFrame *out = NULL;
1429     ThreadData td;
1430
1431     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1432     if (!out) {
1433         av_frame_free(&in);
1434         return AVERROR(ENOMEM);
1435     }
1436     out->pts = in->pts;
1437
1438     td.in = in; td.out = out;
1439     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1440                                                                    ff_filter_get_nb_threads(ctx)));
1441
1442     av_frame_free(&in);
1443     return ff_filter_frame(outlink, out);
1444 }
1445
1446 static int activate(AVFilterContext *ctx)
1447 {
1448     AVFilterLink *inlink = ctx->inputs[0];
1449     AVFilterLink *outlink = ctx->outputs[0];
1450     AVFrame *in = NULL;
1451     int ret;
1452
1453     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1454
1455     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1456     if (ret < 0)
1457         return ret;
1458
1459     if (ret > 0)
1460         return filter_frame(inlink, in);
1461
1462     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1463     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1464
1465     return FFERROR_NOT_READY;
1466 }
1467
1468 static int open_model(AVFilterContext *ctx, RNNModel **model)
1469 {
1470     AudioRNNContext *s = ctx->priv;
1471     FILE *f;
1472
1473     if (!s->model_name)
1474         return AVERROR(EINVAL);
1475     f = av_fopen_utf8(s->model_name, "r");
1476     if (!f)
1477         return AVERROR(EINVAL);
1478
1479     *model = rnnoise_model_from_file(f);
1480     fclose(f);
1481     if (!*model)
1482         return AVERROR(EINVAL);
1483
1484     return 0;
1485 }
1486
1487 static av_cold int init(AVFilterContext *ctx)
1488 {
1489     AudioRNNContext *s = ctx->priv;
1490     int ret;
1491
1492     s->fdsp = avpriv_float_dsp_alloc(0);
1493     if (!s->fdsp)
1494         return AVERROR(ENOMEM);
1495
1496     ret = open_model(ctx, &s->model[0]);
1497     if (ret < 0)
1498         return ret;
1499
1500     for (int i = 0; i < FRAME_SIZE; i++) {
1501         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1502         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1503     }
1504
1505     for (int i = 0; i < NB_BANDS; i++) {
1506         for (int j = 0; j < NB_BANDS; j++) {
1507             s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1508             if (j == 0)
1509                 s->dct_table[j][i] *= sqrtf(.5);
1510         }
1511     }
1512
1513     return 0;
1514 }
1515
1516 static void free_model(AVFilterContext *ctx, int n)
1517 {
1518     AudioRNNContext *s = ctx->priv;
1519
1520     rnnoise_model_free(s->model[n]);
1521     s->model[n] = NULL;
1522
1523     for (int ch = 0; ch < s->channels && s->st; ch++) {
1524         av_freep(&s->st[ch].rnn[n].vad_gru_state);
1525         av_freep(&s->st[ch].rnn[n].noise_gru_state);
1526         av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1527     }
1528 }
1529
1530 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1531                            char *res, int res_len, int flags)
1532 {
1533     AudioRNNContext *s = ctx->priv;
1534     int ret;
1535
1536     ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1537     if (ret < 0)
1538         return ret;
1539
1540     ret = open_model(ctx, &s->model[1]);
1541     if (ret < 0)
1542         return ret;
1543
1544     FFSWAP(RNNModel *, s->model[0], s->model[1]);
1545     for (int ch = 0; ch < s->channels; ch++)
1546         FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1547
1548     ret = config_input(ctx->inputs[0]);
1549     if (ret < 0) {
1550         for (int ch = 0; ch < s->channels; ch++)
1551             FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1552         FFSWAP(RNNModel *, s->model[0], s->model[1]);
1553         return ret;
1554     }
1555
1556     free_model(ctx, 1);
1557     return 0;
1558 }
1559
1560 static av_cold void uninit(AVFilterContext *ctx)
1561 {
1562     AudioRNNContext *s = ctx->priv;
1563
1564     av_freep(&s->fdsp);
1565     free_model(ctx, 0);
1566     for (int ch = 0; ch < s->channels && s->st; ch++) {
1567         av_tx_uninit(&s->st[ch].tx);
1568         av_tx_uninit(&s->st[ch].txi);
1569     }
1570     av_freep(&s->st);
1571 }
1572
1573 static const AVFilterPad inputs[] = {
1574     {
1575         .name         = "default",
1576         .type         = AVMEDIA_TYPE_AUDIO,
1577         .config_props = config_input,
1578     },
1579     { NULL }
1580 };
1581
1582 static const AVFilterPad outputs[] = {
1583     {
1584         .name          = "default",
1585         .type          = AVMEDIA_TYPE_AUDIO,
1586     },
1587     { NULL }
1588 };
1589
1590 #define OFFSET(x) offsetof(AudioRNNContext, x)
1591 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1592
1593 static const AVOption arnndn_options[] = {
1594     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1595     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1596     { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1597     { NULL }
1598 };
1599
1600 AVFILTER_DEFINE_CLASS(arnndn);
1601
1602 AVFilter ff_af_arnndn = {
1603     .name          = "arnndn",
1604     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1605     .query_formats = query_formats,
1606     .priv_size     = sizeof(AudioRNNContext),
1607     .priv_class    = &arnndn_class,
1608     .activate      = activate,
1609     .init          = init,
1610     .uninit        = uninit,
1611     .inputs        = inputs,
1612     .outputs       = outputs,
1613     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1614                      AVFILTER_FLAG_SLICE_THREADS,
1615     .process_command = process_command,
1616 };