]> git.sesse.net Git - ffmpeg/blobdiff - libavfilter/af_arnndn.c
avfilter/vf_scale: store the offset in a local variable before adding it
[ffmpeg] / libavfilter / af_arnndn.c
index 781d0dc9347a8b5d3ae818e929a5868edf80588a..0c70a32271ffe0d3c06e43e1b4e9a31374811faa 100644 (file)
@@ -36,6 +36,7 @@
 #include "libavutil/avassert.h"
 #include "libavutil/avstring.h"
 #include "libavutil/float_dsp.h"
+#include "libavutil/mem_internal.h"
 #include "libavutil/opt.h"
 #include "libavutil/tx.h"
 #include "avfilter.h"
@@ -127,7 +128,8 @@ typedef struct DenoiseState {
     int last_period;
     float mem_hp_x[2];
     float lastg[NB_BANDS];
-    RNNState rnn;
+    float history[FRAME_SIZE];
+    RNNState rnn[2];
     AVTXContext *tx, *txi;
     av_tx_fn tx_fn, txi_fn;
 } DenoiseState;
@@ -136,14 +138,15 @@ typedef struct AudioRNNContext {
     const AVClass *class;
 
     char *model_name;
+    float mix;
 
     int channels;
     DenoiseState *st;
 
     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
-    float dct_table[NB_BANDS*NB_BANDS];
+    DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
 
-    RNNModel *model;
+    RNNModel *model[2];
 
     AVFloatDSPContext *fdsp;
 } AudioRNNContext;
@@ -182,9 +185,9 @@ static void rnnoise_model_free(RNNModel *model)
     av_free(model);
 }
 
-static RNNModel *rnnoise_model_from_file(FILE *f)
+static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
 {
-    RNNModel *ret;
+    RNNModel *ret = NULL;
     DenseLayer *input_dense;
     GRULayer *vad_gru;
     GRULayer *noise_gru;
@@ -194,17 +197,17 @@ static RNNModel *rnnoise_model_from_file(FILE *f)
     int in;
 
     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
-        return NULL;
+        return AVERROR_INVALIDDATA;
 
     ret = av_calloc(1, sizeof(RNNModel));
     if (!ret)
-        return NULL;
+        return AVERROR(ENOMEM);
 
 #define ALLOC_LAYER(type, name) \
     name = av_calloc(1, sizeof(type)); \
     if (!name) { \
         rnnoise_model_free(ret); \
-        return NULL; \
+        return AVERROR(ENOMEM); \
     } \
     ret->name = name
 
@@ -218,7 +221,7 @@ static RNNModel *rnnoise_model_from_file(FILE *f)
 #define INPUT_VAL(name) do { \
     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
         rnnoise_model_free(ret); \
-        return NULL; \
+        return AVERROR(EINVAL); \
     } \
     name = in; \
     } while (0)
@@ -242,13 +245,13 @@ static RNNModel *rnnoise_model_from_file(FILE *f)
     float *values = av_calloc((len), sizeof(float)); \
     if (!values) { \
         rnnoise_model_free(ret); \
-        return NULL; \
+        return AVERROR(ENOMEM); \
     } \
     name = values; \
     for (int i = 0; i < (len); i++) { \
         if (fscanf(f, "%d", &in) != 1) { \
             rnnoise_model_free(ret); \
-            return NULL; \
+            return AVERROR(EINVAL); \
         } \
         values[i] = in; \
     } \
@@ -258,7 +261,7 @@ static RNNModel *rnnoise_model_from_file(FILE *f)
     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
     if (!values) { \
         rnnoise_model_free(ret); \
-        return NULL; \
+        return AVERROR(ENOMEM); \
     } \
     name = values; \
     for (int k = 0; k < (len0); k++) { \
@@ -266,7 +269,7 @@ static RNNModel *rnnoise_model_from_file(FILE *f)
             for (int j = 0; j < (len1); j++) { \
                 if (fscanf(f, "%d", &in) != 1) { \
                     rnnoise_model_free(ret); \
-                    return NULL; \
+                    return AVERROR(EINVAL); \
                 } \
                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
             } \
@@ -274,13 +277,24 @@ static RNNModel *rnnoise_model_from_file(FILE *f)
     } \
     } while (0)
 
