2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56 #define SQUARE(x) ((x)*(x))
61 #define NB_DELTA_CEPS 6
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65 #define WEIGHTS_SCALE (1.f/256)
67 #define MAX_NEURONS 128
69 #define ACTIVATION_TANH 0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU 2
75 typedef struct DenseLayer {
77 const float *input_weights;
83 typedef struct GRULayer {
85 const float *input_weights;
86 const float *recurrent_weights;
92 typedef struct RNNModel {
94 const DenseLayer *input_dense;
97 const GRULayer *vad_gru;
100 const GRULayer *noise_gru;
102 int denoise_gru_size;
103 const GRULayer *denoise_gru;
105 int denoise_output_size;
106 const DenseLayer *denoise_output;
109 const DenseLayer *vad_output;
112 typedef struct RNNState {
113 float *vad_gru_state;
114 float *noise_gru_state;
115 float *denoise_gru_state;
119 typedef struct DenoiseState {
120 float analysis_mem[FRAME_SIZE];
121 float cepstral_mem[CEPS_MEM][NB_BANDS];
123 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124 float pitch_buf[PITCH_BUF_SIZE];
125 float pitch_enh_buf[PITCH_BUF_SIZE];
129 float lastg[NB_BANDS];
131 AVTXContext *tx, *txi;
132 av_tx_fn tx_fn, txi_fn;
135 typedef struct AudioRNNContext {
136 const AVClass *class;
143 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144 float dct_table[NB_BANDS*NB_BANDS];
148 AVFloatDSPContext *fdsp;
151 #define F_ACTIVATION_TANH 0
152 #define F_ACTIVATION_SIGMOID 1
153 #define F_ACTIVATION_RELU 2
155 static void rnnoise_model_free(RNNModel *model)
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
160 av_free((void *) model->name->input_weights); \
161 av_free((void *) model->name->bias); \
162 av_free((void *) model->name); \
165 #define FREE_GRU(name) do { \
167 av_free((void *) model->name->input_weights); \
168 av_free((void *) model->name->recurrent_weights); \
169 av_free((void *) model->name->bias); \
170 av_free((void *) model->name); \
176 FREE_DENSE(input_dense);
179 FREE_GRU(denoise_gru);
180 FREE_DENSE(denoise_output);
181 FREE_DENSE(vad_output);
185 static RNNModel *rnnoise_model_from_file(FILE *f)
188 DenseLayer *input_dense;
191 GRULayer *denoise_gru;
192 DenseLayer *denoise_output;
193 DenseLayer *vad_output;
196 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
199 ret = av_calloc(1, sizeof(RNNModel));
203 #define ALLOC_LAYER(type, name) \
204 name = av_calloc(1, sizeof(type)); \
206 rnnoise_model_free(ret); \
211 ALLOC_LAYER(DenseLayer, input_dense);
212 ALLOC_LAYER(GRULayer, vad_gru);
213 ALLOC_LAYER(GRULayer, noise_gru);
214 ALLOC_LAYER(GRULayer, denoise_gru);
215 ALLOC_LAYER(DenseLayer, denoise_output);
216 ALLOC_LAYER(DenseLayer, vad_output);
218 #define INPUT_VAL(name) do { \
219 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220 rnnoise_model_free(ret); \
226 #define INPUT_ACTIVATION(name) do { \
228 INPUT_VAL(activation); \
229 switch (activation) { \
230 case F_ACTIVATION_SIGMOID: \
231 name = ACTIVATION_SIGMOID; \
233 case F_ACTIVATION_RELU: \
234 name = ACTIVATION_RELU; \
237 name = ACTIVATION_TANH; \
241 #define INPUT_ARRAY(name, len) do { \
242 float *values = av_calloc((len), sizeof(float)); \
244 rnnoise_model_free(ret); \
248 for (int i = 0; i < (len); i++) { \
249 if (fscanf(f, "%d", &in) != 1) { \
250 rnnoise_model_free(ret); \
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
260 rnnoise_model_free(ret); \
264 for (int k = 0; k < (len0); k++) { \
265 for (int i = 0; i < (len2); i++) { \
266 for (int j = 0; j < (len1); j++) { \
267 if (fscanf(f, "%d", &in) != 1) { \
268 rnnoise_model_free(ret); \
271 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
277 #define INPUT_DENSE(name) do { \
278 INPUT_VAL(name->nb_inputs); \
279 INPUT_VAL(name->nb_neurons); \
280 ret->name ## _size = name->nb_neurons; \
281 INPUT_ACTIVATION(name->activation); \
282 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283 INPUT_ARRAY(name->bias, name->nb_neurons); \
286 #define INPUT_GRU(name) do { \
287 INPUT_VAL(name->nb_inputs); \
288 INPUT_VAL(name->nb_neurons); \
289 ret->name ## _size = name->nb_neurons; \
290 INPUT_ACTIVATION(name->activation); \
291 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
296 INPUT_DENSE(input_dense);
298 INPUT_GRU(noise_gru);
299 INPUT_GRU(denoise_gru);
300 INPUT_DENSE(denoise_output);
301 INPUT_DENSE(vad_output);
303 if (vad_output->nb_neurons != 1) {
304 rnnoise_model_free(ret);
311 static int query_formats(AVFilterContext *ctx)
313 AVFilterFormats *formats = NULL;
314 AVFilterChannelLayouts *layouts = NULL;
315 static const enum AVSampleFormat sample_fmts[] = {
319 int ret, sample_rates[] = { 48000, -1 };
321 formats = ff_make_format_list(sample_fmts);
323 return AVERROR(ENOMEM);
324 ret = ff_set_common_formats(ctx, formats);
328 layouts = ff_all_channel_counts();
330 return AVERROR(ENOMEM);
332 ret = ff_set_common_channel_layouts(ctx, layouts);
336 formats = ff_make_format_list(sample_rates);
338 return AVERROR(ENOMEM);
339 return ff_set_common_samplerates(ctx, formats);
342 static int config_input(AVFilterLink *inlink)
344 AVFilterContext *ctx = inlink->dst;
345 AudioRNNContext *s = ctx->priv;
348 s->channels = inlink->channels;
350 s->st = av_calloc(s->channels, sizeof(DenoiseState));
352 return AVERROR(ENOMEM);
354 for (int i = 0; i < s->channels; i++) {
355 DenoiseState *st = &s->st[i];
357 st->rnn.model = s->model;
358 st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359 st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360 st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361 if (!st->rnn.vad_gru_state ||
362 !st->rnn.noise_gru_state ||
363 !st->rnn.denoise_gru_state)
364 return AVERROR(ENOMEM);
366 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
370 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
378 static void biquad(float *y, float mem[2], const float *x,
379 const float *b, const float *a, int N)
381 for (int i = 0; i < N; i++) {
386 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387 mem[1] = (b[1]*xi - a[1]*yi);
392 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
396 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
398 AVComplexFloat x[WINDOW_SIZE];
399 AVComplexFloat y[WINDOW_SIZE];
401 for (int i = 0; i < WINDOW_SIZE; i++) {
406 st->tx_fn(st->tx, y, x, sizeof(float));
408 RNN_COPY(out, y, FREQ_SIZE);
411 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
413 AVComplexFloat x[WINDOW_SIZE];
414 AVComplexFloat y[WINDOW_SIZE];
416 RNN_COPY(x, in, FREQ_SIZE);
418 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
419 x[i].re = x[WINDOW_SIZE - i].re;
420 x[i].im = -x[WINDOW_SIZE - i].im;
423 st->txi_fn(st->txi, y, x, sizeof(float));
425 for (int i = 0; i < WINDOW_SIZE; i++)
426 out[i] = y[i].re / WINDOW_SIZE;
429 static const uint8_t eband5ms[] = {
430 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
431 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
434 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
436 float sum[NB_BANDS] = {0};
438 for (int i = 0; i < NB_BANDS - 1; i++) {
441 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
442 for (int j = 0; j < band_size; j++) {
443 float tmp, frac = (float)j / band_size;
445 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
446 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
447 sum[i] += (1.f - frac) * tmp;
448 sum[i + 1] += frac * tmp;
453 sum[NB_BANDS - 1] *= 2;
455 for (int i = 0; i < NB_BANDS; i++)
459 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
461 float sum[NB_BANDS] = { 0 };
463 for (int i = 0; i < NB_BANDS - 1; i++) {
466 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
467 for (int j = 0; j < band_size; j++) {
468 float tmp, frac = (float)j / band_size;
470 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
471 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
472 sum[i] += (1 - frac) * tmp;
473 sum[i + 1] += frac * tmp;
478 sum[NB_BANDS-1] *= 2;
480 for (int i = 0; i < NB_BANDS; i++)
484 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
486 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
488 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
489 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
490 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
491 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
492 forward_transform(st, X, x);
493 compute_band_energy(Ex, X);
496 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
498 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
500 inverse_transform(st, x, y);
501 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
502 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
503 RNN_COPY(out, x, FRAME_SIZE);
504 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
507 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
509 float y_0, y_1, y_2, y_3 = 0;
516 for (j = 0; j < len - 3; j += 4) {
576 static inline float celt_inner_prod(const float *x,
577 const float *y, int N)
581 for (int i = 0; i < N; i++)
587 static void celt_pitch_xcorr(const float *x, const float *y,
588 float *xcorr, int len, int max_pitch)
592 for (i = 0; i < max_pitch - 3; i += 4) {
593 float sum[4] = { 0, 0, 0, 0};
595 xcorr_kernel(x, y + i, sum, len);
598 xcorr[i + 1] = sum[1];
599 xcorr[i + 2] = sum[2];
600 xcorr[i + 3] = sum[3];
602 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
603 for (; i < max_pitch; i++) {
604 xcorr[i] = celt_inner_prod(x, y + i, len);
608 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
609 float *ac, /* out: [0...lag-1] ac values */
618 float xx[PITCH_BUF_SIZE>>1];
623 for (int i = 0; i < n; i++)
625 for (int i = 0; i < overlap; i++) {
626 xx[i] = x[i] * window[i];
627 xx[n-i-1] = x[n-i-1] * window[i];
633 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
635 for (int k = 0; k <= lag; k++) {
638 for (int i = k + fastN; i < n; i++)
639 d += xptr[i] * xptr[i-k];
646 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
647 const float *ac, /* in: [0...p] autocorrelation values */
650 float r, error = ac[0];
654 for (int i = 0; i < p; i++) {
655 /* Sum up this iteration's reflection coefficient */
657 for (int j = 0; j < i; j++)
658 rr += (lpc[j] * ac[i - j]);
661 /* Update LPC coefficients and total error */
663 for (int j = 0; j < (i + 1) >> 1; j++) {
667 lpc[j] = tmp1 + (r*tmp2);
668 lpc[i-1-j] = tmp2 + (r*tmp1);
671 error = error - (r * r *error);
672 /* Bail out once we get 30 dB gain */
673 if (error < .001f * ac[0])
679 static void celt_fir5(const float *x,
685 float num0, num1, num2, num3, num4;
686 float mem0, mem1, mem2, mem3, mem4;
699 for (int i = 0; i < N; i++) {
722 static void pitch_downsample(float *x[], float *x_lp,
727 float lpc[4], mem[5]={0,0,0,0,0};
731 for (int i = 1; i < len >> 1; i++)
732 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
733 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
735 for (int i = 1; i < len >> 1; i++)
736 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
737 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
740 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
742 /* Noise floor -40 dB */
745 for (int i = 1; i <= 4; i++) {
746 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
747 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
750 celt_lpc(lpc, ac, 4);
751 for (int i = 0; i < 4; i++) {
753 lpc[i] = (lpc[i] * tmp);
756 lpc2[0] = lpc[0] + .8f;
757 lpc2[1] = lpc[1] + (c1 * lpc[0]);
758 lpc2[2] = lpc[2] + (c1 * lpc[1]);
759 lpc2[3] = lpc[3] + (c1 * lpc[2]);
760 lpc2[4] = (c1 * lpc[3]);
761 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
764 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
765 int N, float *xy1, float *xy2)
767 float xy01 = 0, xy02 = 0;
769 for (int i = 0; i < N; i++) {
770 xy01 += (x[i] * y01[i]);
771 xy02 += (x[i] * y02[i]);
778 static float compute_pitch_gain(float xy, float xx, float yy)
780 return xy / sqrtf(1.f + xx * yy);
783 static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
784 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
785 int *T0_, int prev_period, float prev_gain)
792 float best_xy, best_yy;
795 float yy_lookup[PITCH_MAX_PERIOD+1];
797 minperiod0 = minperiod;
808 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
811 for (i = 1; i <= maxperiod; i++) {
812 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
813 yy_lookup[i] = FFMAX(0, yy);
818 g = g0 = compute_pitch_gain(xy, xx, yy);
819 /* Look for any pitch at T/k */
820 for (k = 2; k <= 15; k++) {
828 /* Look for another strong correlation at T1b */
837 T1b = (2*second_check[k]*T0+k)/(2*k);
839 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
840 xy = .5f * (xy + xy2);
841 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
842 g1 = compute_pitch_gain(xy, xx, yy);
843 if (FFABS(T1-prev_period)<=1)
845 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
846 cont = prev_gain * .5f;
849 thresh = FFMAX(.3f, (.7f * g0) - cont);
850 /* Bias against very high pitch (very short period) to avoid false-positives
851 due to short-term correlation */
853 thresh = FFMAX(.4f, (.85f * g0) - cont);
854 else if (T1<2*minperiod)
855 thresh = FFMAX(.5f, (.9f * g0) - cont);
864 best_xy = FFMAX(0, best_xy);
865 if (best_yy <= best_xy)
868 pg = best_xy/(best_yy + 1);
870 for (k = 0; k < 3; k++)
871 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
872 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
874 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
887 static void find_best_pitch(float *xcorr, float *y, int len,
888 int max_pitch, int *best_pitch)
901 for (int j = 0; j < len; j++)
904 for (int i = 0; i < max_pitch; i++) {
910 /* Considering the range of xcorr16, this should avoid both underflows
911 and overflows (inf) when squaring xcorr16 */
913 num = xcorr16 * xcorr16;
914 if ((num * best_den[1]) > (best_num[1] * Syy)) {
915 if ((num * best_den[0]) > (best_num[0] * Syy)) {
916 best_num[1] = best_num[0];
917 best_den[1] = best_den[0];
918 best_pitch[1] = best_pitch[0];
929 Syy += y[i+len]*y[i+len] - y[i] * y[i];
934 static void pitch_search(const float *x_lp, float *y,
935 int len, int max_pitch, int *pitch)
938 int best_pitch[2]={0,0};
941 float x_lp4[WINDOW_SIZE];
942 float y_lp4[WINDOW_SIZE];
943 float xcorr[WINDOW_SIZE];
947 /* Downsample by 2 again */
948 for (int j = 0; j < len >> 2; j++)
949 x_lp4[j] = x_lp[2*j];
950 for (int j = 0; j < lag >> 2; j++)
953 /* Coarse search with 4x decimation */
955 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
957 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
959 /* Finer search with 2x decimation */
960 for (int i = 0; i < max_pitch >> 1; i++) {
963 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
965 sum = celt_inner_prod(x_lp, y+i, len>>1);
966 xcorr[i] = FFMAX(-1, sum);
969 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
971 /* Refine by pseudo-interpolation */
972 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
975 a = xcorr[best_pitch[0] - 1];
976 b = xcorr[best_pitch[0]];
977 c = xcorr[best_pitch[0] + 1];
978 if (c - a > .7f * (b - a))
980 else if (a - c > .7f * (b-c))
988 *pitch = 2 * best_pitch[0] - offset;
991 static void dct(AudioRNNContext *s, float *out, const float *in)
993 for (int i = 0; i < NB_BANDS; i++) {
996 for (int j = 0; j < NB_BANDS; j++) {
997 sum += in[j] * s->dct_table[j * NB_BANDS + i];
999 out[i] = sum * sqrtf(2.f / 22);
1003 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1004 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1007 float *ceps_0, *ceps_1, *ceps_2;
1008 float spec_variability = 0;
1010 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1011 float pitch_buf[PITCH_BUF_SIZE>>1];
1015 float tmp[NB_BANDS];
1016 float follow, logMax;
1018 frame_analysis(s, st, X, Ex, in);
1019 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1020 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1021 pre[0] = &st->pitch_buf[0];
1022 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1023 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1024 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1025 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1027 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1028 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1029 st->last_period = pitch_index;
1030 st->last_gain = gain;
1032 for (int i = 0; i < WINDOW_SIZE; i++)
1033 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1035 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1036 forward_transform(st, P, p);
1037 compute_band_energy(Ep, P);
1038 compute_band_corr(Exp, X, P);
1040 for (int i = 0; i < NB_BANDS; i++)
1041 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1045 for (int i = 0; i < NB_DELTA_CEPS; i++)
1046 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1048 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1049 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1050 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1054 for (int i = 0; i < NB_BANDS; i++) {
1055 Ly[i] = log10f(1e-2f + Ex[i]);
1056 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1057 logMax = FFMAX(logMax, Ly[i]);
1058 follow = FFMAX(follow-1.5, Ly[i]);
1063 /* If there's no audio, avoid messing up the state. */
1064 RNN_CLEAR(features, NB_FEATURES);
1068 dct(s, features, Ly);
1071 ceps_0 = st->cepstral_mem[st->memid];
1072 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1073 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1075 for (int i = 0; i < NB_BANDS; i++)
1076 ceps_0[i] = features[i];
1079 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1080 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1081 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1082 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1084 /* Spectral variability features. */
1085 if (st->memid == CEPS_MEM)
1088 for (int i = 0; i < CEPS_MEM; i++) {
1089 float mindist = 1e15f;
1090 for (int j = 0; j < CEPS_MEM; j++) {
1092 for (int k = 0; k < NB_BANDS; k++) {
1095 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1100 mindist = FFMIN(mindist, dist);
1103 spec_variability += mindist;
1106 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1111 static void interp_band_gain(float *g, const float *bandE)
1113 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1115 for (int i = 0; i < NB_BANDS - 1; i++) {
1116 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1118 for (int j = 0; j < band_size; j++) {
1119 float frac = (float)j / band_size;
1121 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1126 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1127 const float *Exp, const float *g)
1129 float newE[NB_BANDS];
1131 float norm[NB_BANDS];
1132 float rf[FREQ_SIZE] = {0};
1133 float normf[FREQ_SIZE]={0};
1135 for (int i = 0; i < NB_BANDS; i++) {
1136 if (Exp[i]>g[i]) r[i] = 1;
1137 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1138 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1139 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1141 interp_band_gain(rf, r);
1142 for (int i = 0; i < FREQ_SIZE; i++) {
1143 X[i].re += rf[i]*P[i].re;
1144 X[i].im += rf[i]*P[i].im;
1146 compute_band_energy(newE, X);
1147 for (int i = 0; i < NB_BANDS; i++) {
1148 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1150 interp_band_gain(normf, norm);
1151 for (int i = 0; i < FREQ_SIZE; i++) {
1152 X[i].re *= normf[i];
1153 X[i].im *= normf[i];
1157 static const float tansig_table[201] = {
1158 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1159 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1160 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1161 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1162 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1163 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1164 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1165 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1166 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1167 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1168 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1169 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1170 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1171 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1172 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1173 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1174 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1175 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1176 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1177 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1178 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1179 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1180 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1181 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1182 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1183 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1184 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1185 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1186 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1187 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1188 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1189 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1190 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1191 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1192 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1193 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1194 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1195 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1196 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1197 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1201 static inline float tansig_approx(float x)
1207 /* Tests are reversed to catch NaNs */
1212 /* Another check in case of -ffast-math */
1221 i = (int)floor(.5f+25*x);
1223 y = tansig_table[i];
1225 y = y + x*dy*(1 - y*x);
1229 static inline float sigmoid_approx(float x)
1231 return .5f + .5f*tansig_approx(.5f*x);
1234 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1236 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1238 for (int i = 0; i < N; i++) {
1239 /* Compute update gate. */
1240 float sum = layer->bias[i];
1242 for (int j = 0; j < M; j++)
1243 sum += layer->input_weights[j * stride + i] * input[j];
1245 output[i] = WEIGHTS_SCALE * sum;
1248 if (layer->activation == ACTIVATION_SIGMOID) {
1249 for (int i = 0; i < N; i++)
1250 output[i] = sigmoid_approx(output[i]);
1251 } else if (layer->activation == ACTIVATION_TANH) {
1252 for (int i = 0; i < N; i++)
1253 output[i] = tansig_approx(output[i]);
1254 } else if (layer->activation == ACTIVATION_RELU) {
1255 for (int i = 0; i < N; i++)
1256 output[i] = FFMAX(0, output[i]);
1262 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1264 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1265 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1266 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1267 const int M = gru->nb_inputs;
1268 const int N = gru->nb_neurons;
1269 const int AN = FFALIGN(N, 4);
1270 const int AM = FFALIGN(M, 4);
1271 const int stride = 3 * AN, istride = 3 * AM;
1273 for (int i = 0; i < N; i++) {
1274 /* Compute update gate. */
1275 float sum = gru->bias[i];
1277 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1278 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1279 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1282 for (int i = 0; i < N; i++) {
1283 /* Compute reset gate. */
1284 float sum = gru->bias[N + i];
1286 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1287 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1288 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1291 for (int i = 0; i < N; i++) {
1292 /* Compute output. */
1293 float sum = gru->bias[2 * N + i];
1295 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1296 for (int j = 0; j < N; j++)
1297 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1299 if (gru->activation == ACTIVATION_SIGMOID)
1300 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1301 else if (gru->activation == ACTIVATION_TANH)
1302 sum = tansig_approx(WEIGHTS_SCALE * sum);
1303 else if (gru->activation == ACTIVATION_RELU)
1304 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1307 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1310 RNN_COPY(state, h, N);
1313 #define INPUT_SIZE 42
1315 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1317 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1318 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1319 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1321 compute_dense(rnn->model->input_dense, dense_out, input);
1322 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1323 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1325 for (int i = 0; i < rnn->model->input_dense_size; i++)
1326 noise_input[i] = dense_out[i];
1327 for (int i = 0; i < rnn->model->vad_gru_size; i++)
1328 noise_input[i + rnn->model->input_dense_size] = rnn->vad_gru_state[i];
1329 for (int i = 0; i < INPUT_SIZE; i++)
1330 noise_input[i + rnn->model->input_dense_size + rnn->model->vad_gru_size] = input[i];
1332 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1334 for (int i = 0; i < rnn->model->vad_gru_size; i++)
1335 denoise_input[i] = rnn->vad_gru_state[i];
1336 for (int i = 0; i < rnn->model->noise_gru_size; i++)
1337 denoise_input[i + rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
1338 for (int i = 0; i < INPUT_SIZE; i++)
1339 denoise_input[i + rnn->model->vad_gru_size + rnn->model->noise_gru_size] = input[i];
1341 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1342 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1345 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1347 AVComplexFloat X[FREQ_SIZE];
1348 AVComplexFloat P[WINDOW_SIZE];
1349 float x[FRAME_SIZE];
1350 float Ex[NB_BANDS], Ep[NB_BANDS];
1351 float Exp[NB_BANDS];
1352 float features[NB_FEATURES];
1354 float gf[FREQ_SIZE];
1356 static const float a_hp[2] = {-1.99599, 0.99600};
1357 static const float b_hp[2] = {-2, 1};
1360 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1361 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1364 compute_rnn(s, &st->rnn, g, &vad_prob, features);
1365 pitch_filter(X, P, Ex, Ep, Exp, g);
1366 for (int i = 0; i < NB_BANDS; i++) {
1369 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1370 st->lastg[i] = g[i];
1373 interp_band_gain(gf, g);
1375 for (int i = 0; i < FREQ_SIZE; i++) {
1381 frame_synthesis(s, st, out, X);
1386 typedef struct ThreadData {
1390 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1392 AudioRNNContext *s = ctx->priv;
1393 ThreadData *td = arg;
1394 AVFrame *in = td->in;
1395 AVFrame *out = td->out;
1396 const int start = (out->channels * jobnr) / nb_jobs;
1397 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1399 for (int ch = start; ch < end; ch++) {
1400 rnnoise_channel(s, &s->st[ch],
1401 (float *)out->extended_data[ch],
1402 (const float *)in->extended_data[ch]);
1408 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1410 AVFilterContext *ctx = inlink->dst;
1411 AVFilterLink *outlink = ctx->outputs[0];
1412 AVFrame *out = NULL;
1415 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1418 return AVERROR(ENOMEM);
1422 td.in = in; td.out = out;
1423 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1424 ff_filter_get_nb_threads(ctx)));
1427 return ff_filter_frame(outlink, out);
1430 static int activate(AVFilterContext *ctx)
1432 AVFilterLink *inlink = ctx->inputs[0];
1433 AVFilterLink *outlink = ctx->outputs[0];
1437 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1439 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1444 return filter_frame(inlink, in);
1446 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1447 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1449 return FFERROR_NOT_READY;
1452 static av_cold int init(AVFilterContext *ctx)
1454 AudioRNNContext *s = ctx->priv;
1457 s->fdsp = avpriv_float_dsp_alloc(0);
1459 return AVERROR(ENOMEM);
1462 return AVERROR(EINVAL);
1463 f = av_fopen_utf8(s->model_name, "r");
1465 return AVERROR(EINVAL);
1467 s->model = rnnoise_model_from_file(f);
1470 return AVERROR(EINVAL);
1472 for (int i = 0; i < FRAME_SIZE; i++) {
1473 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1474 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1477 for (int i = 0; i < NB_BANDS; i++) {
1478 for (int j = 0; j < NB_BANDS; j++) {
1479 s->dct_table[i*NB_BANDS + j] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1481 s->dct_table[i*NB_BANDS + j] *= sqrtf(.5);
1488 static av_cold void uninit(AVFilterContext *ctx)
1490 AudioRNNContext *s = ctx->priv;
1493 rnnoise_model_free(s->model);
1497 for (int ch = 0; ch < s->channels; ch++) {
1498 av_freep(&s->st[ch].rnn.vad_gru_state);
1499 av_freep(&s->st[ch].rnn.noise_gru_state);
1500 av_freep(&s->st[ch].rnn.denoise_gru_state);
1501 av_tx_uninit(&s->st[ch].tx);
1502 av_tx_uninit(&s->st[ch].txi);
1508 static const AVFilterPad inputs[] = {
1511 .type = AVMEDIA_TYPE_AUDIO,
1512 .config_props = config_input,
1517 static const AVFilterPad outputs[] = {
1520 .type = AVMEDIA_TYPE_AUDIO,
1525 #define OFFSET(x) offsetof(AudioRNNContext, x)
1526 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1528 static const AVOption arnndn_options[] = {
1529 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1530 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1534 AVFILTER_DEFINE_CLASS(arnndn);
1536 AVFilter ff_af_arnndn = {
1538 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1539 .query_formats = query_formats,
1540 .priv_size = sizeof(AudioRNNContext),
1541 .priv_class = &arnndn_class,
1542 .activate = activate,
1547 .flags = AVFILTER_FLAG_SLICE_THREADS,