2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56 #define SQUARE(x) ((x)*(x))
61 #define NB_DELTA_CEPS 6
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65 #define WEIGHTS_SCALE (1.f/256)
67 #define MAX_NEURONS 128
69 #define ACTIVATION_TANH 0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU 2
75 typedef struct DenseLayer {
77 const float *input_weights;
83 typedef struct GRULayer {
85 const float *input_weights;
86 const float *recurrent_weights;
92 typedef struct RNNModel {
94 const DenseLayer *input_dense;
97 const GRULayer *vad_gru;
100 const GRULayer *noise_gru;
102 int denoise_gru_size;
103 const GRULayer *denoise_gru;
105 int denoise_output_size;
106 const DenseLayer *denoise_output;
109 const DenseLayer *vad_output;
112 typedef struct RNNState {
113 float *vad_gru_state;
114 float *noise_gru_state;
115 float *denoise_gru_state;
119 typedef struct DenoiseState {
120 float analysis_mem[FRAME_SIZE];
121 float cepstral_mem[CEPS_MEM][NB_BANDS];
123 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124 float pitch_buf[PITCH_BUF_SIZE];
125 float pitch_enh_buf[PITCH_BUF_SIZE];
129 float lastg[NB_BANDS];
131 AVTXContext *tx, *txi;
132 av_tx_fn tx_fn, txi_fn;
135 typedef struct AudioRNNContext {
136 const AVClass *class;
143 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144 float dct_table[NB_BANDS*NB_BANDS];
148 AVFloatDSPContext *fdsp;
151 #define F_ACTIVATION_TANH 0
152 #define F_ACTIVATION_SIGMOID 1
153 #define F_ACTIVATION_RELU 2
155 static void rnnoise_model_free(RNNModel *model)
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
160 av_free((void *) model->name->input_weights); \
161 av_free((void *) model->name->bias); \
162 av_free((void *) model->name); \
165 #define FREE_GRU(name) do { \
167 av_free((void *) model->name->input_weights); \
168 av_free((void *) model->name->recurrent_weights); \
169 av_free((void *) model->name->bias); \
170 av_free((void *) model->name); \
176 FREE_DENSE(input_dense);
179 FREE_GRU(denoise_gru);
180 FREE_DENSE(denoise_output);
181 FREE_DENSE(vad_output);
185 static RNNModel *rnnoise_model_from_file(FILE *f)
188 DenseLayer *input_dense;
191 GRULayer *denoise_gru;
192 DenseLayer *denoise_output;
193 DenseLayer *vad_output;
196 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
199 ret = av_calloc(1, sizeof(RNNModel));
203 #define ALLOC_LAYER(type, name) \
204 name = av_calloc(1, sizeof(type)); \
206 rnnoise_model_free(ret); \
211 ALLOC_LAYER(DenseLayer, input_dense);
212 ALLOC_LAYER(GRULayer, vad_gru);
213 ALLOC_LAYER(GRULayer, noise_gru);
214 ALLOC_LAYER(GRULayer, denoise_gru);
215 ALLOC_LAYER(DenseLayer, denoise_output);
216 ALLOC_LAYER(DenseLayer, vad_output);
218 #define INPUT_VAL(name) do { \
219 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220 rnnoise_model_free(ret); \
226 #define INPUT_ACTIVATION(name) do { \
228 INPUT_VAL(activation); \
229 switch (activation) { \
230 case F_ACTIVATION_SIGMOID: \
231 name = ACTIVATION_SIGMOID; \
233 case F_ACTIVATION_RELU: \
234 name = ACTIVATION_RELU; \
237 name = ACTIVATION_TANH; \
241 #define INPUT_ARRAY(name, len) do { \
242 float *values = av_calloc((len), sizeof(float)); \
244 rnnoise_model_free(ret); \
248 for (int i = 0; i < (len); i++) { \
249 if (fscanf(f, "%d", &in) != 1) { \
250 rnnoise_model_free(ret); \
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
260 rnnoise_model_free(ret); \
264 for (int k = 0; k < (len0); k++) { \
265 for (int i = 0; i < (len2); i++) { \
266 for (int j = 0; j < (len1); j++) { \
267 if (fscanf(f, "%d", &in) != 1) { \
268 rnnoise_model_free(ret); \
271 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
277 #define INPUT_DENSE(name) do { \
278 INPUT_VAL(name->nb_inputs); \
279 INPUT_VAL(name->nb_neurons); \
280 ret->name ## _size = name->nb_neurons; \
281 INPUT_ACTIVATION(name->activation); \
282 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283 INPUT_ARRAY(name->bias, name->nb_neurons); \
286 #define INPUT_GRU(name) do { \
287 INPUT_VAL(name->nb_inputs); \
288 INPUT_VAL(name->nb_neurons); \
289 ret->name ## _size = name->nb_neurons; \
290 INPUT_ACTIVATION(name->activation); \
291 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
296 INPUT_DENSE(input_dense);
298 INPUT_GRU(noise_gru);
299 INPUT_GRU(denoise_gru);
300 INPUT_DENSE(denoise_output);
301 INPUT_DENSE(vad_output);
306 static int query_formats(AVFilterContext *ctx)
308 AVFilterFormats *formats = NULL;
309 AVFilterChannelLayouts *layouts = NULL;
310 static const enum AVSampleFormat sample_fmts[] = {
314 int ret, sample_rates[] = { 48000, -1 };
316 formats = ff_make_format_list(sample_fmts);
318 return AVERROR(ENOMEM);
319 ret = ff_set_common_formats(ctx, formats);
323 layouts = ff_all_channel_counts();
325 return AVERROR(ENOMEM);
327 ret = ff_set_common_channel_layouts(ctx, layouts);
331 formats = ff_make_format_list(sample_rates);
333 return AVERROR(ENOMEM);
334 return ff_set_common_samplerates(ctx, formats);
337 static int config_input(AVFilterLink *inlink)
339 AVFilterContext *ctx = inlink->dst;
340 AudioRNNContext *s = ctx->priv;
343 s->channels = inlink->channels;
345 s->st = av_calloc(s->channels, sizeof(DenoiseState));
347 return AVERROR(ENOMEM);
349 for (int i = 0; i < s->channels; i++) {
350 DenoiseState *st = &s->st[i];
352 st->rnn.model = s->model;
353 st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
354 st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
355 st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
356 if (!st->rnn.vad_gru_state ||
357 !st->rnn.noise_gru_state ||
358 !st->rnn.denoise_gru_state)
359 return AVERROR(ENOMEM);
361 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
365 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
373 static void biquad(float *y, float mem[2], const float *x,
374 const float *b, const float *a, int N)
376 for (int i = 0; i < N; i++) {
381 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
382 mem[1] = (b[1]*xi - a[1]*yi);
387 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
388 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
389 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
391 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
393 AVComplexFloat x[WINDOW_SIZE];
394 AVComplexFloat y[WINDOW_SIZE];
396 for (int i = 0; i < WINDOW_SIZE; i++) {
401 st->tx_fn(st->tx, y, x, sizeof(float));
403 RNN_COPY(out, y, FREQ_SIZE);
406 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
408 AVComplexFloat x[WINDOW_SIZE];
409 AVComplexFloat y[WINDOW_SIZE];
411 for (int i = 0; i < FREQ_SIZE; i++)
414 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
415 x[i].re = x[WINDOW_SIZE - i].re;
416 x[i].im = -x[WINDOW_SIZE - i].im;
419 st->txi_fn(st->txi, y, x, sizeof(float));
421 for (int i = 0; i < WINDOW_SIZE; i++)
422 out[i] = y[i].re / WINDOW_SIZE;
425 static const uint8_t eband5ms[] = {
426 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
427 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
430 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
432 float sum[NB_BANDS] = {0};
434 for (int i = 0; i < NB_BANDS - 1; i++) {
437 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
438 for (int j = 0; j < band_size; j++) {
439 float tmp, frac = (float)j / band_size;
441 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
442 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
443 sum[i] += (1.f - frac) * tmp;
444 sum[i + 1] += frac * tmp;
449 sum[NB_BANDS - 1] *= 2;
451 for (int i = 0; i < NB_BANDS; i++)
455 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
457 float sum[NB_BANDS] = { 0 };
459 for (int i = 0; i < NB_BANDS - 1; i++) {
462 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
463 for (int j = 0; j < band_size; j++) {
464 float tmp, frac = (float)j / band_size;
466 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
467 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
468 sum[i] += (1 - frac) * tmp;
469 sum[i + 1] += frac * tmp;
474 sum[NB_BANDS-1] *= 2;
476 for (int i = 0; i < NB_BANDS; i++)
480 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
482 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
484 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
485 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
486 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
487 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
488 forward_transform(st, X, x);
489 compute_band_energy(Ex, X);
492 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
494 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
496 inverse_transform(st, x, y);
497 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
498 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
499 RNN_COPY(out, x, FRAME_SIZE);
500 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
503 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
505 float y_0, y_1, y_2, y_3 = 0;
512 for (j = 0; j < len - 3; j += 4) {
572 static inline float celt_inner_prod(const float *x,
573 const float *y, int N)
577 for (int i = 0; i < N; i++)
583 static void celt_pitch_xcorr(const float *x, const float *y,
584 float *xcorr, int len, int max_pitch)
588 for (i = 0; i < max_pitch - 3; i += 4) {
589 float sum[4] = { 0, 0, 0, 0};
591 xcorr_kernel(x, y + i, sum, len);
594 xcorr[i + 1] = sum[1];
595 xcorr[i + 2] = sum[2];
596 xcorr[i + 3] = sum[3];
598 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
599 for (; i < max_pitch; i++) {
600 xcorr[i] = celt_inner_prod(x, y + i, len);
604 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
605 float *ac, /* out: [0...lag-1] ac values */
614 float xx[PITCH_BUF_SIZE>>1];
619 for (int i = 0; i < n; i++)
621 for (int i = 0; i < overlap; i++) {
622 xx[i] = x[i] * window[i];
623 xx[n-i-1] = x[n-i-1] * window[i];
629 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
631 for (int k = 0; k <= lag; k++) {
634 for (int i = k + fastN; i < n; i++)
635 d += xptr[i] * xptr[i-k];
642 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
643 const float *ac, /* in: [0...p] autocorrelation values */
646 float r, error = ac[0];
650 for (int i = 0; i < p; i++) {
651 /* Sum up this iteration's reflection coefficient */
653 for (int j = 0; j < i; j++)
654 rr += (lpc[j] * ac[i - j]);
657 /* Update LPC coefficients and total error */
659 for (int j = 0; j < (i + 1) >> 1; j++) {
663 lpc[j] = tmp1 + (r*tmp2);
664 lpc[i-1-j] = tmp2 + (r*tmp1);
667 error = error - (r * r *error);
668 /* Bail out once we get 30 dB gain */
669 if (error < .001f * ac[0])
675 static void celt_fir5(const float *x,
681 float num0, num1, num2, num3, num4;
682 float mem0, mem1, mem2, mem3, mem4;
695 for (int i = 0; i < N; i++) {
718 static void pitch_downsample(float *x[], float *x_lp,
723 float lpc[4], mem[5]={0,0,0,0,0};
727 for (int i = 1; i < len >> 1; i++)
728 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
729 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
731 for (int i = 1; i < len >> 1; i++)
732 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
733 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
736 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
738 /* Noise floor -40 dB */
741 for (int i = 1; i <= 4; i++) {
742 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
743 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
746 celt_lpc(lpc, ac, 4);
747 for (int i = 0; i < 4; i++) {
749 lpc[i] = (lpc[i] * tmp);
752 lpc2[0] = lpc[0] + .8f;
753 lpc2[1] = lpc[1] + (c1 * lpc[0]);
754 lpc2[2] = lpc[2] + (c1 * lpc[1]);
755 lpc2[3] = lpc[3] + (c1 * lpc[2]);
756 lpc2[4] = (c1 * lpc[3]);
757 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
760 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
761 int N, float *xy1, float *xy2)
763 float xy01 = 0, xy02 = 0;
765 for (int i = 0; i < N; i++) {
766 xy01 += (x[i] * y01[i]);
767 xy02 += (x[i] * y02[i]);
774 static float compute_pitch_gain(float xy, float xx, float yy)
776 return xy / sqrtf(1.f + xx * yy);
779 static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
780 static const float remove_doubling(float *x, int maxperiod, int minperiod,
781 int N, int *T0_, int prev_period, float prev_gain)
788 float best_xy, best_yy;
791 float yy_lookup[PITCH_MAX_PERIOD+1];
793 minperiod0 = minperiod;
804 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
807 for (i = 1; i <= maxperiod; i++) {
808 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
809 yy_lookup[i] = FFMAX(0, yy);
814 g = g0 = compute_pitch_gain(xy, xx, yy);
815 /* Look for any pitch at T/k */
816 for (k = 2; k <= 15; k++) {
824 /* Look for another strong correlation at T1b */
833 T1b = (2*second_check[k]*T0+k)/(2*k);
835 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
836 xy = .5f * (xy + xy2);
837 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
838 g1 = compute_pitch_gain(xy, xx, yy);
839 if (FFABS(T1-prev_period)<=1)
841 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
842 cont = prev_gain * .5f;
845 thresh = FFMAX(.3f, (.7f * g0) - cont);
846 /* Bias against very high pitch (very short period) to avoid false-positives
847 due to short-term correlation */
849 thresh = FFMAX(.4f, (.85f * g0) - cont);
850 else if (T1<2*minperiod)
851 thresh = FFMAX(.5f, (.9f * g0) - cont);
860 best_xy = FFMAX(0, best_xy);
861 if (best_yy <= best_xy)
864 pg = best_xy/(best_yy + 1);
866 for (k = 0; k < 3; k++)
867 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
868 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
870 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
883 static void find_best_pitch(float *xcorr, float *y, int len,
884 int max_pitch, int *best_pitch)
897 for (int j = 0; j < len; j++)
900 for (int i = 0; i < max_pitch; i++) {
906 /* Considering the range of xcorr16, this should avoid both underflows
907 and overflows (inf) when squaring xcorr16 */
909 num = xcorr16 * xcorr16;
910 if ((num * best_den[1]) > (best_num[1] * Syy)) {
911 if ((num * best_den[0]) > (best_num[0] * Syy)) {
912 best_num[1] = best_num[0];
913 best_den[1] = best_den[0];
914 best_pitch[1] = best_pitch[0];
925 Syy += y[i+len]*y[i+len] - y[i] * y[i];
930 static void pitch_search(const float *x_lp, float *y,
931 int len, int max_pitch, int *pitch)
934 int best_pitch[2]={0,0};
937 float x_lp4[WINDOW_SIZE];
938 float y_lp4[WINDOW_SIZE];
939 float xcorr[WINDOW_SIZE];
943 /* Downsample by 2 again */
944 for (int j = 0; j < len >> 2; j++)
945 x_lp4[j] = x_lp[2*j];
946 for (int j = 0; j < lag >> 2; j++)
949 /* Coarse search with 4x decimation */
951 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
953 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
955 /* Finer search with 2x decimation */
956 for (int i = 0; i < max_pitch >> 1; i++) {
959 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
961 sum = celt_inner_prod(x_lp, y+i, len>>1);
962 xcorr[i] = FFMAX(-1, sum);
965 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
967 /* Refine by pseudo-interpolation */
968 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
971 a = xcorr[best_pitch[0] - 1];
972 b = xcorr[best_pitch[0]];
973 c = xcorr[best_pitch[0] + 1];
974 if (c - a > .7f * (b - a))
976 else if (a - c > .7f * (b-c))
984 *pitch = 2 * best_pitch[0] - offset;
987 static void dct(AudioRNNContext *s, float *out, const float *in)
989 for (int i = 0; i < NB_BANDS; i++) {
992 for (int j = 0; j < NB_BANDS; j++) {
993 sum += in[j] * s->dct_table[j * NB_BANDS + i];
995 out[i] = sum * sqrtf(2.f / 22);
999 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1000 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1003 float *ceps_0, *ceps_1, *ceps_2;
1004 float spec_variability = 0;
1006 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1007 float pitch_buf[PITCH_BUF_SIZE>>1];
1011 float tmp[NB_BANDS];
1012 float follow, logMax;
1014 frame_analysis(s, st, X, Ex, in);
1015 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1016 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1017 pre[0] = &st->pitch_buf[0];
1018 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1019 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1020 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1021 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1023 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1024 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1025 st->last_period = pitch_index;
1026 st->last_gain = gain;
1028 for (int i = 0; i < WINDOW_SIZE; i++)
1029 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1031 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1032 forward_transform(st, P, p);
1033 compute_band_energy(Ep, P);
1034 compute_band_corr(Exp, X, P);
1036 for (int i = 0; i < NB_BANDS; i++)
1037 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1041 for (int i = 0; i < NB_DELTA_CEPS; i++)
1042 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1044 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1045 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1046 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1050 for (int i = 0; i < NB_BANDS; i++) {
1051 Ly[i] = log10f(1e-2f + Ex[i]);
1052 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1053 logMax = FFMAX(logMax, Ly[i]);
1054 follow = FFMAX(follow-1.5, Ly[i]);
1059 /* If there's no audio, avoid messing up the state. */
1060 RNN_CLEAR(features, NB_FEATURES);
1064 dct(s, features, Ly);
1067 ceps_0 = st->cepstral_mem[st->memid];
1068 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1069 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1071 for (int i = 0; i < NB_BANDS; i++)
1072 ceps_0[i] = features[i];
1075 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1076 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1077 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1078 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1080 /* Spectral variability features. */
1081 if (st->memid == CEPS_MEM)
1084 for (int i = 0; i < CEPS_MEM; i++) {
1085 float mindist = 1e15f;
1086 for (int j = 0; j < CEPS_MEM; j++) {
1088 for (int k = 0; k < NB_BANDS; k++) {
1091 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1096 mindist = FFMIN(mindist, dist);
1099 spec_variability += mindist;
1102 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1107 static void interp_band_gain(float *g, const float *bandE)
1109 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1111 for (int i = 0; i < NB_BANDS - 1; i++) {
1112 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1114 for (int j = 0; j < band_size; j++) {
1115 float frac = (float)j / band_size;
1117 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1122 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1123 const float *Exp, const float *g)
1125 float newE[NB_BANDS];
1127 float norm[NB_BANDS];
1128 float rf[FREQ_SIZE] = {0};
1129 float normf[FREQ_SIZE]={0};
1131 for (int i = 0; i < NB_BANDS; i++) {
1132 if (Exp[i]>g[i]) r[i] = 1;
1133 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1134 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1135 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1137 interp_band_gain(rf, r);
1138 for (int i = 0; i < FREQ_SIZE; i++) {
1139 X[i].re += rf[i]*P[i].re;
1140 X[i].im += rf[i]*P[i].im;
1142 compute_band_energy(newE, X);
1143 for (int i = 0; i < NB_BANDS; i++) {
1144 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1146 interp_band_gain(normf, norm);
1147 for (int i = 0; i < FREQ_SIZE; i++) {
1148 X[i].re *= normf[i];
1149 X[i].im *= normf[i];
1153 static const float tansig_table[201] = {
1154 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1155 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1156 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1157 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1158 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1159 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1160 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1161 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1162 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1163 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1164 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1165 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1166 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1167 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1168 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1169 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1170 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1171 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1172 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1173 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1174 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1175 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1176 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1177 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1178 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1179 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1180 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1181 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1182 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1183 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1184 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1185 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1186 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1187 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1188 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1189 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1190 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1191 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1192 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1193 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1197 static inline float tansig_approx(float x)
1203 /* Tests are reversed to catch NaNs */
1208 /* Another check in case of -ffast-math */
1217 i = (int)floor(.5f+25*x);
1219 y = tansig_table[i];
1221 y = y + x*dy*(1 - y*x);
1225 static inline float sigmoid_approx(float x)
1227 return .5f + .5f*tansig_approx(.5f*x);
1230 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1232 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1234 for (int i = 0; i < N; i++) {
1235 /* Compute update gate. */
1236 float sum = layer->bias[i];
1238 for (int j = 0; j < M; j++)
1239 sum += layer->input_weights[j * stride + i] * input[j];
1241 output[i] = WEIGHTS_SCALE * sum;
1244 if (layer->activation == ACTIVATION_SIGMOID) {
1245 for (int i = 0; i < N; i++)
1246 output[i] = sigmoid_approx(output[i]);
1247 } else if (layer->activation == ACTIVATION_TANH) {
1248 for (int i = 0; i < N; i++)
1249 output[i] = tansig_approx(output[i]);
1250 } else if (layer->activation == ACTIVATION_RELU) {
1251 for (int i = 0; i < N; i++)
1252 output[i] = FFMAX(0, output[i]);
1258 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1260 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1261 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1262 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1263 const int M = gru->nb_inputs;
1264 const int N = gru->nb_neurons;
1265 const int AN = FFALIGN(N, 4);
1266 const int AM = FFALIGN(M, 4);
1267 const int stride = 3 * AN, istride = 3 * AM;
1269 for (int i = 0; i < N; i++) {
1270 /* Compute update gate. */
1271 float sum = gru->bias[i];
1273 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1274 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1275 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1278 for (int i = 0; i < N; i++) {
1279 /* Compute reset gate. */
1280 float sum = gru->bias[N + i];
1282 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1283 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1284 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1287 for (int i = 0; i < N; i++) {
1288 /* Compute output. */
1289 float sum = gru->bias[2 * N + i];
1291 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1292 for (int j = 0; j < N; j++)
1293 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1295 if (gru->activation == ACTIVATION_SIGMOID)
1296 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1297 else if (gru->activation == ACTIVATION_TANH)
1298 sum = tansig_approx(WEIGHTS_SCALE * sum);
1299 else if (gru->activation == ACTIVATION_RELU)
1300 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1303 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1306 RNN_COPY(state, h, N);
1309 #define INPUT_SIZE 42
1311 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1313 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1314 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1315 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1317 compute_dense(rnn->model->input_dense, dense_out, input);
1318 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1319 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1321 for (int i = 0; i < rnn->model->input_dense_size; i++)
1322 noise_input[i] = dense_out[i];
1323 for (int i = 0; i < rnn->model->vad_gru_size; i++)
1324 noise_input[i + rnn->model->input_dense_size] = rnn->vad_gru_state[i];
1325 for (int i = 0; i < INPUT_SIZE; i++)
1326 noise_input[i + rnn->model->input_dense_size + rnn->model->vad_gru_size] = input[i];
1328 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1330 for (int i = 0; i < rnn->model->vad_gru_size; i++)
1331 denoise_input[i] = rnn->vad_gru_state[i];
1332 for (int i = 0; i < rnn->model->noise_gru_size; i++)
1333 denoise_input[i + rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
1334 for (int i = 0; i < INPUT_SIZE; i++)
1335 denoise_input[i + rnn->model->vad_gru_size + rnn->model->noise_gru_size] = input[i];
1337 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1338 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1341 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1343 AVComplexFloat X[FREQ_SIZE];
1344 AVComplexFloat P[WINDOW_SIZE];
1345 float x[FRAME_SIZE];
1346 float Ex[NB_BANDS], Ep[NB_BANDS];
1347 float Exp[NB_BANDS];
1348 float features[NB_FEATURES];
1350 float gf[FREQ_SIZE];
1352 static const float a_hp[2] = {-1.99599, 0.99600};
1353 static const float b_hp[2] = {-2, 1};
1356 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1357 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1360 compute_rnn(s, &st->rnn, g, &vad_prob, features);
1361 pitch_filter(X, P, Ex, Ep, Exp, g);
1362 for (int i = 0; i < NB_BANDS; i++) {
1365 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1366 st->lastg[i] = g[i];
1369 interp_band_gain(gf, g);
1371 for (int i = 0; i < FREQ_SIZE; i++) {
1377 frame_synthesis(s, st, out, X);
1382 typedef struct ThreadData {
1386 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1388 AudioRNNContext *s = ctx->priv;
1389 ThreadData *td = arg;
1390 AVFrame *in = td->in;
1391 AVFrame *out = td->out;
1392 const int start = (out->channels * jobnr) / nb_jobs;
1393 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1395 for (int ch = start; ch < end; ch++) {
1396 rnnoise_channel(s, &s->st[ch],
1397 (float *)out->extended_data[ch],
1398 (const float *)in->extended_data[ch]);
1404 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1406 AVFilterContext *ctx = inlink->dst;
1407 AVFilterLink *outlink = ctx->outputs[0];
1408 AVFrame *out = NULL;
1411 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1414 return AVERROR(ENOMEM);
1418 td.in = in; td.out = out;
1419 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1420 ff_filter_get_nb_threads(ctx)));
1423 return ff_filter_frame(outlink, out);
1426 static int activate(AVFilterContext *ctx)
1428 AVFilterLink *inlink = ctx->inputs[0];
1429 AVFilterLink *outlink = ctx->outputs[0];
1433 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1435 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1440 return filter_frame(inlink, in);
1442 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1443 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1445 return FFERROR_NOT_READY;
1448 static av_cold int init(AVFilterContext *ctx)
1450 AudioRNNContext *s = ctx->priv;
1453 s->fdsp = avpriv_float_dsp_alloc(0);
1455 return AVERROR(ENOMEM);
1458 return AVERROR(EINVAL);
1459 f = av_fopen_utf8(s->model_name, "r");
1461 return AVERROR(EINVAL);
1463 s->model = rnnoise_model_from_file(f);
1466 return AVERROR(EINVAL);
1468 for (int i = 0; i < FRAME_SIZE; i++) {
1469 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1470 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1473 for (int i = 0; i < NB_BANDS; i++) {
1474 for (int j = 0; j < NB_BANDS; j++) {
1475 s->dct_table[i*NB_BANDS + j] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1477 s->dct_table[i*NB_BANDS + j] *= sqrtf(.5);
1484 static av_cold void uninit(AVFilterContext *ctx)
1486 AudioRNNContext *s = ctx->priv;
1489 rnnoise_model_free(s->model);
1493 for (int ch = 0; ch < s->channels; ch++) {
1494 av_freep(&s->st[ch].rnn.vad_gru_state);
1495 av_freep(&s->st[ch].rnn.noise_gru_state);
1496 av_freep(&s->st[ch].rnn.denoise_gru_state);
1497 av_tx_uninit(&s->st[ch].tx);
1498 av_tx_uninit(&s->st[ch].txi);
1504 static const AVFilterPad inputs[] = {
1507 .type = AVMEDIA_TYPE_AUDIO,
1508 .config_props = config_input,
1513 static const AVFilterPad outputs[] = {
1516 .type = AVMEDIA_TYPE_AUDIO,
1521 #define OFFSET(x) offsetof(AudioRNNContext, x)
1522 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1524 static const AVOption arnndn_options[] = {
1525 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1526 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1530 AVFILTER_DEFINE_CLASS(arnndn);
1532 AVFilter ff_af_arnndn = {
1534 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1535 .query_formats = query_formats,
1536 .priv_size = sizeof(AudioRNNContext),
1537 .priv_class = &arnndn_class,
1538 .activate = activate,
1543 .flags = AVFILTER_FLAG_SLICE_THREADS,