+#define NEW_LINE() do { \
+    int c; \
+    while ((c = fgetc(f)) != EOF) { \
+        if (c == '\n') \
+        break; \
+    } \
+    } while (0)
+
 #define INPUT_DENSE(name) do { \
     INPUT_VAL(name->nb_inputs); \
     INPUT_VAL(name->nb_neurons); \
     ret->name ## _size = name->nb_neurons; \
     INPUT_ACTIVATION(name->activation); \
+    NEW_LINE(); \
     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
+    NEW_LINE(); \
     INPUT_ARRAY(name->bias, name->nb_neurons); \
+    NEW_LINE(); \
     } while (0)
 
 #define INPUT_GRU(name) do { \
@@ -288,9 +302,13 @@ static RNNModel *rnnoise_model_from_file(FILE *f)
     INPUT_VAL(name->nb_neurons); \
     ret->name ## _size = name->nb_neurons; \
     INPUT_ACTIVATION(name->activation); \
+    NEW_LINE(); \
     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
+    NEW_LINE(); \
     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
+    NEW_LINE(); \
     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
+    NEW_LINE(); \
     } while (0)
 
     INPUT_DENSE(input_dense);
@@ -300,7 +318,14 @@ static RNNModel *rnnoise_model_from_file(FILE *f)
     INPUT_DENSE(denoise_output);
     INPUT_DENSE(vad_output);
 
-    return ret;
+    if (vad_output->nb_neurons != 1) {
+        rnnoise_model_free(ret);
+        return AVERROR(EINVAL);
+    }
+
+    *rnn = ret;
+
+    return 0;
 }
 
 static int query_formats(AVFilterContext *ctx)
@@ -342,27 +367,34 @@ static int config_input(AVFilterLink *inlink)
 
     s->channels = inlink->channels;
 
-    s->st = av_calloc(s->channels, sizeof(DenoiseState));
+    if (!s->st)
+        s->st = av_calloc(s->channels, sizeof(DenoiseState));
     if (!s->st)
         return AVERROR(ENOMEM);
 
     for (int i = 0; i < s->channels; i++) {
         DenoiseState *st = &s->st[i];
 
-        st->rnn.model = s->model;
-        st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
-        st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
-        st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
-        if (!st->rnn.vad_gru_state ||
-            !st->rnn.noise_gru_state ||
-            !st->rnn.denoise_gru_state)
+        st->rnn[0].model = s->model[0];
+        st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
+        st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
+        st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
+        if (!st->rnn[0].vad_gru_state ||
+            !st->rnn[0].noise_gru_state ||
+            !st->rnn[0].denoise_gru_state)
             return AVERROR(ENOMEM);
+    }
 
-        ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
+    for (int i = 0; i < s->channels; i++) {
+        DenoiseState *st = &s->st[i];
+
+        if (!st->tx)
+            ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
         if (ret < 0)
             return ret;
 
-        ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
+        if (!st->txi)
+            ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
         if (ret < 0)
             return ret;
     }
@@ -408,8 +440,7 @@ static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat
     AVComplexFloat x[WINDOW_SIZE];
     AVComplexFloat y[WINDOW_SIZE];
 
-    for (int i = 0; i < FREQ_SIZE; i++)
-        x[i] = in[i];
+    RNN_COPY(x, in, FREQ_SIZE);
 
     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
         x[i].re =  x[WINDOW_SIZE - i].re;
@@ -492,12 +523,18 @@ static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat
 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
 {
     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
+    const float *src = st->history;
+    const float mix = s->mix;
+    const float imix = 1.f - FFMAX(mix, 0.f);
 
     inverse_transform(st, x, y);
     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
     RNN_COPY(out, x, FRAME_SIZE);
     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
+
+    for (int n = 0; n < FRAME_SIZE; n++)
+        out[n] = out[n] * mix + src[n] * imix;
 }
 
 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
@@ -776,9 +813,9 @@ static float compute_pitch_gain(float xy, float xx, float yy)
     return xy / sqrtf(1.f + xx * yy);
 }
 
