2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56 #define SQUARE(x) ((x)*(x))
61 #define NB_DELTA_CEPS 6
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65 #define WEIGHTS_SCALE (1.f/256)
67 #define MAX_NEURONS 128
69 #define ACTIVATION_TANH 0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU 2
75 typedef struct DenseLayer {
77 const float *input_weights;
83 typedef struct GRULayer {
85 const float *input_weights;
86 const float *recurrent_weights;
92 typedef struct RNNModel {
94 const DenseLayer *input_dense;
97 const GRULayer *vad_gru;
100 const GRULayer *noise_gru;
102 int denoise_gru_size;
103 const GRULayer *denoise_gru;
105 int denoise_output_size;
106 const DenseLayer *denoise_output;
109 const DenseLayer *vad_output;
112 typedef struct RNNState {
113 float *vad_gru_state;
114 float *noise_gru_state;
115 float *denoise_gru_state;
119 typedef struct DenoiseState {
120 float analysis_mem[FRAME_SIZE];
121 float cepstral_mem[CEPS_MEM][NB_BANDS];
123 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124 float pitch_buf[PITCH_BUF_SIZE];
125 float pitch_enh_buf[PITCH_BUF_SIZE];
129 float lastg[NB_BANDS];
131 AVTXContext *tx, *txi;
132 av_tx_fn tx_fn, txi_fn;
135 typedef struct AudioRNNContext {
136 const AVClass *class;
143 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144 float dct_table[NB_BANDS*NB_BANDS];
148 AVFloatDSPContext *fdsp;
151 #define F_ACTIVATION_TANH 0
152 #define F_ACTIVATION_SIGMOID 1
153 #define F_ACTIVATION_RELU 2
155 static void rnnoise_model_free(RNNModel *model)
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
160 av_free((void *) model->name->input_weights); \
161 av_free((void *) model->name->bias); \
162 av_free((void *) model->name); \
165 #define FREE_GRU(name) do { \
167 av_free((void *) model->name->input_weights); \
168 av_free((void *) model->name->recurrent_weights); \
169 av_free((void *) model->name->bias); \
170 av_free((void *) model->name); \
176 FREE_DENSE(input_dense);
179 FREE_GRU(denoise_gru);
180 FREE_DENSE(denoise_output);
181 FREE_DENSE(vad_output);
185 static RNNModel *rnnoise_model_from_file(FILE *f)
188 DenseLayer *input_dense;
191 GRULayer *denoise_gru;
192 DenseLayer *denoise_output;
193 DenseLayer *vad_output;
196 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
199 ret = av_calloc(1, sizeof(RNNModel));
203 #define ALLOC_LAYER(type, name) \
204 name = av_calloc(1, sizeof(type)); \
206 rnnoise_model_free(ret); \
211 ALLOC_LAYER(DenseLayer, input_dense);
212 ALLOC_LAYER(GRULayer, vad_gru);
213 ALLOC_LAYER(GRULayer, noise_gru);
214 ALLOC_LAYER(GRULayer, denoise_gru);
215 ALLOC_LAYER(DenseLayer, denoise_output);
216 ALLOC_LAYER(DenseLayer, vad_output);
218 #define INPUT_VAL(name) do { \
219 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220 rnnoise_model_free(ret); \
226 #define INPUT_ACTIVATION(name) do { \
228 INPUT_VAL(activation); \
229 switch (activation) { \
230 case F_ACTIVATION_SIGMOID: \
231 name = ACTIVATION_SIGMOID; \
233 case F_ACTIVATION_RELU: \
234 name = ACTIVATION_RELU; \
237 name = ACTIVATION_TANH; \
241 #define INPUT_ARRAY(name, len) do { \
242 float *values = av_calloc((len), sizeof(float)); \
244 rnnoise_model_free(ret); \
248 for (int i = 0; i < (len); i++) { \
249 if (fscanf(f, "%d", &in) != 1) { \
250 rnnoise_model_free(ret); \
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
260 rnnoise_model_free(ret); \
264 for (int k = 0; k < (len0); k++) { \
265 for (int i = 0; i < (len2); i++) { \
266 for (int j = 0; j < (len1); j++) { \
267 if (fscanf(f, "%d", &in) != 1) { \
268 rnnoise_model_free(ret); \
271 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
277 #define INPUT_DENSE(name) do { \
278 INPUT_VAL(name->nb_inputs); \
279 INPUT_VAL(name->nb_neurons); \
280 ret->name ## _size = name->nb_neurons; \
281 INPUT_ACTIVATION(name->activation); \
282 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283 INPUT_ARRAY(name->bias, name->nb_neurons); \
286 #define INPUT_GRU(name) do { \
287 INPUT_VAL(name->nb_inputs); \
288 INPUT_VAL(name->nb_neurons); \
289 ret->name ## _size = name->nb_neurons; \
290 INPUT_ACTIVATION(name->activation); \
291 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
296 INPUT_DENSE(input_dense);
298 INPUT_GRU(noise_gru);
299 INPUT_GRU(denoise_gru);
300 INPUT_DENSE(denoise_output);
301 INPUT_DENSE(vad_output);
303 if (vad_output->nb_neurons != 1) {
304 rnnoise_model_free(ret);
311 static int query_formats(AVFilterContext *ctx)
313 AVFilterFormats *formats = NULL;
314 AVFilterChannelLayouts *layouts = NULL;
315 static const enum AVSampleFormat sample_fmts[] = {
319 int ret, sample_rates[] = { 48000, -1 };
321 formats = ff_make_format_list(sample_fmts);
323 return AVERROR(ENOMEM);
324 ret = ff_set_common_formats(ctx, formats);
328 layouts = ff_all_channel_counts();
330 return AVERROR(ENOMEM);
332 ret = ff_set_common_channel_layouts(ctx, layouts);
336 formats = ff_make_format_list(sample_rates);
338 return AVERROR(ENOMEM);
339 return ff_set_common_samplerates(ctx, formats);
342 static int config_input(AVFilterLink *inlink)
344 AVFilterContext *ctx = inlink->dst;
345 AudioRNNContext *s = ctx->priv;
348 s->channels = inlink->channels;
350 s->st = av_calloc(s->channels, sizeof(DenoiseState));
352 return AVERROR(ENOMEM);
354 for (int i = 0; i < s->channels; i++) {
355 DenoiseState *st = &s->st[i];
357 st->rnn.model = s->model;
358 st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359 st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360 st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361 if (!st->rnn.vad_gru_state ||
362 !st->rnn.noise_gru_state ||
363 !st->rnn.denoise_gru_state)
364 return AVERROR(ENOMEM);
366 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
370 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
378 static void biquad(float *y, float mem[2], const float *x,
379 const float *b, const float *a, int N)
381 for (int i = 0; i < N; i++) {
386 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387 mem[1] = (b[1]*xi - a[1]*yi);
392 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
396 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
398 AVComplexFloat x[WINDOW_SIZE];
399 AVComplexFloat y[WINDOW_SIZE];
401 for (int i = 0; i < WINDOW_SIZE; i++) {
406 st->tx_fn(st->tx, y, x, sizeof(float));
408 RNN_COPY(out, y, FREQ_SIZE);
411 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
413 AVComplexFloat x[WINDOW_SIZE];
414 AVComplexFloat y[WINDOW_SIZE];
416 for (int i = 0; i < FREQ_SIZE; i++)
419 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
420 x[i].re = x[WINDOW_SIZE - i].re;
421 x[i].im = -x[WINDOW_SIZE - i].im;
424 st->txi_fn(st->txi, y, x, sizeof(float));
426 for (int i = 0; i < WINDOW_SIZE; i++)
427 out[i] = y[i].re / WINDOW_SIZE;
430 static const uint8_t eband5ms[] = {
431 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
432 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
435 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
437 float sum[NB_BANDS] = {0};
439 for (int i = 0; i < NB_BANDS - 1; i++) {
442 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
443 for (int j = 0; j < band_size; j++) {
444 float tmp, frac = (float)j / band_size;
446 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
447 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
448 sum[i] += (1.f - frac) * tmp;
449 sum[i + 1] += frac * tmp;
454 sum[NB_BANDS - 1] *= 2;
456 for (int i = 0; i < NB_BANDS; i++)
460 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
462 float sum[NB_BANDS] = { 0 };
464 for (int i = 0; i < NB_BANDS - 1; i++) {
467 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
468 for (int j = 0; j < band_size; j++) {
469 float tmp, frac = (float)j / band_size;
471 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
472 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
473 sum[i] += (1 - frac) * tmp;
474 sum[i + 1] += frac * tmp;
479 sum[NB_BANDS-1] *= 2;
481 for (int i = 0; i < NB_BANDS; i++)
485 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
487 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
489 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
490 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
491 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
492 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
493 forward_transform(st, X, x);
494 compute_band_energy(Ex, X);
497 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
499 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
501 inverse_transform(st, x, y);
502 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
503 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
504 RNN_COPY(out, x, FRAME_SIZE);
505 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
508 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
510 float y_0, y_1, y_2, y_3 = 0;
517 for (j = 0; j < len - 3; j += 4) {
577 static inline float celt_inner_prod(const float *x,
578 const float *y, int N)
582 for (int i = 0; i < N; i++)
588 static void celt_pitch_xcorr(const float *x, const float *y,
589 float *xcorr, int len, int max_pitch)
593 for (i = 0; i < max_pitch - 3; i += 4) {
594 float sum[4] = { 0, 0, 0, 0};
596 xcorr_kernel(x, y + i, sum, len);
599 xcorr[i + 1] = sum[1];
600 xcorr[i + 2] = sum[2];
601 xcorr[i + 3] = sum[3];
603 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
604 for (; i < max_pitch; i++) {
605 xcorr[i] = celt_inner_prod(x, y + i, len);
609 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
610 float *ac, /* out: [0...lag-1] ac values */
619 float xx[PITCH_BUF_SIZE>>1];
624 for (int i = 0; i < n; i++)
626 for (int i = 0; i < overlap; i++) {
627 xx[i] = x[i] * window[i];
628 xx[n-i-1] = x[n-i-1] * window[i];
634 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
636 for (int k = 0; k <= lag; k++) {
639 for (int i = k + fastN; i < n; i++)
640 d += xptr[i] * xptr[i-k];
647 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
648 const float *ac, /* in: [0...p] autocorrelation values */
651 float r, error = ac[0];
655 for (int i = 0; i < p; i++) {
656 /* Sum up this iteration's reflection coefficient */
658 for (int j = 0; j < i; j++)
659 rr += (lpc[j] * ac[i - j]);
662 /* Update LPC coefficients and total error */
664 for (int j = 0; j < (i + 1) >> 1; j++) {
668 lpc[j] = tmp1 + (r*tmp2);
669 lpc[i-1-j] = tmp2 + (r*tmp1);
672 error = error - (r * r *error);
673 /* Bail out once we get 30 dB gain */
674 if (error < .001f * ac[0])
680 static void celt_fir5(const float *x,
686 float num0, num1, num2, num3, num4;
687 float mem0, mem1, mem2, mem3, mem4;
700 for (int i = 0; i < N; i++) {
723 static void pitch_downsample(float *x[], float *x_lp,
728 float lpc[4], mem[5]={0,0,0,0,0};
732 for (int i = 1; i < len >> 1; i++)
733 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
734 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
736 for (int i = 1; i < len >> 1; i++)
737 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
738 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
741 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
743 /* Noise floor -40 dB */
746 for (int i = 1; i <= 4; i++) {
747 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
748 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
751 celt_lpc(lpc, ac, 4);
752 for (int i = 0; i < 4; i++) {
754 lpc[i] = (lpc[i] * tmp);
757 lpc2[0] = lpc[0] + .8f;
758 lpc2[1] = lpc[1] + (c1 * lpc[0]);
759 lpc2[2] = lpc[2] + (c1 * lpc[1]);
760 lpc2[3] = lpc[3] + (c1 * lpc[2]);
761 lpc2[4] = (c1 * lpc[3]);
762 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
765 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
766 int N, float *xy1, float *xy2)
768 float xy01 = 0, xy02 = 0;
770 for (int i = 0; i < N; i++) {
771 xy01 += (x[i] * y01[i]);
772 xy02 += (x[i] * y02[i]);
779 static float compute_pitch_gain(float xy, float xx, float yy)
781 return xy / sqrtf(1.f + xx * yy);
784 static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
785 static const float remove_doubling(float *x, int maxperiod, int minperiod,
786 int N, int *T0_, int prev_period, float prev_gain)
793 float best_xy, best_yy;
796 float yy_lookup[PITCH_MAX_PERIOD+1];
798 minperiod0 = minperiod;
809 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
812 for (i = 1; i <= maxperiod; i++) {
813 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
814 yy_lookup[i] = FFMAX(0, yy);
819 g = g0 = compute_pitch_gain(xy, xx, yy);
820 /* Look for any pitch at T/k */
821 for (k = 2; k <= 15; k++) {
829 /* Look for another strong correlation at T1b */
838 T1b = (2*second_check[k]*T0+k)/(2*k);
840 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
841 xy = .5f * (xy + xy2);
842 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
843 g1 = compute_pitch_gain(xy, xx, yy);
844 if (FFABS(T1-prev_period)<=1)
846 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
847 cont = prev_gain * .5f;
850 thresh = FFMAX(.3f, (.7f * g0) - cont);
851 /* Bias against very high pitch (very short period) to avoid false-positives
852 due to short-term correlation */
854 thresh = FFMAX(.4f, (.85f * g0) - cont);
855 else if (T1<2*minperiod)
856 thresh = FFMAX(.5f, (.9f * g0) - cont);
865 best_xy = FFMAX(0, best_xy);
866 if (best_yy <= best_xy)
869 pg = best_xy/(best_yy + 1);
871 for (k = 0; k < 3; k++)
872 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
873 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
875 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
888 static void find_best_pitch(float *xcorr, float *y, int len,
889 int max_pitch, int *best_pitch)
902 for (int j = 0; j < len; j++)
905 for (int i = 0; i < max_pitch; i++) {
911 /* Considering the range of xcorr16, this should avoid both underflows
912 and overflows (inf) when squaring xcorr16 */
914 num = xcorr16 * xcorr16;
915 if ((num * best_den[1]) > (best_num[1] * Syy)) {
916 if ((num * best_den[0]) > (best_num[0] * Syy)) {
917 best_num[1] = best_num[0];
918 best_den[1] = best_den[0];
919 best_pitch[1] = best_pitch[0];
930 Syy += y[i+len]*y[i+len] - y[i] * y[i];
935 static void pitch_search(const float *x_lp, float *y,
936 int len, int max_pitch, int *pitch)
939 int best_pitch[2]={0,0};
942 float x_lp4[WINDOW_SIZE];
943 float y_lp4[WINDOW_SIZE];
944 float xcorr[WINDOW_SIZE];
948 /* Downsample by 2 again */
949 for (int j = 0; j < len >> 2; j++)
950 x_lp4[j] = x_lp[2*j];
951 for (int j = 0; j < lag >> 2; j++)
954 /* Coarse search with 4x decimation */
956 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
958 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
960 /* Finer search with 2x decimation */
961 for (int i = 0; i < max_pitch >> 1; i++) {
964 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
966 sum = celt_inner_prod(x_lp, y+i, len>>1);
967 xcorr[i] = FFMAX(-1, sum);
970 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
972 /* Refine by pseudo-interpolation */
973 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
976 a = xcorr[best_pitch[0] - 1];
977 b = xcorr[best_pitch[0]];
978 c = xcorr[best_pitch[0] + 1];
979 if (c - a > .7f * (b - a))
981 else if (a - c > .7f * (b-c))
989 *pitch = 2 * best_pitch[0] - offset;
992 static void dct(AudioRNNContext *s, float *out, const float *in)
994 for (int i = 0; i < NB_BANDS; i++) {
997 for (int j = 0; j < NB_BANDS; j++) {
998 sum += in[j] * s->dct_table[j * NB_BANDS + i];
1000 out[i] = sum * sqrtf(2.f / 22);
1004 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1005 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1008 float *ceps_0, *ceps_1, *ceps_2;
1009 float spec_variability = 0;
1011 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1012 float pitch_buf[PITCH_BUF_SIZE>>1];
1016 float tmp[NB_BANDS];
1017 float follow, logMax;
1019 frame_analysis(s, st, X, Ex, in);
1020 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1021 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1022 pre[0] = &st->pitch_buf[0];
1023 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1024 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1025 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1026 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1028 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1029 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1030 st->last_period = pitch_index;
1031 st->last_gain = gain;
1033 for (int i = 0; i < WINDOW_SIZE; i++)
1034 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1036 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1037 forward_transform(st, P, p);
1038 compute_band_energy(Ep, P);
1039 compute_band_corr(Exp, X, P);
1041 for (int i = 0; i < NB_BANDS; i++)
1042 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1046 for (int i = 0; i < NB_DELTA_CEPS; i++)
1047 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1049 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1050 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1051 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1055 for (int i = 0; i < NB_BANDS; i++) {
1056 Ly[i] = log10f(1e-2f + Ex[i]);
1057 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1058 logMax = FFMAX(logMax, Ly[i]);
1059 follow = FFMAX(follow-1.5, Ly[i]);
1064 /* If there's no audio, avoid messing up the state. */
1065 RNN_CLEAR(features, NB_FEATURES);
1069 dct(s, features, Ly);
1072 ceps_0 = st->cepstral_mem[st->memid];
1073 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1074 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1076 for (int i = 0; i < NB_BANDS; i++)
1077 ceps_0[i] = features[i];
1080 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1081 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1082 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1083 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1085 /* Spectral variability features. */
1086 if (st->memid == CEPS_MEM)
1089 for (int i = 0; i < CEPS_MEM; i++) {
1090 float mindist = 1e15f;
1091 for (int j = 0; j < CEPS_MEM; j++) {
1093 for (int k = 0; k < NB_BANDS; k++) {
1096 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1101 mindist = FFMIN(mindist, dist);
1104 spec_variability += mindist;
1107 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1112 static void interp_band_gain(float *g, const float *bandE)
1114 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1116 for (int i = 0; i < NB_BANDS - 1; i++) {
1117 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1119 for (int j = 0; j < band_size; j++) {
1120 float frac = (float)j / band_size;
1122 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1127 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1128 const float *Exp, const float *g)
1130 float newE[NB_BANDS];
1132 float norm[NB_BANDS];
1133 float rf[FREQ_SIZE] = {0};
1134 float normf[FREQ_SIZE]={0};
1136 for (int i = 0; i < NB_BANDS; i++) {
1137 if (Exp[i]>g[i]) r[i] = 1;
1138 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1139 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1140 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1142 interp_band_gain(rf, r);
1143 for (int i = 0; i < FREQ_SIZE; i++) {
1144 X[i].re += rf[i]*P[i].re;
1145 X[i].im += rf[i]*P[i].im;
1147 compute_band_energy(newE, X);
1148 for (int i = 0; i < NB_BANDS; i++) {
1149 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1151 interp_band_gain(normf, norm);
1152 for (int i = 0; i < FREQ_SIZE; i++) {
1153 X[i].re *= normf[i];
1154 X[i].im *= normf[i];
1158 static const float tansig_table[201] = {
1159 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1160 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1161 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1162 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1163 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1164 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1165 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1166 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1167 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1168 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1169 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1170 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1171 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1172 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1173 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1174 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1175 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1176 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1177 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1178 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1179 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1180 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1181 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1182 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1183 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1184 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1185 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1186 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1187 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1188 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1189 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1190 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1191 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1192 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1193 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1194 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1195 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1196 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1197 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1198 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1202 static inline float tansig_approx(float x)
1208 /* Tests are reversed to catch NaNs */
1213 /* Another check in case of -ffast-math */
1222 i = (int)floor(.5f+25*x);
1224 y = tansig_table[i];
1226 y = y + x*dy*(1 - y*x);
1230 static inline float sigmoid_approx(float x)
1232 return .5f + .5f*tansig_approx(.5f*x);
1235 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1237 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1239 for (int i = 0; i < N; i++) {
1240 /* Compute update gate. */
1241 float sum = layer->bias[i];
1243 for (int j = 0; j < M; j++)
1244 sum += layer->input_weights[j * stride + i] * input[j];
1246 output[i] = WEIGHTS_SCALE * sum;
1249 if (layer->activation == ACTIVATION_SIGMOID) {
1250 for (int i = 0; i < N; i++)
1251 output[i] = sigmoid_approx(output[i]);
1252 } else if (layer->activation == ACTIVATION_TANH) {
1253 for (int i = 0; i < N; i++)
1254 output[i] = tansig_approx(output[i]);
1255 } else if (layer->activation == ACTIVATION_RELU) {
1256 for (int i = 0; i < N; i++)
1257 output[i] = FFMAX(0, output[i]);
1263 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1265 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1266 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1267 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1268 const int M = gru->nb_inputs;
1269 const int N = gru->nb_neurons;
1270 const int AN = FFALIGN(N, 4);
1271 const int AM = FFALIGN(M, 4);
1272 const int stride = 3 * AN, istride = 3 * AM;
1274 for (int i = 0; i < N; i++) {
1275 /* Compute update gate. */
1276 float sum = gru->bias[i];
1278 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1279 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1280 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1283 for (int i = 0; i < N; i++) {
1284 /* Compute reset gate. */
1285 float sum = gru->bias[N + i];
1287 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1288 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1289 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1292 for (int i = 0; i < N; i++) {
1293 /* Compute output. */
1294 float sum = gru->bias[2 * N + i];
1296 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1297 for (int j = 0; j < N; j++)
1298 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1300 if (gru->activation == ACTIVATION_SIGMOID)
1301 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1302 else if (gru->activation == ACTIVATION_TANH)
1303 sum = tansig_approx(WEIGHTS_SCALE * sum);
1304 else if (gru->activation == ACTIVATION_RELU)
1305 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1308 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1311 RNN_COPY(state, h, N);
1314 #define INPUT_SIZE 42
1316 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1318 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1319 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1320 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1322 compute_dense(rnn->model->input_dense, dense_out, input);
1323 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1324 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1326 for (int i = 0; i < rnn->model->input_dense_size; i++)
1327 noise_input[i] = dense_out[i];
1328 for (int i = 0; i < rnn->model->vad_gru_size; i++)
1329 noise_input[i + rnn->model->input_dense_size] = rnn->vad_gru_state[i];
1330 for (int i = 0; i < INPUT_SIZE; i++)
1331 noise_input[i + rnn->model->input_dense_size + rnn->model->vad_gru_size] = input[i];
1333 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1335 for (int i = 0; i < rnn->model->vad_gru_size; i++)
1336 denoise_input[i] = rnn->vad_gru_state[i];
1337 for (int i = 0; i < rnn->model->noise_gru_size; i++)
1338 denoise_input[i + rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
1339 for (int i = 0; i < INPUT_SIZE; i++)
1340 denoise_input[i + rnn->model->vad_gru_size + rnn->model->noise_gru_size] = input[i];
1342 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1343 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1346 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1348 AVComplexFloat X[FREQ_SIZE];
1349 AVComplexFloat P[WINDOW_SIZE];
1350 float x[FRAME_SIZE];
1351 float Ex[NB_BANDS], Ep[NB_BANDS];
1352 float Exp[NB_BANDS];
1353 float features[NB_FEATURES];
1355 float gf[FREQ_SIZE];
1357 static const float a_hp[2] = {-1.99599, 0.99600};
1358 static const float b_hp[2] = {-2, 1};
1361 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1362 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1365 compute_rnn(s, &st->rnn, g, &vad_prob, features);
1366 pitch_filter(X, P, Ex, Ep, Exp, g);
1367 for (int i = 0; i < NB_BANDS; i++) {
1370 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1371 st->lastg[i] = g[i];
1374 interp_band_gain(gf, g);
1376 for (int i = 0; i < FREQ_SIZE; i++) {
1382 frame_synthesis(s, st, out, X);
1387 typedef struct ThreadData {
1391 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1393 AudioRNNContext *s = ctx->priv;
1394 ThreadData *td = arg;
1395 AVFrame *in = td->in;
1396 AVFrame *out = td->out;
1397 const int start = (out->channels * jobnr) / nb_jobs;
1398 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1400 for (int ch = start; ch < end; ch++) {
1401 rnnoise_channel(s, &s->st[ch],
1402 (float *)out->extended_data[ch],
1403 (const float *)in->extended_data[ch]);
1409 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1411 AVFilterContext *ctx = inlink->dst;
1412 AVFilterLink *outlink = ctx->outputs[0];
1413 AVFrame *out = NULL;
1416 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1419 return AVERROR(ENOMEM);
1423 td.in = in; td.out = out;
1424 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1425 ff_filter_get_nb_threads(ctx)));
1428 return ff_filter_frame(outlink, out);
1431 static int activate(AVFilterContext *ctx)
1433 AVFilterLink *inlink = ctx->inputs[0];
1434 AVFilterLink *outlink = ctx->outputs[0];
1438 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1440 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1445 return filter_frame(inlink, in);
1447 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1448 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1450 return FFERROR_NOT_READY;
1453 static av_cold int init(AVFilterContext *ctx)
1455 AudioRNNContext *s = ctx->priv;
1458 s->fdsp = avpriv_float_dsp_alloc(0);
1460 return AVERROR(ENOMEM);
1463 return AVERROR(EINVAL);
1464 f = av_fopen_utf8(s->model_name, "r");
1466 return AVERROR(EINVAL);
1468 s->model = rnnoise_model_from_file(f);
1471 return AVERROR(EINVAL);
1473 for (int i = 0; i < FRAME_SIZE; i++) {
1474 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1475 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1478 for (int i = 0; i < NB_BANDS; i++) {
1479 for (int j = 0; j < NB_BANDS; j++) {
1480 s->dct_table[i*NB_BANDS + j] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1482 s->dct_table[i*NB_BANDS + j] *= sqrtf(.5);
1489 static av_cold void uninit(AVFilterContext *ctx)
1491 AudioRNNContext *s = ctx->priv;
1494 rnnoise_model_free(s->model);
1498 for (int ch = 0; ch < s->channels; ch++) {
1499 av_freep(&s->st[ch].rnn.vad_gru_state);
1500 av_freep(&s->st[ch].rnn.noise_gru_state);
1501 av_freep(&s->st[ch].rnn.denoise_gru_state);
1502 av_tx_uninit(&s->st[ch].tx);
1503 av_tx_uninit(&s->st[ch].txi);
1509 static const AVFilterPad inputs[] = {
1512 .type = AVMEDIA_TYPE_AUDIO,
1513 .config_props = config_input,
1518 static const AVFilterPad outputs[] = {
1521 .type = AVMEDIA_TYPE_AUDIO,
1526 #define OFFSET(x) offsetof(AudioRNNContext, x)
1527 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1529 static const AVOption arnndn_options[] = {
1530 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1531 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1535 AVFILTER_DEFINE_CLASS(arnndn);
1537 AVFilter ff_af_arnndn = {
1539 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1540 .query_formats = query_formats,
1541 .priv_size = sizeof(AudioRNNContext),
1542 .priv_class = &arnndn_class,
1543 .activate = activate,
1548 .flags = AVFILTER_FLAG_SLICE_THREADS,