2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56 #define SQUARE(x) ((x)*(x))
61 #define NB_DELTA_CEPS 6
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65 #define WEIGHTS_SCALE (1.f/256)
67 #define MAX_NEURONS 128
69 #define ACTIVATION_TANH 0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU 2
75 typedef struct DenseLayer {
77 const float *input_weights;
83 typedef struct GRULayer {
85 const float *input_weights;
86 const float *recurrent_weights;
92 typedef struct RNNModel {
94 const DenseLayer *input_dense;
97 const GRULayer *vad_gru;
100 const GRULayer *noise_gru;
102 int denoise_gru_size;
103 const GRULayer *denoise_gru;
105 int denoise_output_size;
106 const DenseLayer *denoise_output;
109 const DenseLayer *vad_output;
112 typedef struct RNNState {
113 float *vad_gru_state;
114 float *noise_gru_state;
115 float *denoise_gru_state;
119 typedef struct DenoiseState {
120 float analysis_mem[FRAME_SIZE];
121 float cepstral_mem[CEPS_MEM][NB_BANDS];
123 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124 float pitch_buf[PITCH_BUF_SIZE];
125 float pitch_enh_buf[PITCH_BUF_SIZE];
129 float lastg[NB_BANDS];
130 float history[FRAME_SIZE];
132 AVTXContext *tx, *txi;
133 av_tx_fn tx_fn, txi_fn;
136 typedef struct AudioRNNContext {
137 const AVClass *class;
145 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
146 DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
150 AVFloatDSPContext *fdsp;
153 #define F_ACTIVATION_TANH 0
154 #define F_ACTIVATION_SIGMOID 1
155 #define F_ACTIVATION_RELU 2
157 static void rnnoise_model_free(RNNModel *model)
159 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
160 #define FREE_DENSE(name) do { \
162 av_free((void *) model->name->input_weights); \
163 av_free((void *) model->name->bias); \
164 av_free((void *) model->name); \
167 #define FREE_GRU(name) do { \
169 av_free((void *) model->name->input_weights); \
170 av_free((void *) model->name->recurrent_weights); \
171 av_free((void *) model->name->bias); \
172 av_free((void *) model->name); \
178 FREE_DENSE(input_dense);
181 FREE_GRU(denoise_gru);
182 FREE_DENSE(denoise_output);
183 FREE_DENSE(vad_output);
187 static RNNModel *rnnoise_model_from_file(FILE *f)
190 DenseLayer *input_dense;
193 GRULayer *denoise_gru;
194 DenseLayer *denoise_output;
195 DenseLayer *vad_output;
198 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
201 ret = av_calloc(1, sizeof(RNNModel));
205 #define ALLOC_LAYER(type, name) \
206 name = av_calloc(1, sizeof(type)); \
208 rnnoise_model_free(ret); \
213 ALLOC_LAYER(DenseLayer, input_dense);
214 ALLOC_LAYER(GRULayer, vad_gru);
215 ALLOC_LAYER(GRULayer, noise_gru);
216 ALLOC_LAYER(GRULayer, denoise_gru);
217 ALLOC_LAYER(DenseLayer, denoise_output);
218 ALLOC_LAYER(DenseLayer, vad_output);
220 #define INPUT_VAL(name) do { \
221 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
222 rnnoise_model_free(ret); \
228 #define INPUT_ACTIVATION(name) do { \
230 INPUT_VAL(activation); \
231 switch (activation) { \
232 case F_ACTIVATION_SIGMOID: \
233 name = ACTIVATION_SIGMOID; \
235 case F_ACTIVATION_RELU: \
236 name = ACTIVATION_RELU; \
239 name = ACTIVATION_TANH; \
243 #define INPUT_ARRAY(name, len) do { \
244 float *values = av_calloc((len), sizeof(float)); \
246 rnnoise_model_free(ret); \
250 for (int i = 0; i < (len); i++) { \
251 if (fscanf(f, "%d", &in) != 1) { \
252 rnnoise_model_free(ret); \
259 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
260 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262 rnnoise_model_free(ret); \
266 for (int k = 0; k < (len0); k++) { \
267 for (int i = 0; i < (len2); i++) { \
268 for (int j = 0; j < (len1); j++) { \
269 if (fscanf(f, "%d", &in) != 1) { \
270 rnnoise_model_free(ret); \
273 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
279 #define INPUT_DENSE(name) do { \
280 INPUT_VAL(name->nb_inputs); \
281 INPUT_VAL(name->nb_neurons); \
282 ret->name ## _size = name->nb_neurons; \
283 INPUT_ACTIVATION(name->activation); \
284 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
285 INPUT_ARRAY(name->bias, name->nb_neurons); \
288 #define INPUT_GRU(name) do { \
289 INPUT_VAL(name->nb_inputs); \
290 INPUT_VAL(name->nb_neurons); \
291 ret->name ## _size = name->nb_neurons; \
292 INPUT_ACTIVATION(name->activation); \
293 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
294 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
295 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
298 INPUT_DENSE(input_dense);
300 INPUT_GRU(noise_gru);
301 INPUT_GRU(denoise_gru);
302 INPUT_DENSE(denoise_output);
303 INPUT_DENSE(vad_output);
305 if (vad_output->nb_neurons != 1) {
306 rnnoise_model_free(ret);
313 static int query_formats(AVFilterContext *ctx)
315 AVFilterFormats *formats = NULL;
316 AVFilterChannelLayouts *layouts = NULL;
317 static const enum AVSampleFormat sample_fmts[] = {
321 int ret, sample_rates[] = { 48000, -1 };
323 formats = ff_make_format_list(sample_fmts);
325 return AVERROR(ENOMEM);
326 ret = ff_set_common_formats(ctx, formats);
330 layouts = ff_all_channel_counts();
332 return AVERROR(ENOMEM);
334 ret = ff_set_common_channel_layouts(ctx, layouts);
338 formats = ff_make_format_list(sample_rates);
340 return AVERROR(ENOMEM);
341 return ff_set_common_samplerates(ctx, formats);
344 static int config_input(AVFilterLink *inlink)
346 AVFilterContext *ctx = inlink->dst;
347 AudioRNNContext *s = ctx->priv;
350 s->channels = inlink->channels;
352 s->st = av_calloc(s->channels, sizeof(DenoiseState));
354 return AVERROR(ENOMEM);
356 for (int i = 0; i < s->channels; i++) {
357 DenoiseState *st = &s->st[i];
359 st->rnn.model = s->model;
360 st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
361 st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
362 st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
363 if (!st->rnn.vad_gru_state ||
364 !st->rnn.noise_gru_state ||
365 !st->rnn.denoise_gru_state)
366 return AVERROR(ENOMEM);
368 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
372 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
380 static void biquad(float *y, float mem[2], const float *x,
381 const float *b, const float *a, int N)
383 for (int i = 0; i < N; i++) {
388 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
389 mem[1] = (b[1]*xi - a[1]*yi);
394 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
395 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
396 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
398 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
400 AVComplexFloat x[WINDOW_SIZE];
401 AVComplexFloat y[WINDOW_SIZE];
403 for (int i = 0; i < WINDOW_SIZE; i++) {
408 st->tx_fn(st->tx, y, x, sizeof(float));
410 RNN_COPY(out, y, FREQ_SIZE);
413 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
415 AVComplexFloat x[WINDOW_SIZE];
416 AVComplexFloat y[WINDOW_SIZE];
418 RNN_COPY(x, in, FREQ_SIZE);
420 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
421 x[i].re = x[WINDOW_SIZE - i].re;
422 x[i].im = -x[WINDOW_SIZE - i].im;
425 st->txi_fn(st->txi, y, x, sizeof(float));
427 for (int i = 0; i < WINDOW_SIZE; i++)
428 out[i] = y[i].re / WINDOW_SIZE;
431 static const uint8_t eband5ms[] = {
432 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
433 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
436 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
438 float sum[NB_BANDS] = {0};
440 for (int i = 0; i < NB_BANDS - 1; i++) {
443 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
444 for (int j = 0; j < band_size; j++) {
445 float tmp, frac = (float)j / band_size;
447 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
448 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
449 sum[i] += (1.f - frac) * tmp;
450 sum[i + 1] += frac * tmp;
455 sum[NB_BANDS - 1] *= 2;
457 for (int i = 0; i < NB_BANDS; i++)
461 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
463 float sum[NB_BANDS] = { 0 };
465 for (int i = 0; i < NB_BANDS - 1; i++) {
468 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
469 for (int j = 0; j < band_size; j++) {
470 float tmp, frac = (float)j / band_size;
472 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
473 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
474 sum[i] += (1 - frac) * tmp;
475 sum[i + 1] += frac * tmp;
480 sum[NB_BANDS-1] *= 2;
482 for (int i = 0; i < NB_BANDS; i++)
486 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
488 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
490 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
491 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
492 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
493 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
494 forward_transform(st, X, x);
495 compute_band_energy(Ex, X);
498 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
500 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
501 const float *src = st->history;
502 const float mix = s->mix;
503 const float imix = 1.f - FFMAX(mix, 0.f);
505 inverse_transform(st, x, y);
506 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
507 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
508 RNN_COPY(out, x, FRAME_SIZE);
509 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
511 for (int n = 0; n < FRAME_SIZE; n++)
512 out[n] = out[n] * mix + src[n] * imix;
515 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
517 float y_0, y_1, y_2, y_3 = 0;
524 for (j = 0; j < len - 3; j += 4) {
584 static inline float celt_inner_prod(const float *x,
585 const float *y, int N)
589 for (int i = 0; i < N; i++)
595 static void celt_pitch_xcorr(const float *x, const float *y,
596 float *xcorr, int len, int max_pitch)
600 for (i = 0; i < max_pitch - 3; i += 4) {
601 float sum[4] = { 0, 0, 0, 0};
603 xcorr_kernel(x, y + i, sum, len);
606 xcorr[i + 1] = sum[1];
607 xcorr[i + 2] = sum[2];
608 xcorr[i + 3] = sum[3];
610 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
611 for (; i < max_pitch; i++) {
612 xcorr[i] = celt_inner_prod(x, y + i, len);
616 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
617 float *ac, /* out: [0...lag-1] ac values */
626 float xx[PITCH_BUF_SIZE>>1];
631 for (int i = 0; i < n; i++)
633 for (int i = 0; i < overlap; i++) {
634 xx[i] = x[i] * window[i];
635 xx[n-i-1] = x[n-i-1] * window[i];
641 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
643 for (int k = 0; k <= lag; k++) {
646 for (int i = k + fastN; i < n; i++)
647 d += xptr[i] * xptr[i-k];
654 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
655 const float *ac, /* in: [0...p] autocorrelation values */
658 float r, error = ac[0];
662 for (int i = 0; i < p; i++) {
663 /* Sum up this iteration's reflection coefficient */
665 for (int j = 0; j < i; j++)
666 rr += (lpc[j] * ac[i - j]);
669 /* Update LPC coefficients and total error */
671 for (int j = 0; j < (i + 1) >> 1; j++) {
675 lpc[j] = tmp1 + (r*tmp2);
676 lpc[i-1-j] = tmp2 + (r*tmp1);
679 error = error - (r * r *error);
680 /* Bail out once we get 30 dB gain */
681 if (error < .001f * ac[0])
687 static void celt_fir5(const float *x,
693 float num0, num1, num2, num3, num4;
694 float mem0, mem1, mem2, mem3, mem4;
707 for (int i = 0; i < N; i++) {
730 static void pitch_downsample(float *x[], float *x_lp,
735 float lpc[4], mem[5]={0,0,0,0,0};
739 for (int i = 1; i < len >> 1; i++)
740 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
741 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
743 for (int i = 1; i < len >> 1; i++)
744 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
745 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
748 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
750 /* Noise floor -40 dB */
753 for (int i = 1; i <= 4; i++) {
754 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
755 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
758 celt_lpc(lpc, ac, 4);
759 for (int i = 0; i < 4; i++) {
761 lpc[i] = (lpc[i] * tmp);
764 lpc2[0] = lpc[0] + .8f;
765 lpc2[1] = lpc[1] + (c1 * lpc[0]);
766 lpc2[2] = lpc[2] + (c1 * lpc[1]);
767 lpc2[3] = lpc[3] + (c1 * lpc[2]);
768 lpc2[4] = (c1 * lpc[3]);
769 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
772 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
773 int N, float *xy1, float *xy2)
775 float xy01 = 0, xy02 = 0;
777 for (int i = 0; i < N; i++) {
778 xy01 += (x[i] * y01[i]);
779 xy02 += (x[i] * y02[i]);
786 static float compute_pitch_gain(float xy, float xx, float yy)
788 return xy / sqrtf(1.f + xx * yy);
791 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
792 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
793 int *T0_, int prev_period, float prev_gain)
800 float best_xy, best_yy;
803 float yy_lookup[PITCH_MAX_PERIOD+1];
805 minperiod0 = minperiod;
816 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
819 for (i = 1; i <= maxperiod; i++) {
820 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
821 yy_lookup[i] = FFMAX(0, yy);
826 g = g0 = compute_pitch_gain(xy, xx, yy);
827 /* Look for any pitch at T/k */
828 for (k = 2; k <= 15; k++) {
836 /* Look for another strong correlation at T1b */
845 T1b = (2*second_check[k]*T0+k)/(2*k);
847 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
848 xy = .5f * (xy + xy2);
849 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
850 g1 = compute_pitch_gain(xy, xx, yy);
851 if (FFABS(T1-prev_period)<=1)
853 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
854 cont = prev_gain * .5f;
857 thresh = FFMAX(.3f, (.7f * g0) - cont);
858 /* Bias against very high pitch (very short period) to avoid false-positives
859 due to short-term correlation */
861 thresh = FFMAX(.4f, (.85f * g0) - cont);
862 else if (T1<2*minperiod)
863 thresh = FFMAX(.5f, (.9f * g0) - cont);
872 best_xy = FFMAX(0, best_xy);
873 if (best_yy <= best_xy)
876 pg = best_xy/(best_yy + 1);
878 for (k = 0; k < 3; k++)
879 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
880 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
882 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
895 static void find_best_pitch(float *xcorr, float *y, int len,
896 int max_pitch, int *best_pitch)
909 for (int j = 0; j < len; j++)
912 for (int i = 0; i < max_pitch; i++) {
918 /* Considering the range of xcorr16, this should avoid both underflows
919 and overflows (inf) when squaring xcorr16 */
921 num = xcorr16 * xcorr16;
922 if ((num * best_den[1]) > (best_num[1] * Syy)) {
923 if ((num * best_den[0]) > (best_num[0] * Syy)) {
924 best_num[1] = best_num[0];
925 best_den[1] = best_den[0];
926 best_pitch[1] = best_pitch[0];
937 Syy += y[i+len]*y[i+len] - y[i] * y[i];
942 static void pitch_search(const float *x_lp, float *y,
943 int len, int max_pitch, int *pitch)
946 int best_pitch[2]={0,0};
949 float x_lp4[WINDOW_SIZE];
950 float y_lp4[WINDOW_SIZE];
951 float xcorr[WINDOW_SIZE];
955 /* Downsample by 2 again */
956 for (int j = 0; j < len >> 2; j++)
957 x_lp4[j] = x_lp[2*j];
958 for (int j = 0; j < lag >> 2; j++)
961 /* Coarse search with 4x decimation */
963 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
965 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
967 /* Finer search with 2x decimation */
968 for (int i = 0; i < max_pitch >> 1; i++) {
971 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
973 sum = celt_inner_prod(x_lp, y+i, len>>1);
974 xcorr[i] = FFMAX(-1, sum);
977 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
979 /* Refine by pseudo-interpolation */
980 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
983 a = xcorr[best_pitch[0] - 1];
984 b = xcorr[best_pitch[0]];
985 c = xcorr[best_pitch[0] + 1];
986 if (c - a > .7f * (b - a))
988 else if (a - c > .7f * (b-c))
996 *pitch = 2 * best_pitch[0] - offset;
999 static void dct(AudioRNNContext *s, float *out, const float *in)
1001 for (int i = 0; i < NB_BANDS; i++) {
1004 sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1005 out[i] = sum * sqrtf(2.f / 22);
1009 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1010 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1013 float *ceps_0, *ceps_1, *ceps_2;
1014 float spec_variability = 0;
1015 LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1016 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1017 float pitch_buf[PITCH_BUF_SIZE>>1];
1021 float tmp[NB_BANDS];
1022 float follow, logMax;
1024 frame_analysis(s, st, X, Ex, in);
1025 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1026 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1027 pre[0] = &st->pitch_buf[0];
1028 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1029 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1030 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1031 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1033 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1034 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1035 st->last_period = pitch_index;
1036 st->last_gain = gain;
1038 for (int i = 0; i < WINDOW_SIZE; i++)
1039 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1041 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1042 forward_transform(st, P, p);
1043 compute_band_energy(Ep, P);
1044 compute_band_corr(Exp, X, P);
1046 for (int i = 0; i < NB_BANDS; i++)
1047 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1051 for (int i = 0; i < NB_DELTA_CEPS; i++)
1052 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1054 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1055 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1056 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1060 for (int i = 0; i < NB_BANDS; i++) {
1061 Ly[i] = log10f(1e-2f + Ex[i]);
1062 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1063 logMax = FFMAX(logMax, Ly[i]);
1064 follow = FFMAX(follow-1.5, Ly[i]);
1069 /* If there's no audio, avoid messing up the state. */
1070 RNN_CLEAR(features, NB_FEATURES);
1074 dct(s, features, Ly);
1077 ceps_0 = st->cepstral_mem[st->memid];
1078 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1079 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1081 for (int i = 0; i < NB_BANDS; i++)
1082 ceps_0[i] = features[i];
1085 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1086 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1087 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1088 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1090 /* Spectral variability features. */
1091 if (st->memid == CEPS_MEM)
1094 for (int i = 0; i < CEPS_MEM; i++) {
1095 float mindist = 1e15f;
1096 for (int j = 0; j < CEPS_MEM; j++) {
1098 for (int k = 0; k < NB_BANDS; k++) {
1101 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1106 mindist = FFMIN(mindist, dist);
1109 spec_variability += mindist;
1112 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1117 static void interp_band_gain(float *g, const float *bandE)
1119 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1121 for (int i = 0; i < NB_BANDS - 1; i++) {
1122 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1124 for (int j = 0; j < band_size; j++) {
1125 float frac = (float)j / band_size;
1127 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1132 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1133 const float *Exp, const float *g)
1135 float newE[NB_BANDS];
1137 float norm[NB_BANDS];
1138 float rf[FREQ_SIZE] = {0};
1139 float normf[FREQ_SIZE]={0};
1141 for (int i = 0; i < NB_BANDS; i++) {
1142 if (Exp[i]>g[i]) r[i] = 1;
1143 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1144 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1145 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1147 interp_band_gain(rf, r);
1148 for (int i = 0; i < FREQ_SIZE; i++) {
1149 X[i].re += rf[i]*P[i].re;
1150 X[i].im += rf[i]*P[i].im;
1152 compute_band_energy(newE, X);
1153 for (int i = 0; i < NB_BANDS; i++) {
1154 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1156 interp_band_gain(normf, norm);
1157 for (int i = 0; i < FREQ_SIZE; i++) {
1158 X[i].re *= normf[i];
1159 X[i].im *= normf[i];
1163 static const float tansig_table[201] = {
1164 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1165 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1166 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1167 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1168 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1169 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1170 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1171 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1172 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1173 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1174 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1175 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1176 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1177 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1178 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1179 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1180 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1181 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1182 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1183 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1184 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1185 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1186 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1187 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1188 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1189 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1190 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1191 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1192 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1193 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1194 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1195 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1196 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1197 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1198 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1199 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1200 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1201 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1202 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1203 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1207 static inline float tansig_approx(float x)
1213 /* Tests are reversed to catch NaNs */
1218 /* Another check in case of -ffast-math */
1227 i = (int)floor(.5f+25*x);
1229 y = tansig_table[i];
1231 y = y + x*dy*(1 - y*x);
1235 static inline float sigmoid_approx(float x)
1237 return .5f + .5f*tansig_approx(.5f*x);
1240 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1242 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1244 for (int i = 0; i < N; i++) {
1245 /* Compute update gate. */
1246 float sum = layer->bias[i];
1248 for (int j = 0; j < M; j++)
1249 sum += layer->input_weights[j * stride + i] * input[j];
1251 output[i] = WEIGHTS_SCALE * sum;
1254 if (layer->activation == ACTIVATION_SIGMOID) {
1255 for (int i = 0; i < N; i++)
1256 output[i] = sigmoid_approx(output[i]);
1257 } else if (layer->activation == ACTIVATION_TANH) {
1258 for (int i = 0; i < N; i++)
1259 output[i] = tansig_approx(output[i]);
1260 } else if (layer->activation == ACTIVATION_RELU) {
1261 for (int i = 0; i < N; i++)
1262 output[i] = FFMAX(0, output[i]);
1268 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1270 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1271 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1272 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1273 const int M = gru->nb_inputs;
1274 const int N = gru->nb_neurons;
1275 const int AN = FFALIGN(N, 4);
1276 const int AM = FFALIGN(M, 4);
1277 const int stride = 3 * AN, istride = 3 * AM;
1279 for (int i = 0; i < N; i++) {
1280 /* Compute update gate. */
1281 float sum = gru->bias[i];
1283 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1284 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1285 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1288 for (int i = 0; i < N; i++) {
1289 /* Compute reset gate. */
1290 float sum = gru->bias[N + i];
1292 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1293 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1294 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1297 for (int i = 0; i < N; i++) {
1298 /* Compute output. */
1299 float sum = gru->bias[2 * N + i];
1301 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1302 for (int j = 0; j < N; j++)
1303 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1305 if (gru->activation == ACTIVATION_SIGMOID)
1306 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1307 else if (gru->activation == ACTIVATION_TANH)
1308 sum = tansig_approx(WEIGHTS_SCALE * sum);
1309 else if (gru->activation == ACTIVATION_RELU)
1310 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1313 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1316 RNN_COPY(state, h, N);
1319 #define INPUT_SIZE 42
1321 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1323 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1324 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1325 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1327 compute_dense(rnn->model->input_dense, dense_out, input);
1328 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1329 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1331 memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1332 memcpy(noise_input + rnn->model->input_dense_size,
1333 rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1334 memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1335 input, INPUT_SIZE * sizeof(float));
1337 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1339 memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1340 memcpy(denoise_input + rnn->model->vad_gru_size,
1341 rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1342 memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1343 input, INPUT_SIZE * sizeof(float));
1345 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1346 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1349 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1352 AVComplexFloat X[FREQ_SIZE];
1353 AVComplexFloat P[WINDOW_SIZE];
1354 float x[FRAME_SIZE];
1355 float Ex[NB_BANDS], Ep[NB_BANDS];
1356 LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1357 float features[NB_FEATURES];
1359 float gf[FREQ_SIZE];
1361 float *history = st->history;
1362 static const float a_hp[2] = {-1.99599, 0.99600};
1363 static const float b_hp[2] = {-2, 1};
1366 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1367 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1369 if (!silence && !disabled) {
1370 compute_rnn(s, &st->rnn, g, &vad_prob, features);
1371 pitch_filter(X, P, Ex, Ep, Exp, g);
1372 for (int i = 0; i < NB_BANDS; i++) {
1375 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1376 st->lastg[i] = g[i];
1379 interp_band_gain(gf, g);
1381 for (int i = 0; i < FREQ_SIZE; i++) {
1387 frame_synthesis(s, st, out, X);
1388 memcpy(history, in, FRAME_SIZE * sizeof(*history));
1393 typedef struct ThreadData {
1397 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1399 AudioRNNContext *s = ctx->priv;
1400 ThreadData *td = arg;
1401 AVFrame *in = td->in;
1402 AVFrame *out = td->out;
1403 const int start = (out->channels * jobnr) / nb_jobs;
1404 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1406 for (int ch = start; ch < end; ch++) {
1407 rnnoise_channel(s, &s->st[ch],
1408 (float *)out->extended_data[ch],
1409 (const float *)in->extended_data[ch],
1416 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1418 AVFilterContext *ctx = inlink->dst;
1419 AVFilterLink *outlink = ctx->outputs[0];
1420 AVFrame *out = NULL;
1423 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1426 return AVERROR(ENOMEM);
1430 td.in = in; td.out = out;
1431 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1432 ff_filter_get_nb_threads(ctx)));
1435 return ff_filter_frame(outlink, out);
1438 static int activate(AVFilterContext *ctx)
1440 AVFilterLink *inlink = ctx->inputs[0];
1441 AVFilterLink *outlink = ctx->outputs[0];
1445 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1447 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1452 return filter_frame(inlink, in);
1454 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1455 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1457 return FFERROR_NOT_READY;
1460 static av_cold int init(AVFilterContext *ctx)
1462 AudioRNNContext *s = ctx->priv;
1465 s->fdsp = avpriv_float_dsp_alloc(0);
1467 return AVERROR(ENOMEM);
1470 return AVERROR(EINVAL);
1471 f = av_fopen_utf8(s->model_name, "r");
1473 return AVERROR(EINVAL);
1475 s->model = rnnoise_model_from_file(f);
1478 return AVERROR(EINVAL);
1480 for (int i = 0; i < FRAME_SIZE; i++) {
1481 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1482 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1485 for (int i = 0; i < NB_BANDS; i++) {
1486 for (int j = 0; j < NB_BANDS; j++) {
1487 s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1489 s->dct_table[j][i] *= sqrtf(.5);
1496 static av_cold void uninit(AVFilterContext *ctx)
1498 AudioRNNContext *s = ctx->priv;
1501 rnnoise_model_free(s->model);
1505 for (int ch = 0; ch < s->channels; ch++) {
1506 av_freep(&s->st[ch].rnn.vad_gru_state);
1507 av_freep(&s->st[ch].rnn.noise_gru_state);
1508 av_freep(&s->st[ch].rnn.denoise_gru_state);
1509 av_tx_uninit(&s->st[ch].tx);
1510 av_tx_uninit(&s->st[ch].txi);
1516 static const AVFilterPad inputs[] = {
1519 .type = AVMEDIA_TYPE_AUDIO,
1520 .config_props = config_input,
1525 static const AVFilterPad outputs[] = {
1528 .type = AVMEDIA_TYPE_AUDIO,
1533 #define OFFSET(x) offsetof(AudioRNNContext, x)
1534 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1536 static const AVOption arnndn_options[] = {
1537 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1538 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1539 { "mix", "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1543 AVFILTER_DEFINE_CLASS(arnndn);
1545 AVFilter ff_af_arnndn = {
1547 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1548 .query_formats = query_formats,
1549 .priv_size = sizeof(AudioRNNContext),
1550 .priv_class = &arnndn_class,
1551 .activate = activate,
1556 .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1557 AVFILTER_FLAG_SLICE_THREADS,