-static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
-static const float remove_doubling(float *x, int maxperiod, int minperiod,
-                                   int N, int *T0_, int prev_period, float prev_gain)
+static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
+static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
+                             int *T0_, int prev_period, float prev_gain)
 {
     int k, i, T, T0;
     float g, g0;
@@ -987,11 +1024,9 @@ static void pitch_search(const float *x_lp, float *y,
 static void dct(AudioRNNContext *s, float *out, const float *in)
 {
     for (int i = 0; i < NB_BANDS; i++) {
-        float sum = 0.f;
+        float sum;
 
-        for (int j = 0; j < NB_BANDS; j++) {
-            sum += in[j] * s->dct_table[j * NB_BANDS + i];
-        }
+        sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
         out[i] = sum * sqrtf(2.f / 22);
     }
 }
@@ -1002,7 +1037,7 @@ static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComple
     float E = 0;
     float *ceps_0, *ceps_1, *ceps_2;
     float spec_variability = 0;
-    float Ly[NB_BANDS];
+    LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
     float pitch_buf[PITCH_BUF_SIZE>>1];
     int pitch_index;
@@ -1318,37 +1353,37 @@ static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *
     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
 
-    for (int i = 0; i < rnn->model->input_dense_size; i++)
-        noise_input[i] = dense_out[i];
-    for (int i = 0; i < rnn->model->vad_gru_size; i++)
-        noise_input[i + rnn->model->input_dense_size] = rnn->vad_gru_state[i];
-    for (int i = 0; i < INPUT_SIZE; i++)
-        noise_input[i + rnn->model->input_dense_size + rnn->model->vad_gru_size] = input[i];
+    memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
+    memcpy(noise_input + rnn->model->input_dense_size,
+           rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
+    memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
+           input, INPUT_SIZE * sizeof(float));
 
     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
 
-    for (int i = 0; i < rnn->model->vad_gru_size; i++)
-        denoise_input[i] = rnn->vad_gru_state[i];
-    for (int i = 0; i < rnn->model->noise_gru_size; i++)
-        denoise_input[i + rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
-    for (int i = 0; i < INPUT_SIZE; i++)
-        denoise_input[i + rnn->model->vad_gru_size + rnn->model->noise_gru_size] = input[i];
+    memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
+    memcpy(denoise_input + rnn->model->vad_gru_size,
+           rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
+    memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
+           input, INPUT_SIZE * sizeof(float));
 
     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
 }
 
-static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
+static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
+                             int disabled)
 {
     AVComplexFloat X[FREQ_SIZE];
     AVComplexFloat P[WINDOW_SIZE];
     float x[FRAME_SIZE];
     float Ex[NB_BANDS], Ep[NB_BANDS];
-    float Exp[NB_BANDS];
+    LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
     float features[NB_FEATURES];
     float g[NB_BANDS];
     float gf[FREQ_SIZE];
     float vad_prob = 0;
+    float *history = st->history;
     static const float a_hp[2] = {-1.99599, 0.99600};
     static const float b_hp[2] = {-2, 1};
     int silence;
@@ -1356,8 +1391,8 @@ static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, c
     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
 
-    if (!silence) {
-        compute_rnn(s, &st->rnn, g, &vad_prob, features);
+    if (!silence && !disabled) {
+        compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
         pitch_filter(X, P, Ex, Ep, Exp, g);
         for (int i = 0; i < NB_BANDS; i++) {
             float alpha = .6f;
@@ -1375,6 +1410,7 @@ static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, c
     }
 
     frame_synthesis(s, st, out, X);
+    memcpy(history, in, FRAME_SIZE * sizeof(*history));
 
     return vad_prob;
 }
@@ -1395,7 +1431,8 @@ static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_j
     for (int ch = start; ch < end; ch++) {
         rnnoise_channel(s, &s->st[ch],
                         (float *)out->extended_data[ch],
-                        (const float *)in->extended_data[ch]);
+                        (const float *)in->extended_data[ch],
+                        ctx->is_disabled);
     }
 
     return 0;
@@ -1445,25 +1482,40 @@ static int activate(AVFilterContext *ctx)
     return FFERROR_NOT_READY;
 }
 
-static av_cold int init(AVFilterContext *ctx)
+static int open_model(AVFilterContext *ctx, RNNModel **model)
 {
     AudioRNNContext *s = ctx->priv;
+    int ret;
     FILE *f;
 
-    s->fdsp = avpriv_float_dsp_alloc(0);
-    if (!s->fdsp)
-        return AVERROR(ENOMEM);
-
     if (!s->model_name)
         return AVERROR(EINVAL);
     f = av_fopen_utf8(s->model_name, "r");
-    if (!f)
+    if (!f) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
         return AVERROR(EINVAL);
+    }
 
-    s->model = rnnoise_model_from_file(f);
+    ret = rnnoise_model_from_file(f, model);
     fclose(f);
-    if (!s->model)
-        return AVERROR(EINVAL);
+    if (!*model || ret < 0)
+        return ret;
+
+    return 0;
+}
+
+static av_cold int init(AVFilterContext *ctx)
+{
+    AudioRNNContext *s = ctx->priv;
+    int ret;
+
+    s->fdsp = avpriv_float_dsp_alloc(0);
+    if (!s->fdsp)
+        return AVERROR(ENOMEM);
+
+    ret = open_model(ctx, &s->model[0]);
+    if (ret < 0)
+        return ret;
 
     for (int i = 0; i < FRAME_SIZE; i++) {
         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
@@ -1472,31 +1524,68 @@ static av_cold int init(AVFilterContext *ctx)
 
     for (int i = 0; i < NB_BANDS; i++) {
         for (int j = 0; j < NB_BANDS; j++) {
-            s->dct_table[i*NB_BANDS + j] = cosf((i + .5f) * j * M_PI / NB_BANDS);
+            s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
             if (j == 0)
-                s->dct_table[i*NB_BANDS + j] *= sqrtf(.5);
+                s->dct_table[j][i] *= sqrtf(.5);
         }
     }
 
     return 0;
 }
 
+static void free_model(AVFilterContext *ctx, int n)
+{
+    AudioRNNContext *s = ctx->priv;
+
+    rnnoise_model_free(s->model[n]);
+    s->model[n] = NULL;
+
+    for (int ch = 0; ch < s->channels && s->st; ch++) {
+        av_freep(&s->st[ch].rnn[n].vad_gru_state);
+        av_freep(&s->st[ch].rnn[n].noise_gru_state);
+        av_freep(&s->st[ch].rnn[n].denoise_gru_state);
+    }
+}
+
+static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
+                           char *res, int res_len, int flags)
+{
+    AudioRNNContext *s = ctx->priv;
+    int ret;
+
+    ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
+    if (ret < 0)
+        return ret;
+
+    ret = open_model(ctx, &s->model[1]);
+    if (ret < 0)
+        return ret;
+
+    FFSWAP(RNNModel *, s->model[0], s->model[1]);
+    for (int ch = 0; ch < s->channels; ch++)
+        FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
+
+    ret = config_input(ctx->inputs[0]);
+    if (ret < 0) {
+        for (int ch = 0; ch < s->channels; ch++)
+            FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
+        FFSWAP(RNNModel *, s->model[0], s->model[1]);
+        return ret;
+    }
+
+    free_model(ctx, 1);
+    return 0;
+}
+
 static av_cold void uninit(AVFilterContext *ctx)
 {
     AudioRNNContext *s = ctx->priv;
 
     av_freep(&s->fdsp);
-    rnnoise_model_free(s->model);
-    s->model = NULL;
-
-    if (s->st) {
-        for (int ch = 0; ch < s->channels; ch++) {
-            av_freep(&s->st[ch].rnn.vad_gru_state);
-            av_freep(&s->st[ch].rnn.noise_gru_state);
-            av_freep(&s->st[ch].rnn.denoise_gru_state);
-            av_tx_uninit(&s->st[ch].tx);
-            av_tx_uninit(&s->st[ch].txi);
-        }
+    free_model(ctx, 0);
+    for (int ch = 0; ch < s->channels && s->st; ch++) {
+        av_tx_uninit(&s->st[ch].tx);
+        av_tx_uninit(&s->st[ch].txi);
     }
     av_freep(&s->st);
 }
@@ -1519,17 +1608,18 @@ static const AVFilterPad outputs[] = {
 };
 
 #define OFFSET(x) offsetof(AudioRNNContext, x)
-#define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
+#define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
 
 static const AVOption arnndn_options[] = {
     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
+    { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
     { NULL }
 };
 
 AVFILTER_DEFINE_CLASS(arnndn);
 
-AVFilter ff_af_arnndn = {
+const AVFilter ff_af_arnndn = {
     .name          = "arnndn",
     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
     .query_formats = query_formats,
@@ -1540,5 +1630,7 @@ AVFilter ff_af_arnndn = {
     .uninit        = uninit,
     .inputs        = inputs,
     .outputs       = outputs,
-    .flags         = AVFILTER_FLAG_SLICE_THREADS,
+    .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
+                     AVFILTER_FLAG_SLICE_THREADS,
+    .process_command = process_command,
 };