]> git.sesse.net Git - ffmpeg/blob - libavfilter/af_arnndn.c
avfilter/af_arnndn: make model opening errors more verbose
[ffmpeg] / libavfilter / af_arnndn.c
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
42 #include "avfilter.h"
43 #include "audio.h"
44 #include "filters.h"
45 #include "formats.h"
46
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
51
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56
57 #define SQUARE(x) ((x)*(x))
58
59 #define NB_BANDS 22
60
61 #define CEPS_MEM 8
62 #define NB_DELTA_CEPS 6
63
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65
66 #define WEIGHTS_SCALE (1.f/256)
67
68 #define MAX_NEURONS 128
69
70 #define ACTIVATION_TANH    0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU    2
73
74 #define Q15ONE 1.0f
75
76 typedef struct DenseLayer {
77     const float *bias;
78     const float *input_weights;
79     int nb_inputs;
80     int nb_neurons;
81     int activation;
82 } DenseLayer;
83
84 typedef struct GRULayer {
85     const float *bias;
86     const float *input_weights;
87     const float *recurrent_weights;
88     int nb_inputs;
89     int nb_neurons;
90     int activation;
91 } GRULayer;
92
93 typedef struct RNNModel {
94     int input_dense_size;
95     const DenseLayer *input_dense;
96
97     int vad_gru_size;
98     const GRULayer *vad_gru;
99
100     int noise_gru_size;
101     const GRULayer *noise_gru;
102
103     int denoise_gru_size;
104     const GRULayer *denoise_gru;
105
106     int denoise_output_size;
107     const DenseLayer *denoise_output;
108
109     int vad_output_size;
110     const DenseLayer *vad_output;
111 } RNNModel;
112
113 typedef struct RNNState {
114     float *vad_gru_state;
115     float *noise_gru_state;
116     float *denoise_gru_state;
117     RNNModel *model;
118 } RNNState;
119
120 typedef struct DenoiseState {
121     float analysis_mem[FRAME_SIZE];
122     float cepstral_mem[CEPS_MEM][NB_BANDS];
123     int memid;
124     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125     float pitch_buf[PITCH_BUF_SIZE];
126     float pitch_enh_buf[PITCH_BUF_SIZE];
127     float last_gain;
128     int last_period;
129     float mem_hp_x[2];
130     float lastg[NB_BANDS];
131     float history[FRAME_SIZE];
132     RNNState rnn[2];
133     AVTXContext *tx, *txi;
134     av_tx_fn tx_fn, txi_fn;
135 } DenoiseState;
136
137 typedef struct AudioRNNContext {
138     const AVClass *class;
139
140     char *model_name;
141     float mix;
142
143     int channels;
144     DenoiseState *st;
145
146     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147     DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148
149     RNNModel *model[2];
150
151     AVFloatDSPContext *fdsp;
152 } AudioRNNContext;
153
154 #define F_ACTIVATION_TANH       0
155 #define F_ACTIVATION_SIGMOID    1
156 #define F_ACTIVATION_RELU       2
157
158 static void rnnoise_model_free(RNNModel *model)
159 {
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
162     if (model->name) { \
163         av_free((void *) model->name->input_weights); \
164         av_free((void *) model->name->bias); \
165         av_free((void *) model->name); \
166     } \
167     } while (0)
168 #define FREE_GRU(name) do { \
169     if (model->name) { \
170         av_free((void *) model->name->input_weights); \
171         av_free((void *) model->name->recurrent_weights); \
172         av_free((void *) model->name->bias); \
173         av_free((void *) model->name); \
174     } \
175     } while (0)
176
177     if (!model)
178         return;
179     FREE_DENSE(input_dense);
180     FREE_GRU(vad_gru);
181     FREE_GRU(noise_gru);
182     FREE_GRU(denoise_gru);
183     FREE_DENSE(denoise_output);
184     FREE_DENSE(vad_output);
185     av_free(model);
186 }
187
188 static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
189 {
190     RNNModel *ret;
191     DenseLayer *input_dense;
192     GRULayer *vad_gru;
193     GRULayer *noise_gru;
194     GRULayer *denoise_gru;
195     DenseLayer *denoise_output;
196     DenseLayer *vad_output;
197     int in;
198
199     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200         return AVERROR_INVALIDDATA;
201
202     ret = av_calloc(1, sizeof(RNNModel));
203     if (!ret)
204         return AVERROR(ENOMEM);
205
206 #define ALLOC_LAYER(type, name) \
207     name = av_calloc(1, sizeof(type)); \
208     if (!name) { \
209         rnnoise_model_free(ret); \
210         return AVERROR(ENOMEM); \
211     } \
212     ret->name = name
213
214     ALLOC_LAYER(DenseLayer, input_dense);
215     ALLOC_LAYER(GRULayer, vad_gru);
216     ALLOC_LAYER(GRULayer, noise_gru);
217     ALLOC_LAYER(GRULayer, denoise_gru);
218     ALLOC_LAYER(DenseLayer, denoise_output);
219     ALLOC_LAYER(DenseLayer, vad_output);
220
221 #define INPUT_VAL(name) do { \
222     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223         rnnoise_model_free(ret); \
224         return AVERROR(EINVAL); \
225     } \
226     name = in; \
227     } while (0)
228
229 #define INPUT_ACTIVATION(name) do { \
230     int activation; \
231     INPUT_VAL(activation); \
232     switch (activation) { \
233     case F_ACTIVATION_SIGMOID: \
234         name = ACTIVATION_SIGMOID; \
235         break; \
236     case F_ACTIVATION_RELU: \
237         name = ACTIVATION_RELU; \
238         break; \
239     default: \
240         name = ACTIVATION_TANH; \
241     } \
242     } while (0)
243
244 #define INPUT_ARRAY(name, len) do { \
245     float *values = av_calloc((len), sizeof(float)); \
246     if (!values) { \
247         rnnoise_model_free(ret); \
248         return AVERROR(ENOMEM); \
249     } \
250     name = values; \
251     for (int i = 0; i < (len); i++) { \
252         if (fscanf(f, "%d", &in) != 1) { \
253             rnnoise_model_free(ret); \
254             return AVERROR(EINVAL); \
255         } \
256         values[i] = in; \
257     } \
258     } while (0)
259
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262     if (!values) { \
263         rnnoise_model_free(ret); \
264         return AVERROR(ENOMEM); \
265     } \
266     name = values; \
267     for (int k = 0; k < (len0); k++) { \
268         for (int i = 0; i < (len2); i++) { \
269             for (int j = 0; j < (len1); j++) { \
270                 if (fscanf(f, "%d", &in) != 1) { \
271                     rnnoise_model_free(ret); \
272                     return AVERROR(EINVAL); \
273                 } \
274                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275             } \
276         } \
277     } \
278     } while (0)
279
280 #define INPUT_DENSE(name) do { \
281     INPUT_VAL(name->nb_inputs); \
282     INPUT_VAL(name->nb_neurons); \
283     ret->name ## _size = name->nb_neurons; \
284     INPUT_ACTIVATION(name->activation); \
285     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
286     INPUT_ARRAY(name->bias, name->nb_neurons); \
287     } while (0)
288
289 #define INPUT_GRU(name) do { \
290     INPUT_VAL(name->nb_inputs); \
291     INPUT_VAL(name->nb_neurons); \
292     ret->name ## _size = name->nb_neurons; \
293     INPUT_ACTIVATION(name->activation); \
294     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
295     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
296     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
297     } while (0)
298
299     INPUT_DENSE(input_dense);
300     INPUT_GRU(vad_gru);
301     INPUT_GRU(noise_gru);
302     INPUT_GRU(denoise_gru);
303     INPUT_DENSE(denoise_output);
304     INPUT_DENSE(vad_output);
305
306     if (vad_output->nb_neurons != 1) {
307         rnnoise_model_free(ret);
308         return AVERROR(EINVAL);
309     }
310
311     *rnn = ret;
312
313     return 0;
314 }
315
316 static int query_formats(AVFilterContext *ctx)
317 {
318     AVFilterFormats *formats = NULL;
319     AVFilterChannelLayouts *layouts = NULL;
320     static const enum AVSampleFormat sample_fmts[] = {
321         AV_SAMPLE_FMT_FLTP,
322         AV_SAMPLE_FMT_NONE
323     };
324     int ret, sample_rates[] = { 48000, -1 };
325
326     formats = ff_make_format_list(sample_fmts);
327     if (!formats)
328         return AVERROR(ENOMEM);
329     ret = ff_set_common_formats(ctx, formats);
330     if (ret < 0)
331         return ret;
332
333     layouts = ff_all_channel_counts();
334     if (!layouts)
335         return AVERROR(ENOMEM);
336
337     ret = ff_set_common_channel_layouts(ctx, layouts);
338     if (ret < 0)
339         return ret;
340
341     formats = ff_make_format_list(sample_rates);
342     if (!formats)
343         return AVERROR(ENOMEM);
344     return ff_set_common_samplerates(ctx, formats);
345 }
346
347 static int config_input(AVFilterLink *inlink)
348 {
349     AVFilterContext *ctx = inlink->dst;
350     AudioRNNContext *s = ctx->priv;
351     int ret;
352
353     s->channels = inlink->channels;
354
355     if (!s->st)
356         s->st = av_calloc(s->channels, sizeof(DenoiseState));
357     if (!s->st)
358         return AVERROR(ENOMEM);
359
360     for (int i = 0; i < s->channels; i++) {
361         DenoiseState *st = &s->st[i];
362
363         st->rnn[0].model = s->model[0];
364         st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
365         st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
366         st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
367         if (!st->rnn[0].vad_gru_state ||
368             !st->rnn[0].noise_gru_state ||
369             !st->rnn[0].denoise_gru_state)
370             return AVERROR(ENOMEM);
371     }
372
373     for (int i = 0; i < s->channels; i++) {
374         DenoiseState *st = &s->st[i];
375
376         if (!st->tx)
377             ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
378         if (ret < 0)
379             return ret;
380
381         if (!st->txi)
382             ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
383         if (ret < 0)
384             return ret;
385     }
386
387     return 0;
388 }
389
390 static void biquad(float *y, float mem[2], const float *x,
391                    const float *b, const float *a, int N)
392 {
393     for (int i = 0; i < N; i++) {
394         float xi, yi;
395
396         xi = x[i];
397         yi = x[i] + mem[0];
398         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
399         mem[1] = (b[1]*xi - a[1]*yi);
400         y[i] = yi;
401     }
402 }
403
404 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
405 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
406 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
407
408 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
409 {
410     AVComplexFloat x[WINDOW_SIZE];
411     AVComplexFloat y[WINDOW_SIZE];
412
413     for (int i = 0; i < WINDOW_SIZE; i++) {
414         x[i].re = in[i];
415         x[i].im = 0;
416     }
417
418     st->tx_fn(st->tx, y, x, sizeof(float));
419
420     RNN_COPY(out, y, FREQ_SIZE);
421 }
422
423 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
424 {
425     AVComplexFloat x[WINDOW_SIZE];
426     AVComplexFloat y[WINDOW_SIZE];
427
428     RNN_COPY(x, in, FREQ_SIZE);
429
430     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
431         x[i].re =  x[WINDOW_SIZE - i].re;
432         x[i].im = -x[WINDOW_SIZE - i].im;
433     }
434
435     st->txi_fn(st->txi, y, x, sizeof(float));
436
437     for (int i = 0; i < WINDOW_SIZE; i++)
438         out[i] = y[i].re / WINDOW_SIZE;
439 }
440
441 static const uint8_t eband5ms[] = {
442 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
443   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
444 };
445
446 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
447 {
448     float sum[NB_BANDS] = {0};
449
450     for (int i = 0; i < NB_BANDS - 1; i++) {
451         int band_size;
452
453         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
454         for (int j = 0; j < band_size; j++) {
455             float tmp, frac = (float)j / band_size;
456
457             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
458             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
459             sum[i]     += (1.f - frac) * tmp;
460             sum[i + 1] +=        frac  * tmp;
461         }
462     }
463
464     sum[0] *= 2;
465     sum[NB_BANDS - 1] *= 2;
466
467     for (int i = 0; i < NB_BANDS; i++)
468         bandE[i] = sum[i];
469 }
470
471 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
472 {
473     float sum[NB_BANDS] = { 0 };
474
475     for (int i = 0; i < NB_BANDS - 1; i++) {
476         int band_size;
477
478         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
479         for (int j = 0; j < band_size; j++) {
480             float tmp, frac = (float)j / band_size;
481
482             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
483             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
484             sum[i]     += (1 - frac) * tmp;
485             sum[i + 1] +=      frac  * tmp;
486         }
487     }
488
489     sum[0] *= 2;
490     sum[NB_BANDS-1] *= 2;
491
492     for (int i = 0; i < NB_BANDS; i++)
493         bandE[i] = sum[i];
494 }
495
496 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
497 {
498     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
499
500     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
501     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
502     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
503     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
504     forward_transform(st, X, x);
505     compute_band_energy(Ex, X);
506 }
507
508 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
509 {
510     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
511     const float *src = st->history;
512     const float mix = s->mix;
513     const float imix = 1.f - FFMAX(mix, 0.f);
514
515     inverse_transform(st, x, y);
516     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
517     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
518     RNN_COPY(out, x, FRAME_SIZE);
519     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
520
521     for (int n = 0; n < FRAME_SIZE; n++)
522         out[n] = out[n] * mix + src[n] * imix;
523 }
524
525 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
526 {
527     float y_0, y_1, y_2, y_3 = 0;
528     int j;
529
530     y_0 = *y++;
531     y_1 = *y++;
532     y_2 = *y++;
533
534     for (j = 0; j < len - 3; j += 4) {
535         float tmp;
536
537         tmp = *x++;
538         y_3 = *y++;
539         sum[0] += tmp * y_0;
540         sum[1] += tmp * y_1;
541         sum[2] += tmp * y_2;
542         sum[3] += tmp * y_3;
543         tmp = *x++;
544         y_0 = *y++;
545         sum[0] += tmp * y_1;
546         sum[1] += tmp * y_2;
547         sum[2] += tmp * y_3;
548         sum[3] += tmp * y_0;
549         tmp = *x++;
550         y_1 = *y++;
551         sum[0] += tmp * y_2;
552         sum[1] += tmp * y_3;
553         sum[2] += tmp * y_0;
554         sum[3] += tmp * y_1;
555         tmp = *x++;
556         y_2 = *y++;
557         sum[0] += tmp * y_3;
558         sum[1] += tmp * y_0;
559         sum[2] += tmp * y_1;
560         sum[3] += tmp * y_2;
561     }
562
563     if (j++ < len) {
564         float tmp = *x++;
565
566         y_3 = *y++;
567         sum[0] += tmp * y_0;
568         sum[1] += tmp * y_1;
569         sum[2] += tmp * y_2;
570         sum[3] += tmp * y_3;
571     }
572
573     if (j++ < len) {
574         float tmp=*x++;
575
576         y_0 = *y++;
577         sum[0] += tmp * y_1;
578         sum[1] += tmp * y_2;
579         sum[2] += tmp * y_3;
580         sum[3] += tmp * y_0;
581     }
582
583     if (j < len) {
584         float tmp=*x++;
585
586         y_1 = *y++;
587         sum[0] += tmp * y_2;
588         sum[1] += tmp * y_3;
589         sum[2] += tmp * y_0;
590         sum[3] += tmp * y_1;
591     }
592 }
593
594 static inline float celt_inner_prod(const float *x,
595                                     const float *y, int N)
596 {
597     float xy = 0.f;
598
599     for (int i = 0; i < N; i++)
600         xy += x[i] * y[i];
601
602     return xy;
603 }
604
605 static void celt_pitch_xcorr(const float *x, const float *y,
606                              float *xcorr, int len, int max_pitch)
607 {
608     int i;
609
610     for (i = 0; i < max_pitch - 3; i += 4) {
611         float sum[4] = { 0, 0, 0, 0};
612
613         xcorr_kernel(x, y + i, sum, len);
614
615         xcorr[i]     = sum[0];
616         xcorr[i + 1] = sum[1];
617         xcorr[i + 2] = sum[2];
618         xcorr[i + 3] = sum[3];
619     }
620     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
621     for (; i < max_pitch; i++) {
622         xcorr[i] = celt_inner_prod(x, y + i, len);
623     }
624 }
625
626 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
627                          float       *ac,  /* out: [0...lag-1] ac values */
628                          const float *window,
629                          int          overlap,
630                          int          lag,
631                          int          n)
632 {
633     int fastN = n - lag;
634     int shift;
635     const float *xptr;
636     float xx[PITCH_BUF_SIZE>>1];
637
638     if (overlap == 0) {
639         xptr = x;
640     } else {
641         for (int i = 0; i < n; i++)
642             xx[i] = x[i];
643         for (int i = 0; i < overlap; i++) {
644             xx[i] = x[i] * window[i];
645             xx[n-i-1] = x[n-i-1] * window[i];
646         }
647         xptr = xx;
648     }
649
650     shift = 0;
651     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
652
653     for (int k = 0; k <= lag; k++) {
654         float d = 0.f;
655
656         for (int i = k + fastN; i < n; i++)
657             d += xptr[i] * xptr[i-k];
658         ac[k] += d;
659     }
660
661     return shift;
662 }
663
664 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
665                 const float *ac,   /* in:  [0...p] autocorrelation values  */
666                           int p)
667 {
668     float r, error = ac[0];
669
670     RNN_CLEAR(lpc, p);
671     if (ac[0] != 0) {
672         for (int i = 0; i < p; i++) {
673             /* Sum up this iteration's reflection coefficient */
674             float rr = 0;
675             for (int j = 0; j < i; j++)
676                 rr += (lpc[j] * ac[i - j]);
677             rr += ac[i + 1];
678             r = -rr/error;
679             /*  Update LPC coefficients and total error */
680             lpc[i] = r;
681             for (int j = 0; j < (i + 1) >> 1; j++) {
682                 float tmp1, tmp2;
683                 tmp1 = lpc[j];
684                 tmp2 = lpc[i-1-j];
685                 lpc[j]     = tmp1 + (r*tmp2);
686                 lpc[i-1-j] = tmp2 + (r*tmp1);
687             }
688
689             error = error - (r * r *error);
690             /* Bail out once we get 30 dB gain */
691             if (error < .001f * ac[0])
692                 break;
693         }
694     }
695 }
696
697 static void celt_fir5(const float *x,
698                       const float *num,
699                       float *y,
700                       int N,
701                       float *mem)
702 {
703     float num0, num1, num2, num3, num4;
704     float mem0, mem1, mem2, mem3, mem4;
705
706     num0 = num[0];
707     num1 = num[1];
708     num2 = num[2];
709     num3 = num[3];
710     num4 = num[4];
711     mem0 = mem[0];
712     mem1 = mem[1];
713     mem2 = mem[2];
714     mem3 = mem[3];
715     mem4 = mem[4];
716
717     for (int i = 0; i < N; i++) {
718         float sum = x[i];
719
720         sum += (num0*mem0);
721         sum += (num1*mem1);
722         sum += (num2*mem2);
723         sum += (num3*mem3);
724         sum += (num4*mem4);
725         mem4 = mem3;
726         mem3 = mem2;
727         mem2 = mem1;
728         mem1 = mem0;
729         mem0 = x[i];
730         y[i] = sum;
731     }
732
733     mem[0] = mem0;
734     mem[1] = mem1;
735     mem[2] = mem2;
736     mem[3] = mem3;
737     mem[4] = mem4;
738 }
739
740 static void pitch_downsample(float *x[], float *x_lp,
741                              int len, int C)
742 {
743     float ac[5];
744     float tmp=Q15ONE;
745     float lpc[4], mem[5]={0,0,0,0,0};
746     float lpc2[5];
747     float c1 = .8f;
748
749     for (int i = 1; i < len >> 1; i++)
750         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
751     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
752     if (C==2) {
753         for (int i = 1; i < len >> 1; i++)
754             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
755         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
756     }
757
758     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
759
760     /* Noise floor -40 dB */
761     ac[0] *= 1.0001f;
762     /* Lag windowing */
763     for (int i = 1; i <= 4; i++) {
764         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
765         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
766     }
767
768     celt_lpc(lpc, ac, 4);
769     for (int i = 0; i < 4; i++) {
770         tmp = .9f * tmp;
771         lpc[i] = (lpc[i] * tmp);
772     }
773     /* Add a zero */
774     lpc2[0] = lpc[0] + .8f;
775     lpc2[1] = lpc[1] + (c1 * lpc[0]);
776     lpc2[2] = lpc[2] + (c1 * lpc[1]);
777     lpc2[3] = lpc[3] + (c1 * lpc[2]);
778     lpc2[4] = (c1 * lpc[3]);
779     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
780 }
781
782 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
783                                    int N, float *xy1, float *xy2)
784 {
785     float xy01 = 0, xy02 = 0;
786
787     for (int i = 0; i < N; i++) {
788         xy01 += (x[i] * y01[i]);
789         xy02 += (x[i] * y02[i]);
790     }
791
792     *xy1 = xy01;
793     *xy2 = xy02;
794 }
795
796 static float compute_pitch_gain(float xy, float xx, float yy)
797 {
798     return xy / sqrtf(1.f + xx * yy);
799 }
800
801 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
802 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
803                              int *T0_, int prev_period, float prev_gain)
804 {
805     int k, i, T, T0;
806     float g, g0;
807     float pg;
808     float xy,xx,yy,xy2;
809     float xcorr[3];
810     float best_xy, best_yy;
811     int offset;
812     int minperiod0;
813     float yy_lookup[PITCH_MAX_PERIOD+1];
814
815     minperiod0 = minperiod;
816     maxperiod /= 2;
817     minperiod /= 2;
818     *T0_ /= 2;
819     prev_period /= 2;
820     N /= 2;
821     x += maxperiod;
822     if (*T0_>=maxperiod)
823         *T0_=maxperiod-1;
824
825     T = T0 = *T0_;
826     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
827     yy_lookup[0] = xx;
828     yy=xx;
829     for (i = 1; i <= maxperiod; i++) {
830         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
831         yy_lookup[i] = FFMAX(0, yy);
832     }
833     yy = yy_lookup[T0];
834     best_xy = xy;
835     best_yy = yy;
836     g = g0 = compute_pitch_gain(xy, xx, yy);
837     /* Look for any pitch at T/k */
838     for (k = 2; k <= 15; k++) {
839         int T1, T1b;
840         float g1;
841         float cont=0;
842         float thresh;
843         T1 = (2*T0+k)/(2*k);
844         if (T1 < minperiod)
845             break;
846         /* Look for another strong correlation at T1b */
847         if (k==2)
848         {
849             if (T1+T0>maxperiod)
850                 T1b = T0;
851             else
852                 T1b = T0+T1;
853         } else
854         {
855             T1b = (2*second_check[k]*T0+k)/(2*k);
856         }
857         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
858         xy = .5f * (xy + xy2);
859         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
860         g1 = compute_pitch_gain(xy, xx, yy);
861         if (FFABS(T1-prev_period)<=1)
862             cont = prev_gain;
863         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
864             cont = prev_gain * .5f;
865         else
866             cont = 0;
867         thresh = FFMAX(.3f, (.7f * g0) - cont);
868         /* Bias against very high pitch (very short period) to avoid false-positives
869            due to short-term correlation */
870         if (T1<3*minperiod)
871             thresh = FFMAX(.4f, (.85f * g0) - cont);
872         else if (T1<2*minperiod)
873             thresh = FFMAX(.5f, (.9f * g0) - cont);
874         if (g1 > thresh)
875         {
876             best_xy = xy;
877             best_yy = yy;
878             T = T1;
879             g = g1;
880         }
881     }
882     best_xy = FFMAX(0, best_xy);
883     if (best_yy <= best_xy)
884         pg = Q15ONE;
885     else
886         pg = best_xy/(best_yy + 1);
887
888     for (k = 0; k < 3; k++)
889         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
890     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
891         offset = 1;
892     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
893         offset = -1;
894     else
895         offset = 0;
896     if (pg > g)
897         pg = g;
898     *T0_ = 2*T+offset;
899
900     if (*T0_<minperiod0)
901         *T0_=minperiod0;
902     return pg;
903 }
904
905 static void find_best_pitch(float *xcorr, float *y, int len,
906                             int max_pitch, int *best_pitch)
907 {
908     float best_num[2];
909     float best_den[2];
910     float Syy = 1.f;
911
912     best_num[0] = -1;
913     best_num[1] = -1;
914     best_den[0] = 0;
915     best_den[1] = 0;
916     best_pitch[0] = 0;
917     best_pitch[1] = 1;
918
919     for (int j = 0; j < len; j++)
920         Syy += y[j] * y[j];
921
922     for (int i = 0; i < max_pitch; i++) {
923         if (xcorr[i]>0) {
924             float num;
925             float xcorr16;
926
927             xcorr16 = xcorr[i];
928             /* Considering the range of xcorr16, this should avoid both underflows
929                and overflows (inf) when squaring xcorr16 */
930             xcorr16 *= 1e-12f;
931             num = xcorr16 * xcorr16;
932             if ((num * best_den[1]) > (best_num[1] * Syy)) {
933                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
934                     best_num[1] = best_num[0];
935                     best_den[1] = best_den[0];
936                     best_pitch[1] = best_pitch[0];
937                     best_num[0] = num;
938                     best_den[0] = Syy;
939                     best_pitch[0] = i;
940                 } else {
941                     best_num[1] = num;
942                     best_den[1] = Syy;
943                     best_pitch[1] = i;
944                 }
945             }
946         }
947         Syy += y[i+len]*y[i+len] - y[i] * y[i];
948         Syy = FFMAX(1, Syy);
949     }
950 }
951
952 static void pitch_search(const float *x_lp, float *y,
953                          int len, int max_pitch, int *pitch)
954 {
955     int lag;
956     int best_pitch[2]={0,0};
957     int offset;
958
959     float x_lp4[WINDOW_SIZE];
960     float y_lp4[WINDOW_SIZE];
961     float xcorr[WINDOW_SIZE];
962
963     lag = len+max_pitch;
964
965     /* Downsample by 2 again */
966     for (int j = 0; j < len >> 2; j++)
967         x_lp4[j] = x_lp[2*j];
968     for (int j = 0; j < lag >> 2; j++)
969         y_lp4[j] = y[2*j];
970
971     /* Coarse search with 4x decimation */
972
973     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
974
975     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
976
977     /* Finer search with 2x decimation */
978     for (int i = 0; i < max_pitch >> 1; i++) {
979         float sum;
980         xcorr[i] = 0;
981         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
982             continue;
983         sum = celt_inner_prod(x_lp, y+i, len>>1);
984         xcorr[i] = FFMAX(-1, sum);
985     }
986
987     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
988
989     /* Refine by pseudo-interpolation */
990     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
991         float a, b, c;
992
993         a = xcorr[best_pitch[0] - 1];
994         b = xcorr[best_pitch[0]];
995         c = xcorr[best_pitch[0] + 1];
996         if (c - a > .7f * (b - a))
997             offset = 1;
998         else if (a - c > .7f * (b-c))
999             offset = -1;
1000         else
1001             offset = 0;
1002     } else {
1003         offset = 0;
1004     }
1005
1006     *pitch = 2 * best_pitch[0] - offset;
1007 }
1008
1009 static void dct(AudioRNNContext *s, float *out, const float *in)
1010 {
1011     for (int i = 0; i < NB_BANDS; i++) {
1012         float sum;
1013
1014         sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1015         out[i] = sum * sqrtf(2.f / 22);
1016     }
1017 }
1018
1019 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1020                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1021 {
1022     float E = 0;
1023     float *ceps_0, *ceps_1, *ceps_2;
1024     float spec_variability = 0;
1025     LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1026     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1027     float pitch_buf[PITCH_BUF_SIZE>>1];
1028     int pitch_index;
1029     float gain;
1030     float *(pre[1]);
1031     float tmp[NB_BANDS];
1032     float follow, logMax;
1033
1034     frame_analysis(s, st, X, Ex, in);
1035     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1036     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1037     pre[0] = &st->pitch_buf[0];
1038     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1039     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1040             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1041     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1042
1043     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1044             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1045     st->last_period = pitch_index;
1046     st->last_gain = gain;
1047
1048     for (int i = 0; i < WINDOW_SIZE; i++)
1049         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1050
1051     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1052     forward_transform(st, P, p);
1053     compute_band_energy(Ep, P);
1054     compute_band_corr(Exp, X, P);
1055
1056     for (int i = 0; i < NB_BANDS; i++)
1057         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1058
1059     dct(s, tmp, Exp);
1060
1061     for (int i = 0; i < NB_DELTA_CEPS; i++)
1062         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1063
1064     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1065     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1066     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1067     logMax = -2;
1068     follow = -2;
1069
1070     for (int i = 0; i < NB_BANDS; i++) {
1071         Ly[i] = log10f(1e-2f + Ex[i]);
1072         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1073         logMax = FFMAX(logMax, Ly[i]);
1074         follow = FFMAX(follow-1.5, Ly[i]);
1075         E += Ex[i];
1076     }
1077
1078     if (E < 0.04f) {
1079         /* If there's no audio, avoid messing up the state. */
1080         RNN_CLEAR(features, NB_FEATURES);
1081         return 1;
1082     }
1083
1084     dct(s, features, Ly);
1085     features[0] -= 12;
1086     features[1] -= 4;
1087     ceps_0 = st->cepstral_mem[st->memid];
1088     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1089     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1090
1091     for (int i = 0; i < NB_BANDS; i++)
1092         ceps_0[i] = features[i];
1093
1094     st->memid++;
1095     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1096         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1097         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1098         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1099     }
1100     /* Spectral variability features. */
1101     if (st->memid == CEPS_MEM)
1102         st->memid = 0;
1103
1104     for (int i = 0; i < CEPS_MEM; i++) {
1105         float mindist = 1e15f;
1106         for (int j = 0; j < CEPS_MEM; j++) {
1107             float dist = 0.f;
1108             for (int k = 0; k < NB_BANDS; k++) {
1109                 float tmp;
1110
1111                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1112                 dist += tmp*tmp;
1113             }
1114
1115             if (j != i)
1116                 mindist = FFMIN(mindist, dist);
1117         }
1118
1119         spec_variability += mindist;
1120     }
1121
1122     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1123
1124     return 0;
1125 }
1126
1127 static void interp_band_gain(float *g, const float *bandE)
1128 {
1129     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1130
1131     for (int i = 0; i < NB_BANDS - 1; i++) {
1132         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1133
1134         for (int j = 0; j < band_size; j++) {
1135             float frac = (float)j / band_size;
1136
1137             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1138         }
1139     }
1140 }
1141
1142 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1143                          const float *Exp, const float *g)
1144 {
1145     float newE[NB_BANDS];
1146     float r[NB_BANDS];
1147     float norm[NB_BANDS];
1148     float rf[FREQ_SIZE] = {0};
1149     float normf[FREQ_SIZE]={0};
1150
1151     for (int i = 0; i < NB_BANDS; i++) {
1152         if (Exp[i]>g[i]) r[i] = 1;
1153         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1154         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1155         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1156     }
1157     interp_band_gain(rf, r);
1158     for (int i = 0; i < FREQ_SIZE; i++) {
1159         X[i].re += rf[i]*P[i].re;
1160         X[i].im += rf[i]*P[i].im;
1161     }
1162     compute_band_energy(newE, X);
1163     for (int i = 0; i < NB_BANDS; i++) {
1164         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1165     }
1166     interp_band_gain(normf, norm);
1167     for (int i = 0; i < FREQ_SIZE; i++) {
1168         X[i].re *= normf[i];
1169         X[i].im *= normf[i];
1170     }
1171 }
1172
1173 static const float tansig_table[201] = {
1174     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1175     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1176     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1177     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1178     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1179     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1180     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1181     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1182     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1183     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1184     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1185     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1186     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1187     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1188     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1189     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1190     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1191     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1192     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1193     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1194     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1195     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1196     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1197     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1198     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1199     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1200     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1201     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1202     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1203     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1204     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1205     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1206     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1207     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1208     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1209     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1210     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1211     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1212     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1213     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1214     1.000000f,
1215 };
1216
1217 static inline float tansig_approx(float x)
1218 {
1219     float y, dy;
1220     float sign=1;
1221     int i;
1222
1223     /* Tests are reversed to catch NaNs */
1224     if (!(x<8))
1225         return 1;
1226     if (!(x>-8))
1227         return -1;
1228     /* Another check in case of -ffast-math */
1229
1230     if (isnan(x))
1231        return 0;
1232
1233     if (x < 0) {
1234        x=-x;
1235        sign=-1;
1236     }
1237     i = (int)floor(.5f+25*x);
1238     x -= .04f*i;
1239     y = tansig_table[i];
1240     dy = 1-y*y;
1241     y = y + x*dy*(1 - y*x);
1242     return sign*y;
1243 }
1244
1245 static inline float sigmoid_approx(float x)
1246 {
1247     return .5f + .5f*tansig_approx(.5f*x);
1248 }
1249
1250 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1251 {
1252     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1253
1254     for (int i = 0; i < N; i++) {
1255         /* Compute update gate. */
1256         float sum = layer->bias[i];
1257
1258         for (int j = 0; j < M; j++)
1259             sum += layer->input_weights[j * stride + i] * input[j];
1260
1261         output[i] = WEIGHTS_SCALE * sum;
1262     }
1263
1264     if (layer->activation == ACTIVATION_SIGMOID) {
1265         for (int i = 0; i < N; i++)
1266             output[i] = sigmoid_approx(output[i]);
1267     } else if (layer->activation == ACTIVATION_TANH) {
1268         for (int i = 0; i < N; i++)
1269             output[i] = tansig_approx(output[i]);
1270     } else if (layer->activation == ACTIVATION_RELU) {
1271         for (int i = 0; i < N; i++)
1272             output[i] = FFMAX(0, output[i]);
1273     } else {
1274         av_assert0(0);
1275     }
1276 }
1277
1278 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1279 {
1280     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1281     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1282     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1283     const int M = gru->nb_inputs;
1284     const int N = gru->nb_neurons;
1285     const int AN = FFALIGN(N, 4);
1286     const int AM = FFALIGN(M, 4);
1287     const int stride = 3 * AN, istride = 3 * AM;
1288
1289     for (int i = 0; i < N; i++) {
1290         /* Compute update gate. */
1291         float sum = gru->bias[i];
1292
1293         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1294         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1295         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1296     }
1297
1298     for (int i = 0; i < N; i++) {
1299         /* Compute reset gate. */
1300         float sum = gru->bias[N + i];
1301
1302         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1303         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1304         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1305     }
1306
1307     for (int i = 0; i < N; i++) {
1308         /* Compute output. */
1309         float sum = gru->bias[2 * N + i];
1310
1311         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1312         for (int j = 0; j < N; j++)
1313             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1314
1315         if (gru->activation == ACTIVATION_SIGMOID)
1316             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1317         else if (gru->activation == ACTIVATION_TANH)
1318             sum = tansig_approx(WEIGHTS_SCALE * sum);
1319         else if (gru->activation == ACTIVATION_RELU)
1320             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1321         else
1322             av_assert0(0);
1323         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1324     }
1325
1326     RNN_COPY(state, h, N);
1327 }
1328
1329 #define INPUT_SIZE 42
1330
1331 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1332 {
1333     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1334     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1335     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1336
1337     compute_dense(rnn->model->input_dense, dense_out, input);
1338     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1339     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1340
1341     memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1342     memcpy(noise_input + rnn->model->input_dense_size,
1343            rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1344     memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1345            input, INPUT_SIZE * sizeof(float));
1346
1347     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1348
1349     memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1350     memcpy(denoise_input + rnn->model->vad_gru_size,
1351            rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1352     memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1353            input, INPUT_SIZE * sizeof(float));
1354
1355     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1356     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1357 }
1358
1359 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1360                              int disabled)
1361 {
1362     AVComplexFloat X[FREQ_SIZE];
1363     AVComplexFloat P[WINDOW_SIZE];
1364     float x[FRAME_SIZE];
1365     float Ex[NB_BANDS], Ep[NB_BANDS];
1366     LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1367     float features[NB_FEATURES];
1368     float g[NB_BANDS];
1369     float gf[FREQ_SIZE];
1370     float vad_prob = 0;
1371     float *history = st->history;
1372     static const float a_hp[2] = {-1.99599, 0.99600};
1373     static const float b_hp[2] = {-2, 1};
1374     int silence;
1375
1376     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1377     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1378
1379     if (!silence && !disabled) {
1380         compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1381         pitch_filter(X, P, Ex, Ep, Exp, g);
1382         for (int i = 0; i < NB_BANDS; i++) {
1383             float alpha = .6f;
1384
1385             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1386             st->lastg[i] = g[i];
1387         }
1388
1389         interp_band_gain(gf, g);
1390
1391         for (int i = 0; i < FREQ_SIZE; i++) {
1392             X[i].re *= gf[i];
1393             X[i].im *= gf[i];
1394         }
1395     }
1396
1397     frame_synthesis(s, st, out, X);
1398     memcpy(history, in, FRAME_SIZE * sizeof(*history));
1399
1400     return vad_prob;
1401 }
1402
1403 typedef struct ThreadData {
1404     AVFrame *in, *out;
1405 } ThreadData;
1406
1407 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1408 {
1409     AudioRNNContext *s = ctx->priv;
1410     ThreadData *td = arg;
1411     AVFrame *in = td->in;
1412     AVFrame *out = td->out;
1413     const int start = (out->channels * jobnr) / nb_jobs;
1414     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1415
1416     for (int ch = start; ch < end; ch++) {
1417         rnnoise_channel(s, &s->st[ch],
1418                         (float *)out->extended_data[ch],
1419                         (const float *)in->extended_data[ch],
1420                         ctx->is_disabled);
1421     }
1422
1423     return 0;
1424 }
1425
1426 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1427 {
1428     AVFilterContext *ctx = inlink->dst;
1429     AVFilterLink *outlink = ctx->outputs[0];
1430     AVFrame *out = NULL;
1431     ThreadData td;
1432
1433     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1434     if (!out) {
1435         av_frame_free(&in);
1436         return AVERROR(ENOMEM);
1437     }
1438     out->pts = in->pts;
1439
1440     td.in = in; td.out = out;
1441     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1442                                                                    ff_filter_get_nb_threads(ctx)));
1443
1444     av_frame_free(&in);
1445     return ff_filter_frame(outlink, out);
1446 }
1447
1448 static int activate(AVFilterContext *ctx)
1449 {
1450     AVFilterLink *inlink = ctx->inputs[0];
1451     AVFilterLink *outlink = ctx->outputs[0];
1452     AVFrame *in = NULL;
1453     int ret;
1454
1455     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1456
1457     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1458     if (ret < 0)
1459         return ret;
1460
1461     if (ret > 0)
1462         return filter_frame(inlink, in);
1463
1464     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1465     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1466
1467     return FFERROR_NOT_READY;
1468 }
1469
1470 static int open_model(AVFilterContext *ctx, RNNModel **model)
1471 {
1472     AudioRNNContext *s = ctx->priv;
1473     int ret;
1474     FILE *f;
1475
1476     if (!s->model_name)
1477         return AVERROR(EINVAL);
1478     f = av_fopen_utf8(s->model_name, "r");
1479     if (!f) {
1480         av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1481         return AVERROR(EINVAL);
1482     }
1483
1484     ret = rnnoise_model_from_file(f, model);
1485     fclose(f);
1486     if (!*model || ret < 0)
1487         return ret;
1488
1489     return 0;
1490 }
1491
1492 static av_cold int init(AVFilterContext *ctx)
1493 {
1494     AudioRNNContext *s = ctx->priv;
1495     int ret;
1496
1497     s->fdsp = avpriv_float_dsp_alloc(0);
1498     if (!s->fdsp)
1499         return AVERROR(ENOMEM);
1500
1501     ret = open_model(ctx, &s->model[0]);
1502     if (ret < 0)
1503         return ret;
1504
1505     for (int i = 0; i < FRAME_SIZE; i++) {
1506         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1507         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1508     }
1509
1510     for (int i = 0; i < NB_BANDS; i++) {
1511         for (int j = 0; j < NB_BANDS; j++) {
1512             s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1513             if (j == 0)
1514                 s->dct_table[j][i] *= sqrtf(.5);
1515         }
1516     }
1517
1518     return 0;
1519 }
1520
1521 static void free_model(AVFilterContext *ctx, int n)
1522 {
1523     AudioRNNContext *s = ctx->priv;
1524
1525     rnnoise_model_free(s->model[n]);
1526     s->model[n] = NULL;
1527
1528     for (int ch = 0; ch < s->channels && s->st; ch++) {
1529         av_freep(&s->st[ch].rnn[n].vad_gru_state);
1530         av_freep(&s->st[ch].rnn[n].noise_gru_state);
1531         av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1532     }
1533 }
1534
1535 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1536                            char *res, int res_len, int flags)
1537 {
1538     AudioRNNContext *s = ctx->priv;
1539     int ret;
1540
1541     ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1542     if (ret < 0)
1543         return ret;
1544
1545     ret = open_model(ctx, &s->model[1]);
1546     if (ret < 0)
1547         return ret;
1548
1549     FFSWAP(RNNModel *, s->model[0], s->model[1]);
1550     for (int ch = 0; ch < s->channels; ch++)
1551         FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1552
1553     ret = config_input(ctx->inputs[0]);
1554     if (ret < 0) {
1555         for (int ch = 0; ch < s->channels; ch++)
1556             FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1557         FFSWAP(RNNModel *, s->model[0], s->model[1]);
1558         return ret;
1559     }
1560
1561     free_model(ctx, 1);
1562     return 0;
1563 }
1564
1565 static av_cold void uninit(AVFilterContext *ctx)
1566 {
1567     AudioRNNContext *s = ctx->priv;
1568
1569     av_freep(&s->fdsp);
1570     free_model(ctx, 0);
1571     for (int ch = 0; ch < s->channels && s->st; ch++) {
1572         av_tx_uninit(&s->st[ch].tx);
1573         av_tx_uninit(&s->st[ch].txi);
1574     }
1575     av_freep(&s->st);
1576 }
1577
1578 static const AVFilterPad inputs[] = {
1579     {
1580         .name         = "default",
1581         .type         = AVMEDIA_TYPE_AUDIO,
1582         .config_props = config_input,
1583     },
1584     { NULL }
1585 };
1586
1587 static const AVFilterPad outputs[] = {
1588     {
1589         .name          = "default",
1590         .type          = AVMEDIA_TYPE_AUDIO,
1591     },
1592     { NULL }
1593 };
1594
1595 #define OFFSET(x) offsetof(AudioRNNContext, x)
1596 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1597
1598 static const AVOption arnndn_options[] = {
1599     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1600     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1601     { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1602     { NULL }
1603 };
1604
1605 AVFILTER_DEFINE_CLASS(arnndn);
1606
1607 AVFilter ff_af_arnndn = {
1608     .name          = "arnndn",
1609     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1610     .query_formats = query_formats,
1611     .priv_size     = sizeof(AudioRNNContext),
1612     .priv_class    = &arnndn_class,
1613     .activate      = activate,
1614     .init          = init,
1615     .uninit        = uninit,
1616     .inputs        = inputs,
1617     .outputs       = outputs,
1618     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1619                      AVFILTER_FLAG_SLICE_THREADS,
1620     .process_command = process_command,
1621 };