2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56 #define SQUARE(x) ((x)*(x))
61 #define NB_DELTA_CEPS 6
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65 #define WEIGHTS_SCALE (1.f/256)
67 #define MAX_NEURONS 128
69 #define ACTIVATION_TANH 0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU 2
75 typedef struct DenseLayer {
77 const float *input_weights;
83 typedef struct GRULayer {
85 const float *input_weights;
86 const float *recurrent_weights;
92 typedef struct RNNModel {
94 const DenseLayer *input_dense;
97 const GRULayer *vad_gru;
100 const GRULayer *noise_gru;
102 int denoise_gru_size;
103 const GRULayer *denoise_gru;
105 int denoise_output_size;
106 const DenseLayer *denoise_output;
109 const DenseLayer *vad_output;
112 typedef struct RNNState {
113 float *vad_gru_state;
114 float *noise_gru_state;
115 float *denoise_gru_state;
119 typedef struct DenoiseState {
120 float analysis_mem[FRAME_SIZE];
121 float cepstral_mem[CEPS_MEM][NB_BANDS];
123 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124 float pitch_buf[PITCH_BUF_SIZE];
125 float pitch_enh_buf[PITCH_BUF_SIZE];
129 float lastg[NB_BANDS];
131 AVTXContext *tx, *txi;
132 av_tx_fn tx_fn, txi_fn;
135 typedef struct AudioRNNContext {
136 const AVClass *class;
143 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144 DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148 AVFloatDSPContext *fdsp;
151 #define F_ACTIVATION_TANH 0
152 #define F_ACTIVATION_SIGMOID 1
153 #define F_ACTIVATION_RELU 2
155 static void rnnoise_model_free(RNNModel *model)
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
160 av_free((void *) model->name->input_weights); \
161 av_free((void *) model->name->bias); \
162 av_free((void *) model->name); \
165 #define FREE_GRU(name) do { \
167 av_free((void *) model->name->input_weights); \
168 av_free((void *) model->name->recurrent_weights); \
169 av_free((void *) model->name->bias); \
170 av_free((void *) model->name); \
176 FREE_DENSE(input_dense);
179 FREE_GRU(denoise_gru);
180 FREE_DENSE(denoise_output);
181 FREE_DENSE(vad_output);
185 static RNNModel *rnnoise_model_from_file(FILE *f)
188 DenseLayer *input_dense;
191 GRULayer *denoise_gru;
192 DenseLayer *denoise_output;
193 DenseLayer *vad_output;
196 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
199 ret = av_calloc(1, sizeof(RNNModel));
203 #define ALLOC_LAYER(type, name) \
204 name = av_calloc(1, sizeof(type)); \
206 rnnoise_model_free(ret); \
211 ALLOC_LAYER(DenseLayer, input_dense);
212 ALLOC_LAYER(GRULayer, vad_gru);
213 ALLOC_LAYER(GRULayer, noise_gru);
214 ALLOC_LAYER(GRULayer, denoise_gru);
215 ALLOC_LAYER(DenseLayer, denoise_output);
216 ALLOC_LAYER(DenseLayer, vad_output);
218 #define INPUT_VAL(name) do { \
219 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220 rnnoise_model_free(ret); \
226 #define INPUT_ACTIVATION(name) do { \
228 INPUT_VAL(activation); \
229 switch (activation) { \
230 case F_ACTIVATION_SIGMOID: \
231 name = ACTIVATION_SIGMOID; \
233 case F_ACTIVATION_RELU: \
234 name = ACTIVATION_RELU; \
237 name = ACTIVATION_TANH; \
241 #define INPUT_ARRAY(name, len) do { \
242 float *values = av_calloc((len), sizeof(float)); \
244 rnnoise_model_free(ret); \
248 for (int i = 0; i < (len); i++) { \
249 if (fscanf(f, "%d", &in) != 1) { \
250 rnnoise_model_free(ret); \
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
260 rnnoise_model_free(ret); \
264 for (int k = 0; k < (len0); k++) { \
265 for (int i = 0; i < (len2); i++) { \
266 for (int j = 0; j < (len1); j++) { \
267 if (fscanf(f, "%d", &in) != 1) { \
268 rnnoise_model_free(ret); \
271 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
277 #define INPUT_DENSE(name) do { \
278 INPUT_VAL(name->nb_inputs); \
279 INPUT_VAL(name->nb_neurons); \
280 ret->name ## _size = name->nb_neurons; \
281 INPUT_ACTIVATION(name->activation); \
282 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283 INPUT_ARRAY(name->bias, name->nb_neurons); \
286 #define INPUT_GRU(name) do { \
287 INPUT_VAL(name->nb_inputs); \
288 INPUT_VAL(name->nb_neurons); \
289 ret->name ## _size = name->nb_neurons; \
290 INPUT_ACTIVATION(name->activation); \
291 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
296 INPUT_DENSE(input_dense);
298 INPUT_GRU(noise_gru);
299 INPUT_GRU(denoise_gru);
300 INPUT_DENSE(denoise_output);
301 INPUT_DENSE(vad_output);
303 if (vad_output->nb_neurons != 1) {
304 rnnoise_model_free(ret);
311 static int query_formats(AVFilterContext *ctx)
313 AVFilterFormats *formats = NULL;
314 AVFilterChannelLayouts *layouts = NULL;
315 static const enum AVSampleFormat sample_fmts[] = {
319 int ret, sample_rates[] = { 48000, -1 };
321 formats = ff_make_format_list(sample_fmts);
323 return AVERROR(ENOMEM);
324 ret = ff_set_common_formats(ctx, formats);
328 layouts = ff_all_channel_counts();
330 return AVERROR(ENOMEM);
332 ret = ff_set_common_channel_layouts(ctx, layouts);
336 formats = ff_make_format_list(sample_rates);
338 return AVERROR(ENOMEM);
339 return ff_set_common_samplerates(ctx, formats);
342 static int config_input(AVFilterLink *inlink)
344 AVFilterContext *ctx = inlink->dst;
345 AudioRNNContext *s = ctx->priv;
348 s->channels = inlink->channels;
350 s->st = av_calloc(s->channels, sizeof(DenoiseState));
352 return AVERROR(ENOMEM);
354 for (int i = 0; i < s->channels; i++) {
355 DenoiseState *st = &s->st[i];
357 st->rnn.model = s->model;
358 st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359 st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360 st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361 if (!st->rnn.vad_gru_state ||
362 !st->rnn.noise_gru_state ||
363 !st->rnn.denoise_gru_state)
364 return AVERROR(ENOMEM);
366 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
370 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
378 static void biquad(float *y, float mem[2], const float *x,
379 const float *b, const float *a, int N)
381 for (int i = 0; i < N; i++) {
386 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387 mem[1] = (b[1]*xi - a[1]*yi);
392 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
396 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
398 AVComplexFloat x[WINDOW_SIZE];
399 AVComplexFloat y[WINDOW_SIZE];
401 for (int i = 0; i < WINDOW_SIZE; i++) {
406 st->tx_fn(st->tx, y, x, sizeof(float));
408 RNN_COPY(out, y, FREQ_SIZE);
411 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
413 AVComplexFloat x[WINDOW_SIZE];
414 AVComplexFloat y[WINDOW_SIZE];
416 RNN_COPY(x, in, FREQ_SIZE);
418 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
419 x[i].re = x[WINDOW_SIZE - i].re;
420 x[i].im = -x[WINDOW_SIZE - i].im;
423 st->txi_fn(st->txi, y, x, sizeof(float));
425 for (int i = 0; i < WINDOW_SIZE; i++)
426 out[i] = y[i].re / WINDOW_SIZE;
429 static const uint8_t eband5ms[] = {
430 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
431 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
434 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
436 float sum[NB_BANDS] = {0};
438 for (int i = 0; i < NB_BANDS - 1; i++) {
441 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
442 for (int j = 0; j < band_size; j++) {
443 float tmp, frac = (float)j / band_size;
445 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
446 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
447 sum[i] += (1.f - frac) * tmp;
448 sum[i + 1] += frac * tmp;
453 sum[NB_BANDS - 1] *= 2;
455 for (int i = 0; i < NB_BANDS; i++)
459 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
461 float sum[NB_BANDS] = { 0 };
463 for (int i = 0; i < NB_BANDS - 1; i++) {
466 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
467 for (int j = 0; j < band_size; j++) {
468 float tmp, frac = (float)j / band_size;
470 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
471 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
472 sum[i] += (1 - frac) * tmp;
473 sum[i + 1] += frac * tmp;
478 sum[NB_BANDS-1] *= 2;
480 for (int i = 0; i < NB_BANDS; i++)
484 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
486 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
488 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
489 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
490 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
491 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
492 forward_transform(st, X, x);
493 compute_band_energy(Ex, X);
496 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
498 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
500 inverse_transform(st, x, y);
501 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
502 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
503 RNN_COPY(out, x, FRAME_SIZE);
504 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
507 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
509 float y_0, y_1, y_2, y_3 = 0;
516 for (j = 0; j < len - 3; j += 4) {
576 static inline float celt_inner_prod(const float *x,
577 const float *y, int N)
581 for (int i = 0; i < N; i++)
587 static void celt_pitch_xcorr(const float *x, const float *y,
588 float *xcorr, int len, int max_pitch)
592 for (i = 0; i < max_pitch - 3; i += 4) {
593 float sum[4] = { 0, 0, 0, 0};
595 xcorr_kernel(x, y + i, sum, len);
598 xcorr[i + 1] = sum[1];
599 xcorr[i + 2] = sum[2];
600 xcorr[i + 3] = sum[3];
602 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
603 for (; i < max_pitch; i++) {
604 xcorr[i] = celt_inner_prod(x, y + i, len);
608 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
609 float *ac, /* out: [0...lag-1] ac values */
618 float xx[PITCH_BUF_SIZE>>1];
623 for (int i = 0; i < n; i++)
625 for (int i = 0; i < overlap; i++) {
626 xx[i] = x[i] * window[i];
627 xx[n-i-1] = x[n-i-1] * window[i];
633 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
635 for (int k = 0; k <= lag; k++) {
638 for (int i = k + fastN; i < n; i++)
639 d += xptr[i] * xptr[i-k];
646 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
647 const float *ac, /* in: [0...p] autocorrelation values */
650 float r, error = ac[0];
654 for (int i = 0; i < p; i++) {
655 /* Sum up this iteration's reflection coefficient */
657 for (int j = 0; j < i; j++)
658 rr += (lpc[j] * ac[i - j]);
661 /* Update LPC coefficients and total error */
663 for (int j = 0; j < (i + 1) >> 1; j++) {
667 lpc[j] = tmp1 + (r*tmp2);
668 lpc[i-1-j] = tmp2 + (r*tmp1);
671 error = error - (r * r *error);
672 /* Bail out once we get 30 dB gain */
673 if (error < .001f * ac[0])
679 static void celt_fir5(const float *x,
685 float num0, num1, num2, num3, num4;
686 float mem0, mem1, mem2, mem3, mem4;
699 for (int i = 0; i < N; i++) {
722 static void pitch_downsample(float *x[], float *x_lp,
727 float lpc[4], mem[5]={0,0,0,0,0};
731 for (int i = 1; i < len >> 1; i++)
732 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
733 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
735 for (int i = 1; i < len >> 1; i++)
736 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
737 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
740 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
742 /* Noise floor -40 dB */
745 for (int i = 1; i <= 4; i++) {
746 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
747 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
750 celt_lpc(lpc, ac, 4);
751 for (int i = 0; i < 4; i++) {
753 lpc[i] = (lpc[i] * tmp);
756 lpc2[0] = lpc[0] + .8f;
757 lpc2[1] = lpc[1] + (c1 * lpc[0]);
758 lpc2[2] = lpc[2] + (c1 * lpc[1]);
759 lpc2[3] = lpc[3] + (c1 * lpc[2]);
760 lpc2[4] = (c1 * lpc[3]);
761 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
764 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
765 int N, float *xy1, float *xy2)
767 float xy01 = 0, xy02 = 0;
769 for (int i = 0; i < N; i++) {
770 xy01 += (x[i] * y01[i]);
771 xy02 += (x[i] * y02[i]);
778 static float compute_pitch_gain(float xy, float xx, float yy)
780 return xy / sqrtf(1.f + xx * yy);
783 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
784 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
785 int *T0_, int prev_period, float prev_gain)
792 float best_xy, best_yy;
795 float yy_lookup[PITCH_MAX_PERIOD+1];
797 minperiod0 = minperiod;
808 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
811 for (i = 1; i <= maxperiod; i++) {
812 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
813 yy_lookup[i] = FFMAX(0, yy);
818 g = g0 = compute_pitch_gain(xy, xx, yy);
819 /* Look for any pitch at T/k */
820 for (k = 2; k <= 15; k++) {
828 /* Look for another strong correlation at T1b */
837 T1b = (2*second_check[k]*T0+k)/(2*k);
839 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
840 xy = .5f * (xy + xy2);
841 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
842 g1 = compute_pitch_gain(xy, xx, yy);
843 if (FFABS(T1-prev_period)<=1)
845 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
846 cont = prev_gain * .5f;
849 thresh = FFMAX(.3f, (.7f * g0) - cont);
850 /* Bias against very high pitch (very short period) to avoid false-positives
851 due to short-term correlation */
853 thresh = FFMAX(.4f, (.85f * g0) - cont);
854 else if (T1<2*minperiod)
855 thresh = FFMAX(.5f, (.9f * g0) - cont);
864 best_xy = FFMAX(0, best_xy);
865 if (best_yy <= best_xy)
868 pg = best_xy/(best_yy + 1);
870 for (k = 0; k < 3; k++)
871 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
872 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
874 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
887 static void find_best_pitch(float *xcorr, float *y, int len,
888 int max_pitch, int *best_pitch)
901 for (int j = 0; j < len; j++)
904 for (int i = 0; i < max_pitch; i++) {
910 /* Considering the range of xcorr16, this should avoid both underflows
911 and overflows (inf) when squaring xcorr16 */
913 num = xcorr16 * xcorr16;
914 if ((num * best_den[1]) > (best_num[1] * Syy)) {
915 if ((num * best_den[0]) > (best_num[0] * Syy)) {
916 best_num[1] = best_num[0];
917 best_den[1] = best_den[0];
918 best_pitch[1] = best_pitch[0];
929 Syy += y[i+len]*y[i+len] - y[i] * y[i];
934 static void pitch_search(const float *x_lp, float *y,
935 int len, int max_pitch, int *pitch)
938 int best_pitch[2]={0,0};
941 float x_lp4[WINDOW_SIZE];
942 float y_lp4[WINDOW_SIZE];
943 float xcorr[WINDOW_SIZE];
947 /* Downsample by 2 again */
948 for (int j = 0; j < len >> 2; j++)
949 x_lp4[j] = x_lp[2*j];
950 for (int j = 0; j < lag >> 2; j++)
953 /* Coarse search with 4x decimation */
955 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
957 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
959 /* Finer search with 2x decimation */
960 for (int i = 0; i < max_pitch >> 1; i++) {
963 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
965 sum = celt_inner_prod(x_lp, y+i, len>>1);
966 xcorr[i] = FFMAX(-1, sum);
969 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
971 /* Refine by pseudo-interpolation */
972 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
975 a = xcorr[best_pitch[0] - 1];
976 b = xcorr[best_pitch[0]];
977 c = xcorr[best_pitch[0] + 1];
978 if (c - a > .7f * (b - a))
980 else if (a - c > .7f * (b-c))
988 *pitch = 2 * best_pitch[0] - offset;
991 static void dct(AudioRNNContext *s, float *out, const float *in)
993 for (int i = 0; i < NB_BANDS; i++) {
996 sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
997 out[i] = sum * sqrtf(2.f / 22);
1001 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1002 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1005 float *ceps_0, *ceps_1, *ceps_2;
1006 float spec_variability = 0;
1007 LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1008 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1009 float pitch_buf[PITCH_BUF_SIZE>>1];
1013 float tmp[NB_BANDS];
1014 float follow, logMax;
1016 frame_analysis(s, st, X, Ex, in);
1017 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1018 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1019 pre[0] = &st->pitch_buf[0];
1020 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1021 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1022 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1023 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1025 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1026 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1027 st->last_period = pitch_index;
1028 st->last_gain = gain;
1030 for (int i = 0; i < WINDOW_SIZE; i++)
1031 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1033 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1034 forward_transform(st, P, p);
1035 compute_band_energy(Ep, P);
1036 compute_band_corr(Exp, X, P);
1038 for (int i = 0; i < NB_BANDS; i++)
1039 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1043 for (int i = 0; i < NB_DELTA_CEPS; i++)
1044 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1046 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1047 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1048 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1052 for (int i = 0; i < NB_BANDS; i++) {
1053 Ly[i] = log10f(1e-2f + Ex[i]);
1054 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1055 logMax = FFMAX(logMax, Ly[i]);
1056 follow = FFMAX(follow-1.5, Ly[i]);
1061 /* If there's no audio, avoid messing up the state. */
1062 RNN_CLEAR(features, NB_FEATURES);
1066 dct(s, features, Ly);
1069 ceps_0 = st->cepstral_mem[st->memid];
1070 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1071 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1073 for (int i = 0; i < NB_BANDS; i++)
1074 ceps_0[i] = features[i];
1077 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1078 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1079 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1080 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1082 /* Spectral variability features. */
1083 if (st->memid == CEPS_MEM)
1086 for (int i = 0; i < CEPS_MEM; i++) {
1087 float mindist = 1e15f;
1088 for (int j = 0; j < CEPS_MEM; j++) {
1090 for (int k = 0; k < NB_BANDS; k++) {
1093 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1098 mindist = FFMIN(mindist, dist);
1101 spec_variability += mindist;
1104 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1109 static void interp_band_gain(float *g, const float *bandE)
1111 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1113 for (int i = 0; i < NB_BANDS - 1; i++) {
1114 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1116 for (int j = 0; j < band_size; j++) {
1117 float frac = (float)j / band_size;
1119 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1124 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1125 const float *Exp, const float *g)
1127 float newE[NB_BANDS];
1129 float norm[NB_BANDS];
1130 float rf[FREQ_SIZE] = {0};
1131 float normf[FREQ_SIZE]={0};
1133 for (int i = 0; i < NB_BANDS; i++) {
1134 if (Exp[i]>g[i]) r[i] = 1;
1135 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1136 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1137 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1139 interp_band_gain(rf, r);
1140 for (int i = 0; i < FREQ_SIZE; i++) {
1141 X[i].re += rf[i]*P[i].re;
1142 X[i].im += rf[i]*P[i].im;
1144 compute_band_energy(newE, X);
1145 for (int i = 0; i < NB_BANDS; i++) {
1146 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1148 interp_band_gain(normf, norm);
1149 for (int i = 0; i < FREQ_SIZE; i++) {
1150 X[i].re *= normf[i];
1151 X[i].im *= normf[i];
1155 static const float tansig_table[201] = {
1156 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1157 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1158 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1159 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1160 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1161 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1162 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1163 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1164 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1165 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1166 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1167 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1168 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1169 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1170 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1171 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1172 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1173 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1174 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1175 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1176 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1177 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1178 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1179 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1180 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1181 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1182 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1183 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1184 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1185 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1186 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1187 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1188 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1189 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1190 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1191 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1192 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1193 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1194 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1195 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1199 static inline float tansig_approx(float x)
1205 /* Tests are reversed to catch NaNs */
1210 /* Another check in case of -ffast-math */
1219 i = (int)floor(.5f+25*x);
1221 y = tansig_table[i];
1223 y = y + x*dy*(1 - y*x);
1227 static inline float sigmoid_approx(float x)
1229 return .5f + .5f*tansig_approx(.5f*x);
1232 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1234 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1236 for (int i = 0; i < N; i++) {
1237 /* Compute update gate. */
1238 float sum = layer->bias[i];
1240 for (int j = 0; j < M; j++)
1241 sum += layer->input_weights[j * stride + i] * input[j];
1243 output[i] = WEIGHTS_SCALE * sum;
1246 if (layer->activation == ACTIVATION_SIGMOID) {
1247 for (int i = 0; i < N; i++)
1248 output[i] = sigmoid_approx(output[i]);
1249 } else if (layer->activation == ACTIVATION_TANH) {
1250 for (int i = 0; i < N; i++)
1251 output[i] = tansig_approx(output[i]);
1252 } else if (layer->activation == ACTIVATION_RELU) {
1253 for (int i = 0; i < N; i++)
1254 output[i] = FFMAX(0, output[i]);
1260 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1262 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1263 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1264 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1265 const int M = gru->nb_inputs;
1266 const int N = gru->nb_neurons;
1267 const int AN = FFALIGN(N, 4);
1268 const int AM = FFALIGN(M, 4);
1269 const int stride = 3 * AN, istride = 3 * AM;
1271 for (int i = 0; i < N; i++) {
1272 /* Compute update gate. */
1273 float sum = gru->bias[i];
1275 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1276 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1277 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1280 for (int i = 0; i < N; i++) {
1281 /* Compute reset gate. */
1282 float sum = gru->bias[N + i];
1284 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1285 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1286 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1289 for (int i = 0; i < N; i++) {
1290 /* Compute output. */
1291 float sum = gru->bias[2 * N + i];
1293 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1294 for (int j = 0; j < N; j++)
1295 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1297 if (gru->activation == ACTIVATION_SIGMOID)
1298 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1299 else if (gru->activation == ACTIVATION_TANH)
1300 sum = tansig_approx(WEIGHTS_SCALE * sum);
1301 else if (gru->activation == ACTIVATION_RELU)
1302 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1305 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1308 RNN_COPY(state, h, N);
1311 #define INPUT_SIZE 42
1313 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1315 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1316 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1317 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1319 compute_dense(rnn->model->input_dense, dense_out, input);
1320 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1321 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1323 memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1324 memcpy(noise_input + rnn->model->input_dense_size,
1325 rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1326 memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1327 input, INPUT_SIZE * sizeof(float));
1329 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1331 memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1332 memcpy(denoise_input + rnn->model->vad_gru_size,
1333 rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1334 memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1335 input, INPUT_SIZE * sizeof(float));
1337 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1338 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1341 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1343 AVComplexFloat X[FREQ_SIZE];
1344 AVComplexFloat P[WINDOW_SIZE];
1345 float x[FRAME_SIZE];
1346 float Ex[NB_BANDS], Ep[NB_BANDS];
1347 LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1348 float features[NB_FEATURES];
1350 float gf[FREQ_SIZE];
1352 static const float a_hp[2] = {-1.99599, 0.99600};
1353 static const float b_hp[2] = {-2, 1};
1356 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1357 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1360 compute_rnn(s, &st->rnn, g, &vad_prob, features);
1361 pitch_filter(X, P, Ex, Ep, Exp, g);
1362 for (int i = 0; i < NB_BANDS; i++) {
1365 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1366 st->lastg[i] = g[i];
1369 interp_band_gain(gf, g);
1371 for (int i = 0; i < FREQ_SIZE; i++) {
1377 frame_synthesis(s, st, out, X);
1382 typedef struct ThreadData {
1386 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1388 AudioRNNContext *s = ctx->priv;
1389 ThreadData *td = arg;
1390 AVFrame *in = td->in;
1391 AVFrame *out = td->out;
1392 const int start = (out->channels * jobnr) / nb_jobs;
1393 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1395 for (int ch = start; ch < end; ch++) {
1396 rnnoise_channel(s, &s->st[ch],
1397 (float *)out->extended_data[ch],
1398 (const float *)in->extended_data[ch]);
1404 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1406 AVFilterContext *ctx = inlink->dst;
1407 AVFilterLink *outlink = ctx->outputs[0];
1408 AVFrame *out = NULL;
1411 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1414 return AVERROR(ENOMEM);
1418 td.in = in; td.out = out;
1419 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1420 ff_filter_get_nb_threads(ctx)));
1423 return ff_filter_frame(outlink, out);
1426 static int activate(AVFilterContext *ctx)
1428 AVFilterLink *inlink = ctx->inputs[0];
1429 AVFilterLink *outlink = ctx->outputs[0];
1433 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1435 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1440 return filter_frame(inlink, in);
1442 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1443 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1445 return FFERROR_NOT_READY;
1448 static av_cold int init(AVFilterContext *ctx)
1450 AudioRNNContext *s = ctx->priv;
1453 s->fdsp = avpriv_float_dsp_alloc(0);
1455 return AVERROR(ENOMEM);
1458 return AVERROR(EINVAL);
1459 f = av_fopen_utf8(s->model_name, "r");
1461 return AVERROR(EINVAL);
1463 s->model = rnnoise_model_from_file(f);
1466 return AVERROR(EINVAL);
1468 for (int i = 0; i < FRAME_SIZE; i++) {
1469 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1470 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1473 for (int i = 0; i < NB_BANDS; i++) {
1474 for (int j = 0; j < NB_BANDS; j++) {
1475 s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1477 s->dct_table[j][i] *= sqrtf(.5);
1484 static av_cold void uninit(AVFilterContext *ctx)
1486 AudioRNNContext *s = ctx->priv;
1489 rnnoise_model_free(s->model);
1493 for (int ch = 0; ch < s->channels; ch++) {
1494 av_freep(&s->st[ch].rnn.vad_gru_state);
1495 av_freep(&s->st[ch].rnn.noise_gru_state);
1496 av_freep(&s->st[ch].rnn.denoise_gru_state);
1497 av_tx_uninit(&s->st[ch].tx);
1498 av_tx_uninit(&s->st[ch].txi);
1504 static const AVFilterPad inputs[] = {
1507 .type = AVMEDIA_TYPE_AUDIO,
1508 .config_props = config_input,
1513 static const AVFilterPad outputs[] = {
1516 .type = AVMEDIA_TYPE_AUDIO,
1521 #define OFFSET(x) offsetof(AudioRNNContext, x)
1522 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1524 static const AVOption arnndn_options[] = {
1525 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1526 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1530 AVFILTER_DEFINE_CLASS(arnndn);
1532 AVFilter ff_af_arnndn = {
1534 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1535 .query_formats = query_formats,
1536 .priv_size = sizeof(AudioRNNContext),
1537 .priv_class = &arnndn_class,
1538 .activate = activate,
1543 .flags = AVFILTER_FLAG_SLICE_THREADS,