2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
57 #define SQUARE(x) ((x)*(x))
62 #define NB_DELTA_CEPS 6
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
66 #define WEIGHTS_SCALE (1.f/256)
68 #define MAX_NEURONS 128
70 #define ACTIVATION_TANH 0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU 2
76 typedef struct DenseLayer {
78 const float *input_weights;
84 typedef struct GRULayer {
86 const float *input_weights;
87 const float *recurrent_weights;
93 typedef struct RNNModel {
95 const DenseLayer *input_dense;
98 const GRULayer *vad_gru;
101 const GRULayer *noise_gru;
103 int denoise_gru_size;
104 const GRULayer *denoise_gru;
106 int denoise_output_size;
107 const DenseLayer *denoise_output;
110 const DenseLayer *vad_output;
113 typedef struct RNNState {
114 float *vad_gru_state;
115 float *noise_gru_state;
116 float *denoise_gru_state;
120 typedef struct DenoiseState {
121 float analysis_mem[FRAME_SIZE];
122 float cepstral_mem[CEPS_MEM][NB_BANDS];
124 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125 float pitch_buf[PITCH_BUF_SIZE];
126 float pitch_enh_buf[PITCH_BUF_SIZE];
130 float lastg[NB_BANDS];
131 float history[FRAME_SIZE];
133 AVTXContext *tx, *txi;
134 av_tx_fn tx_fn, txi_fn;
137 typedef struct AudioRNNContext {
138 const AVClass *class;
146 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147 DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
151 AVFloatDSPContext *fdsp;
154 #define F_ACTIVATION_TANH 0
155 #define F_ACTIVATION_SIGMOID 1
156 #define F_ACTIVATION_RELU 2
158 static void rnnoise_model_free(RNNModel *model)
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
163 av_free((void *) model->name->input_weights); \
164 av_free((void *) model->name->bias); \
165 av_free((void *) model->name); \
168 #define FREE_GRU(name) do { \
170 av_free((void *) model->name->input_weights); \
171 av_free((void *) model->name->recurrent_weights); \
172 av_free((void *) model->name->bias); \
173 av_free((void *) model->name); \
179 FREE_DENSE(input_dense);
182 FREE_GRU(denoise_gru);
183 FREE_DENSE(denoise_output);
184 FREE_DENSE(vad_output);
188 static RNNModel *rnnoise_model_from_file(FILE *f)
191 DenseLayer *input_dense;
194 GRULayer *denoise_gru;
195 DenseLayer *denoise_output;
196 DenseLayer *vad_output;
199 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
202 ret = av_calloc(1, sizeof(RNNModel));
206 #define ALLOC_LAYER(type, name) \
207 name = av_calloc(1, sizeof(type)); \
209 rnnoise_model_free(ret); \
214 ALLOC_LAYER(DenseLayer, input_dense);
215 ALLOC_LAYER(GRULayer, vad_gru);
216 ALLOC_LAYER(GRULayer, noise_gru);
217 ALLOC_LAYER(GRULayer, denoise_gru);
218 ALLOC_LAYER(DenseLayer, denoise_output);
219 ALLOC_LAYER(DenseLayer, vad_output);
221 #define INPUT_VAL(name) do { \
222 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223 rnnoise_model_free(ret); \
229 #define INPUT_ACTIVATION(name) do { \
231 INPUT_VAL(activation); \
232 switch (activation) { \
233 case F_ACTIVATION_SIGMOID: \
234 name = ACTIVATION_SIGMOID; \
236 case F_ACTIVATION_RELU: \
237 name = ACTIVATION_RELU; \
240 name = ACTIVATION_TANH; \
244 #define INPUT_ARRAY(name, len) do { \
245 float *values = av_calloc((len), sizeof(float)); \
247 rnnoise_model_free(ret); \
251 for (int i = 0; i < (len); i++) { \
252 if (fscanf(f, "%d", &in) != 1) { \
253 rnnoise_model_free(ret); \
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
263 rnnoise_model_free(ret); \
267 for (int k = 0; k < (len0); k++) { \
268 for (int i = 0; i < (len2); i++) { \
269 for (int j = 0; j < (len1); j++) { \
270 if (fscanf(f, "%d", &in) != 1) { \
271 rnnoise_model_free(ret); \
274 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
280 #define INPUT_DENSE(name) do { \
281 INPUT_VAL(name->nb_inputs); \
282 INPUT_VAL(name->nb_neurons); \
283 ret->name ## _size = name->nb_neurons; \
284 INPUT_ACTIVATION(name->activation); \
285 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
286 INPUT_ARRAY(name->bias, name->nb_neurons); \
289 #define INPUT_GRU(name) do { \
290 INPUT_VAL(name->nb_inputs); \
291 INPUT_VAL(name->nb_neurons); \
292 ret->name ## _size = name->nb_neurons; \
293 INPUT_ACTIVATION(name->activation); \
294 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
295 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
296 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
299 INPUT_DENSE(input_dense);
301 INPUT_GRU(noise_gru);
302 INPUT_GRU(denoise_gru);
303 INPUT_DENSE(denoise_output);
304 INPUT_DENSE(vad_output);
306 if (vad_output->nb_neurons != 1) {
307 rnnoise_model_free(ret);
314 static int query_formats(AVFilterContext *ctx)
316 AVFilterFormats *formats = NULL;
317 AVFilterChannelLayouts *layouts = NULL;
318 static const enum AVSampleFormat sample_fmts[] = {
322 int ret, sample_rates[] = { 48000, -1 };
324 formats = ff_make_format_list(sample_fmts);
326 return AVERROR(ENOMEM);
327 ret = ff_set_common_formats(ctx, formats);
331 layouts = ff_all_channel_counts();
333 return AVERROR(ENOMEM);
335 ret = ff_set_common_channel_layouts(ctx, layouts);
339 formats = ff_make_format_list(sample_rates);
341 return AVERROR(ENOMEM);
342 return ff_set_common_samplerates(ctx, formats);
345 static int config_input(AVFilterLink *inlink)
347 AVFilterContext *ctx = inlink->dst;
348 AudioRNNContext *s = ctx->priv;
351 s->channels = inlink->channels;
353 s->st = av_calloc(s->channels, sizeof(DenoiseState));
355 return AVERROR(ENOMEM);
357 for (int i = 0; i < s->channels; i++) {
358 DenoiseState *st = &s->st[i];
360 st->rnn.model = s->model;
361 st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
362 st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
363 st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
364 if (!st->rnn.vad_gru_state ||
365 !st->rnn.noise_gru_state ||
366 !st->rnn.denoise_gru_state)
367 return AVERROR(ENOMEM);
369 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
373 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
381 static void biquad(float *y, float mem[2], const float *x,
382 const float *b, const float *a, int N)
384 for (int i = 0; i < N; i++) {
389 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
390 mem[1] = (b[1]*xi - a[1]*yi);
395 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
396 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
397 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
399 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
401 AVComplexFloat x[WINDOW_SIZE];
402 AVComplexFloat y[WINDOW_SIZE];
404 for (int i = 0; i < WINDOW_SIZE; i++) {
409 st->tx_fn(st->tx, y, x, sizeof(float));
411 RNN_COPY(out, y, FREQ_SIZE);
414 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
416 AVComplexFloat x[WINDOW_SIZE];
417 AVComplexFloat y[WINDOW_SIZE];
419 RNN_COPY(x, in, FREQ_SIZE);
421 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
422 x[i].re = x[WINDOW_SIZE - i].re;
423 x[i].im = -x[WINDOW_SIZE - i].im;
426 st->txi_fn(st->txi, y, x, sizeof(float));
428 for (int i = 0; i < WINDOW_SIZE; i++)
429 out[i] = y[i].re / WINDOW_SIZE;
432 static const uint8_t eband5ms[] = {
433 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
434 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
437 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
439 float sum[NB_BANDS] = {0};
441 for (int i = 0; i < NB_BANDS - 1; i++) {
444 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
445 for (int j = 0; j < band_size; j++) {
446 float tmp, frac = (float)j / band_size;
448 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
449 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
450 sum[i] += (1.f - frac) * tmp;
451 sum[i + 1] += frac * tmp;
456 sum[NB_BANDS - 1] *= 2;
458 for (int i = 0; i < NB_BANDS; i++)
462 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
464 float sum[NB_BANDS] = { 0 };
466 for (int i = 0; i < NB_BANDS - 1; i++) {
469 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
470 for (int j = 0; j < band_size; j++) {
471 float tmp, frac = (float)j / band_size;
473 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
474 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
475 sum[i] += (1 - frac) * tmp;
476 sum[i + 1] += frac * tmp;
481 sum[NB_BANDS-1] *= 2;
483 for (int i = 0; i < NB_BANDS; i++)
487 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
489 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
491 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
492 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
493 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
494 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
495 forward_transform(st, X, x);
496 compute_band_energy(Ex, X);
499 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
501 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
502 const float *src = st->history;
503 const float mix = s->mix;
504 const float imix = 1.f - FFMAX(mix, 0.f);
506 inverse_transform(st, x, y);
507 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
508 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
509 RNN_COPY(out, x, FRAME_SIZE);
510 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
512 for (int n = 0; n < FRAME_SIZE; n++)
513 out[n] = out[n] * mix + src[n] * imix;
516 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
518 float y_0, y_1, y_2, y_3 = 0;
525 for (j = 0; j < len - 3; j += 4) {
585 static inline float celt_inner_prod(const float *x,
586 const float *y, int N)
590 for (int i = 0; i < N; i++)
596 static void celt_pitch_xcorr(const float *x, const float *y,
597 float *xcorr, int len, int max_pitch)
601 for (i = 0; i < max_pitch - 3; i += 4) {
602 float sum[4] = { 0, 0, 0, 0};
604 xcorr_kernel(x, y + i, sum, len);
607 xcorr[i + 1] = sum[1];
608 xcorr[i + 2] = sum[2];
609 xcorr[i + 3] = sum[3];
611 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
612 for (; i < max_pitch; i++) {
613 xcorr[i] = celt_inner_prod(x, y + i, len);
617 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
618 float *ac, /* out: [0...lag-1] ac values */
627 float xx[PITCH_BUF_SIZE>>1];
632 for (int i = 0; i < n; i++)
634 for (int i = 0; i < overlap; i++) {
635 xx[i] = x[i] * window[i];
636 xx[n-i-1] = x[n-i-1] * window[i];
642 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
644 for (int k = 0; k <= lag; k++) {
647 for (int i = k + fastN; i < n; i++)
648 d += xptr[i] * xptr[i-k];
655 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
656 const float *ac, /* in: [0...p] autocorrelation values */
659 float r, error = ac[0];
663 for (int i = 0; i < p; i++) {
664 /* Sum up this iteration's reflection coefficient */
666 for (int j = 0; j < i; j++)
667 rr += (lpc[j] * ac[i - j]);
670 /* Update LPC coefficients and total error */
672 for (int j = 0; j < (i + 1) >> 1; j++) {
676 lpc[j] = tmp1 + (r*tmp2);
677 lpc[i-1-j] = tmp2 + (r*tmp1);
680 error = error - (r * r *error);
681 /* Bail out once we get 30 dB gain */
682 if (error < .001f * ac[0])
688 static void celt_fir5(const float *x,
694 float num0, num1, num2, num3, num4;
695 float mem0, mem1, mem2, mem3, mem4;
708 for (int i = 0; i < N; i++) {
731 static void pitch_downsample(float *x[], float *x_lp,
736 float lpc[4], mem[5]={0,0,0,0,0};
740 for (int i = 1; i < len >> 1; i++)
741 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
742 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
744 for (int i = 1; i < len >> 1; i++)
745 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
746 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
749 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
751 /* Noise floor -40 dB */
754 for (int i = 1; i <= 4; i++) {
755 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
756 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
759 celt_lpc(lpc, ac, 4);
760 for (int i = 0; i < 4; i++) {
762 lpc[i] = (lpc[i] * tmp);
765 lpc2[0] = lpc[0] + .8f;
766 lpc2[1] = lpc[1] + (c1 * lpc[0]);
767 lpc2[2] = lpc[2] + (c1 * lpc[1]);
768 lpc2[3] = lpc[3] + (c1 * lpc[2]);
769 lpc2[4] = (c1 * lpc[3]);
770 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
773 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
774 int N, float *xy1, float *xy2)
776 float xy01 = 0, xy02 = 0;
778 for (int i = 0; i < N; i++) {
779 xy01 += (x[i] * y01[i]);
780 xy02 += (x[i] * y02[i]);
787 static float compute_pitch_gain(float xy, float xx, float yy)
789 return xy / sqrtf(1.f + xx * yy);
792 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
793 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
794 int *T0_, int prev_period, float prev_gain)
801 float best_xy, best_yy;
804 float yy_lookup[PITCH_MAX_PERIOD+1];
806 minperiod0 = minperiod;
817 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
820 for (i = 1; i <= maxperiod; i++) {
821 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
822 yy_lookup[i] = FFMAX(0, yy);
827 g = g0 = compute_pitch_gain(xy, xx, yy);
828 /* Look for any pitch at T/k */
829 for (k = 2; k <= 15; k++) {
837 /* Look for another strong correlation at T1b */
846 T1b = (2*second_check[k]*T0+k)/(2*k);
848 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
849 xy = .5f * (xy + xy2);
850 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
851 g1 = compute_pitch_gain(xy, xx, yy);
852 if (FFABS(T1-prev_period)<=1)
854 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
855 cont = prev_gain * .5f;
858 thresh = FFMAX(.3f, (.7f * g0) - cont);
859 /* Bias against very high pitch (very short period) to avoid false-positives
860 due to short-term correlation */
862 thresh = FFMAX(.4f, (.85f * g0) - cont);
863 else if (T1<2*minperiod)
864 thresh = FFMAX(.5f, (.9f * g0) - cont);
873 best_xy = FFMAX(0, best_xy);
874 if (best_yy <= best_xy)
877 pg = best_xy/(best_yy + 1);
879 for (k = 0; k < 3; k++)
880 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
881 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
883 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
896 static void find_best_pitch(float *xcorr, float *y, int len,
897 int max_pitch, int *best_pitch)
910 for (int j = 0; j < len; j++)
913 for (int i = 0; i < max_pitch; i++) {
919 /* Considering the range of xcorr16, this should avoid both underflows
920 and overflows (inf) when squaring xcorr16 */
922 num = xcorr16 * xcorr16;
923 if ((num * best_den[1]) > (best_num[1] * Syy)) {
924 if ((num * best_den[0]) > (best_num[0] * Syy)) {
925 best_num[1] = best_num[0];
926 best_den[1] = best_den[0];
927 best_pitch[1] = best_pitch[0];
938 Syy += y[i+len]*y[i+len] - y[i] * y[i];
943 static void pitch_search(const float *x_lp, float *y,
944 int len, int max_pitch, int *pitch)
947 int best_pitch[2]={0,0};
950 float x_lp4[WINDOW_SIZE];
951 float y_lp4[WINDOW_SIZE];
952 float xcorr[WINDOW_SIZE];
956 /* Downsample by 2 again */
957 for (int j = 0; j < len >> 2; j++)
958 x_lp4[j] = x_lp[2*j];
959 for (int j = 0; j < lag >> 2; j++)
962 /* Coarse search with 4x decimation */
964 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
966 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
968 /* Finer search with 2x decimation */
969 for (int i = 0; i < max_pitch >> 1; i++) {
972 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
974 sum = celt_inner_prod(x_lp, y+i, len>>1);
975 xcorr[i] = FFMAX(-1, sum);
978 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
980 /* Refine by pseudo-interpolation */
981 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
984 a = xcorr[best_pitch[0] - 1];
985 b = xcorr[best_pitch[0]];
986 c = xcorr[best_pitch[0] + 1];
987 if (c - a > .7f * (b - a))
989 else if (a - c > .7f * (b-c))
997 *pitch = 2 * best_pitch[0] - offset;
1000 static void dct(AudioRNNContext *s, float *out, const float *in)
1002 for (int i = 0; i < NB_BANDS; i++) {
1005 sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1006 out[i] = sum * sqrtf(2.f / 22);
1010 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1011 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1014 float *ceps_0, *ceps_1, *ceps_2;
1015 float spec_variability = 0;
1016 LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1017 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1018 float pitch_buf[PITCH_BUF_SIZE>>1];
1022 float tmp[NB_BANDS];
1023 float follow, logMax;
1025 frame_analysis(s, st, X, Ex, in);
1026 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1027 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1028 pre[0] = &st->pitch_buf[0];
1029 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1030 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1031 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1032 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1034 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1035 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1036 st->last_period = pitch_index;
1037 st->last_gain = gain;
1039 for (int i = 0; i < WINDOW_SIZE; i++)
1040 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1042 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1043 forward_transform(st, P, p);
1044 compute_band_energy(Ep, P);
1045 compute_band_corr(Exp, X, P);
1047 for (int i = 0; i < NB_BANDS; i++)
1048 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1052 for (int i = 0; i < NB_DELTA_CEPS; i++)
1053 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1055 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1056 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1057 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1061 for (int i = 0; i < NB_BANDS; i++) {
1062 Ly[i] = log10f(1e-2f + Ex[i]);
1063 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1064 logMax = FFMAX(logMax, Ly[i]);
1065 follow = FFMAX(follow-1.5, Ly[i]);
1070 /* If there's no audio, avoid messing up the state. */
1071 RNN_CLEAR(features, NB_FEATURES);
1075 dct(s, features, Ly);
1078 ceps_0 = st->cepstral_mem[st->memid];
1079 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1080 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1082 for (int i = 0; i < NB_BANDS; i++)
1083 ceps_0[i] = features[i];
1086 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1087 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1088 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1089 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1091 /* Spectral variability features. */
1092 if (st->memid == CEPS_MEM)
1095 for (int i = 0; i < CEPS_MEM; i++) {
1096 float mindist = 1e15f;
1097 for (int j = 0; j < CEPS_MEM; j++) {
1099 for (int k = 0; k < NB_BANDS; k++) {
1102 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1107 mindist = FFMIN(mindist, dist);
1110 spec_variability += mindist;
1113 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1118 static void interp_band_gain(float *g, const float *bandE)
1120 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1122 for (int i = 0; i < NB_BANDS - 1; i++) {
1123 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1125 for (int j = 0; j < band_size; j++) {
1126 float frac = (float)j / band_size;
1128 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1133 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1134 const float *Exp, const float *g)
1136 float newE[NB_BANDS];
1138 float norm[NB_BANDS];
1139 float rf[FREQ_SIZE] = {0};
1140 float normf[FREQ_SIZE]={0};
1142 for (int i = 0; i < NB_BANDS; i++) {
1143 if (Exp[i]>g[i]) r[i] = 1;
1144 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1145 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1146 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1148 interp_band_gain(rf, r);
1149 for (int i = 0; i < FREQ_SIZE; i++) {
1150 X[i].re += rf[i]*P[i].re;
1151 X[i].im += rf[i]*P[i].im;
1153 compute_band_energy(newE, X);
1154 for (int i = 0; i < NB_BANDS; i++) {
1155 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1157 interp_band_gain(normf, norm);
1158 for (int i = 0; i < FREQ_SIZE; i++) {
1159 X[i].re *= normf[i];
1160 X[i].im *= normf[i];
1164 static const float tansig_table[201] = {
1165 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1166 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1167 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1168 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1169 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1170 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1171 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1172 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1173 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1174 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1175 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1176 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1177 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1178 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1179 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1180 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1181 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1182 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1183 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1184 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1185 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1186 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1187 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1188 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1189 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1190 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1191 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1192 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1193 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1194 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1195 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1196 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1197 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1198 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1199 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1200 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1201 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1202 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1203 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1204 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1208 static inline float tansig_approx(float x)
1214 /* Tests are reversed to catch NaNs */
1219 /* Another check in case of -ffast-math */
1228 i = (int)floor(.5f+25*x);
1230 y = tansig_table[i];
1232 y = y + x*dy*(1 - y*x);
1236 static inline float sigmoid_approx(float x)
1238 return .5f + .5f*tansig_approx(.5f*x);
1241 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1243 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1245 for (int i = 0; i < N; i++) {
1246 /* Compute update gate. */
1247 float sum = layer->bias[i];
1249 for (int j = 0; j < M; j++)
1250 sum += layer->input_weights[j * stride + i] * input[j];
1252 output[i] = WEIGHTS_SCALE * sum;
1255 if (layer->activation == ACTIVATION_SIGMOID) {
1256 for (int i = 0; i < N; i++)
1257 output[i] = sigmoid_approx(output[i]);
1258 } else if (layer->activation == ACTIVATION_TANH) {
1259 for (int i = 0; i < N; i++)
1260 output[i] = tansig_approx(output[i]);
1261 } else if (layer->activation == ACTIVATION_RELU) {
1262 for (int i = 0; i < N; i++)
1263 output[i] = FFMAX(0, output[i]);
1269 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1271 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1272 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1273 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1274 const int M = gru->nb_inputs;
1275 const int N = gru->nb_neurons;
1276 const int AN = FFALIGN(N, 4);
1277 const int AM = FFALIGN(M, 4);
1278 const int stride = 3 * AN, istride = 3 * AM;
1280 for (int i = 0; i < N; i++) {
1281 /* Compute update gate. */
1282 float sum = gru->bias[i];
1284 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1285 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1286 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1289 for (int i = 0; i < N; i++) {
1290 /* Compute reset gate. */
1291 float sum = gru->bias[N + i];
1293 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1294 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1295 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1298 for (int i = 0; i < N; i++) {
1299 /* Compute output. */
1300 float sum = gru->bias[2 * N + i];
1302 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1303 for (int j = 0; j < N; j++)
1304 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1306 if (gru->activation == ACTIVATION_SIGMOID)
1307 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1308 else if (gru->activation == ACTIVATION_TANH)
1309 sum = tansig_approx(WEIGHTS_SCALE * sum);
1310 else if (gru->activation == ACTIVATION_RELU)
1311 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1314 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1317 RNN_COPY(state, h, N);
1320 #define INPUT_SIZE 42
1322 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1324 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1325 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1326 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1328 compute_dense(rnn->model->input_dense, dense_out, input);
1329 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1330 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1332 memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1333 memcpy(noise_input + rnn->model->input_dense_size,
1334 rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1335 memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1336 input, INPUT_SIZE * sizeof(float));
1338 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1340 memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1341 memcpy(denoise_input + rnn->model->vad_gru_size,
1342 rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1343 memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1344 input, INPUT_SIZE * sizeof(float));
1346 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1347 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1350 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1353 AVComplexFloat X[FREQ_SIZE];
1354 AVComplexFloat P[WINDOW_SIZE];
1355 float x[FRAME_SIZE];
1356 float Ex[NB_BANDS], Ep[NB_BANDS];
1357 LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1358 float features[NB_FEATURES];
1360 float gf[FREQ_SIZE];
1362 float *history = st->history;
1363 static const float a_hp[2] = {-1.99599, 0.99600};
1364 static const float b_hp[2] = {-2, 1};
1367 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1368 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1370 if (!silence && !disabled) {
1371 compute_rnn(s, &st->rnn, g, &vad_prob, features);
1372 pitch_filter(X, P, Ex, Ep, Exp, g);
1373 for (int i = 0; i < NB_BANDS; i++) {
1376 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1377 st->lastg[i] = g[i];
1380 interp_band_gain(gf, g);
1382 for (int i = 0; i < FREQ_SIZE; i++) {
1388 frame_synthesis(s, st, out, X);
1389 memcpy(history, in, FRAME_SIZE * sizeof(*history));
1394 typedef struct ThreadData {
1398 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1400 AudioRNNContext *s = ctx->priv;
1401 ThreadData *td = arg;
1402 AVFrame *in = td->in;
1403 AVFrame *out = td->out;
1404 const int start = (out->channels * jobnr) / nb_jobs;
1405 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1407 for (int ch = start; ch < end; ch++) {
1408 rnnoise_channel(s, &s->st[ch],
1409 (float *)out->extended_data[ch],
1410 (const float *)in->extended_data[ch],
1417 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1419 AVFilterContext *ctx = inlink->dst;
1420 AVFilterLink *outlink = ctx->outputs[0];
1421 AVFrame *out = NULL;
1424 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1427 return AVERROR(ENOMEM);
1431 td.in = in; td.out = out;
1432 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1433 ff_filter_get_nb_threads(ctx)));
1436 return ff_filter_frame(outlink, out);
1439 static int activate(AVFilterContext *ctx)
1441 AVFilterLink *inlink = ctx->inputs[0];
1442 AVFilterLink *outlink = ctx->outputs[0];
1446 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1448 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1453 return filter_frame(inlink, in);
1455 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1456 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1458 return FFERROR_NOT_READY;
1461 static av_cold int init(AVFilterContext *ctx)
1463 AudioRNNContext *s = ctx->priv;
1466 s->fdsp = avpriv_float_dsp_alloc(0);
1468 return AVERROR(ENOMEM);
1471 return AVERROR(EINVAL);
1472 f = av_fopen_utf8(s->model_name, "r");
1474 return AVERROR(EINVAL);
1476 s->model = rnnoise_model_from_file(f);
1479 return AVERROR(EINVAL);
1481 for (int i = 0; i < FRAME_SIZE; i++) {
1482 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1483 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1486 for (int i = 0; i < NB_BANDS; i++) {
1487 for (int j = 0; j < NB_BANDS; j++) {
1488 s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1490 s->dct_table[j][i] *= sqrtf(.5);
1497 static av_cold void uninit(AVFilterContext *ctx)
1499 AudioRNNContext *s = ctx->priv;
1502 rnnoise_model_free(s->model);
1506 for (int ch = 0; ch < s->channels; ch++) {
1507 av_freep(&s->st[ch].rnn.vad_gru_state);
1508 av_freep(&s->st[ch].rnn.noise_gru_state);
1509 av_freep(&s->st[ch].rnn.denoise_gru_state);
1510 av_tx_uninit(&s->st[ch].tx);
1511 av_tx_uninit(&s->st[ch].txi);
1517 static const AVFilterPad inputs[] = {
1520 .type = AVMEDIA_TYPE_AUDIO,
1521 .config_props = config_input,
1526 static const AVFilterPad outputs[] = {
1529 .type = AVMEDIA_TYPE_AUDIO,
1534 #define OFFSET(x) offsetof(AudioRNNContext, x)
1535 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1537 static const AVOption arnndn_options[] = {
1538 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1539 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1540 { "mix", "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1544 AVFILTER_DEFINE_CLASS(arnndn);
1546 AVFilter ff_af_arnndn = {
1548 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1549 .query_formats = query_formats,
1550 .priv_size = sizeof(AudioRNNContext),
1551 .priv_class = &arnndn_class,
1552 .activate = activate,
1557 .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1558 AVFILTER_FLAG_SLICE_THREADS,