]> git.sesse.net Git - ffmpeg/blobdiff - libavutil/tx_template.c
doc/APIchanges: add hashes and version numbers for recent entries
[ffmpeg] / libavutil / tx_template.c
index f78e7abfb111527aedd64d01317dcf1b1d58532e..cad66a8bc0932ab2e180b412c1f41ce212a78305 100644 (file)
@@ -40,6 +40,8 @@ COSTABLE(32768);
 COSTABLE(65536);
 COSTABLE(131072);
 DECLARE_ALIGNED(32, FFTComplex, TX_NAME(ff_cos_53))[4];
+DECLARE_ALIGNED(32, FFTComplex, TX_NAME(ff_cos_7))[3];
+DECLARE_ALIGNED(32, FFTComplex, TX_NAME(ff_cos_9))[4];
 
 static FFTSample * const cos_tabs[18] = {
     NULL,
@@ -103,10 +105,26 @@ static av_cold void ff_init_53_tabs(void)
     TX_NAME(ff_cos_53)[3] = (FFTComplex){ RESCALE(cos(2 * M_PI / 10)), RESCALE(sin(2 * M_PI / 10)) };
 }
 
+static av_cold void ff_init_7_tabs(void)
+{
+    TX_NAME(ff_cos_7)[0] = (FFTComplex){ RESCALE(cos(2 * M_PI /  7)), RESCALE(sin(2 * M_PI /  7)) };
+    TX_NAME(ff_cos_7)[1] = (FFTComplex){ RESCALE(sin(2 * M_PI / 28)), RESCALE(cos(2 * M_PI / 28)) };
+    TX_NAME(ff_cos_7)[2] = (FFTComplex){ RESCALE(cos(2 * M_PI / 14)), RESCALE(sin(2 * M_PI / 14)) };
+}
+
+static av_cold void ff_init_9_tabs(void)
+{
+    TX_NAME(ff_cos_9)[0] = (FFTComplex){ RESCALE(cos(2 * M_PI /  3)), RESCALE( sin(2 * M_PI /  3)) };
+    TX_NAME(ff_cos_9)[1] = (FFTComplex){ RESCALE(cos(2 * M_PI /  9)), RESCALE( sin(2 * M_PI /  9)) };
+    TX_NAME(ff_cos_9)[2] = (FFTComplex){ RESCALE(cos(2 * M_PI / 36)), RESCALE( sin(2 * M_PI / 36)) };
+    TX_NAME(ff_cos_9)[3] = (FFTComplex){ TX_NAME(ff_cos_9)[1].re + TX_NAME(ff_cos_9)[2].im,
+                                         TX_NAME(ff_cos_9)[1].im - TX_NAME(ff_cos_9)[2].re };
+}
+
 static CosTabsInitOnce cos_tabs_init_once[] = {
     { ff_init_53_tabs, AV_ONCE_INIT },
-    { NULL },
-    { NULL },
+    { ff_init_7_tabs, AV_ONCE_INIT },
+    { ff_init_9_tabs, AV_ONCE_INIT },
     { NULL },
     { init_cos_tabs_16, AV_ONCE_INIT },
     { init_cos_tabs_32, AV_ONCE_INIT },
@@ -204,6 +222,217 @@ DECL_FFT5(fft5_m1,  0,  6, 12,  3,  9)
 DECL_FFT5(fft5_m2, 10,  1,  7, 13,  4)
 DECL_FFT5(fft5_m3,  5, 11,  2,  8, 14)
 
+static av_always_inline void fft7(FFTComplex *out, FFTComplex *in,
+                                  ptrdiff_t stride)
+{
+    FFTComplex t[6], z[3];
+    const FFTComplex *tab = TX_NAME(ff_cos_7);
+#ifdef TX_INT32
+    int64_t mtmp[12];
+#endif
+
+    BF(t[1].re, t[0].re, in[1].re, in[6].re);
+    BF(t[1].im, t[0].im, in[1].im, in[6].im);
+    BF(t[3].re, t[2].re, in[2].re, in[5].re);
+    BF(t[3].im, t[2].im, in[2].im, in[5].im);
+    BF(t[5].re, t[4].re, in[3].re, in[4].re);
+    BF(t[5].im, t[4].im, in[3].im, in[4].im);
+
+    out[0*stride].re = in[0].re + t[0].re + t[2].re + t[4].re;
+    out[0*stride].im = in[0].im + t[0].im + t[2].im + t[4].im;
+
+#ifdef TX_INT32 /* NOTE: it's possible to do this with 16 mults but 72 adds */
+    mtmp[ 0] = ((int64_t)tab[0].re)*t[0].re - ((int64_t)tab[2].re)*t[4].re;
+    mtmp[ 1] = ((int64_t)tab[0].re)*t[4].re - ((int64_t)tab[1].re)*t[0].re;
+    mtmp[ 2] = ((int64_t)tab[0].re)*t[2].re - ((int64_t)tab[2].re)*t[0].re;
+    mtmp[ 3] = ((int64_t)tab[0].re)*t[0].im - ((int64_t)tab[1].re)*t[2].im;
+    mtmp[ 4] = ((int64_t)tab[0].re)*t[4].im - ((int64_t)tab[1].re)*t[0].im;
+    mtmp[ 5] = ((int64_t)tab[0].re)*t[2].im - ((int64_t)tab[2].re)*t[0].im;
+
+    mtmp[ 6] = ((int64_t)tab[2].im)*t[1].im + ((int64_t)tab[1].im)*t[5].im;
+    mtmp[ 7] = ((int64_t)tab[0].im)*t[5].im + ((int64_t)tab[2].im)*t[3].im;
+    mtmp[ 8] = ((int64_t)tab[2].im)*t[5].im + ((int64_t)tab[1].im)*t[3].im;
+    mtmp[ 9] = ((int64_t)tab[0].im)*t[1].re + ((int64_t)tab[1].im)*t[3].re;
+    mtmp[10] = ((int64_t)tab[2].im)*t[3].re + ((int64_t)tab[0].im)*t[5].re;
+    mtmp[11] = ((int64_t)tab[2].im)*t[1].re + ((int64_t)tab[1].im)*t[5].re;
+
+    z[0].re = (int32_t)(mtmp[ 0] - ((int64_t)tab[1].re)*t[2].re + 0x40000000 >> 31);
+    z[1].re = (int32_t)(mtmp[ 1] - ((int64_t)tab[2].re)*t[2].re + 0x40000000 >> 31);
+    z[2].re = (int32_t)(mtmp[ 2] - ((int64_t)tab[1].re)*t[4].re + 0x40000000 >> 31);
+    z[0].im = (int32_t)(mtmp[ 3] - ((int64_t)tab[2].re)*t[4].im + 0x40000000 >> 31);
+    z[1].im = (int32_t)(mtmp[ 4] - ((int64_t)tab[2].re)*t[2].im + 0x40000000 >> 31);
+    z[2].im = (int32_t)(mtmp[ 5] - ((int64_t)tab[1].re)*t[4].im + 0x40000000 >> 31);
+
+    t[0].re = (int32_t)(mtmp[ 6] - ((int64_t)tab[0].im)*t[3].im + 0x40000000 >> 31);
+    t[2].re = (int32_t)(mtmp[ 7] - ((int64_t)tab[1].im)*t[1].im + 0x40000000 >> 31);
+    t[4].re = (int32_t)(mtmp[ 8] + ((int64_t)tab[0].im)*t[1].im + 0x40000000 >> 31);
+    t[0].im = (int32_t)(mtmp[ 9] + ((int64_t)tab[2].im)*t[5].re + 0x40000000 >> 31);
+    t[2].im = (int32_t)(mtmp[10] - ((int64_t)tab[1].im)*t[1].re + 0x40000000 >> 31);
+    t[4].im = (int32_t)(mtmp[11] - ((int64_t)tab[0].im)*t[3].re + 0x40000000 >> 31);
+#else
+    z[0].re = tab[0].re*t[0].re - tab[2].re*t[4].re - tab[1].re*t[2].re;
+    z[1].re = tab[0].re*t[4].re - tab[1].re*t[0].re - tab[2].re*t[2].re;
+    z[2].re = tab[0].re*t[2].re - tab[2].re*t[0].re - tab[1].re*t[4].re;
+    z[0].im = tab[0].re*t[0].im - tab[1].re*t[2].im - tab[2].re*t[4].im;
+    z[1].im = tab[0].re*t[4].im - tab[1].re*t[0].im - tab[2].re*t[2].im;
+    z[2].im = tab[0].re*t[2].im - tab[2].re*t[0].im - tab[1].re*t[4].im;
+
+    /* It's possible to do t[4].re and t[0].im with 2 multiplies only by
+     * multiplying the sum of all with the average of the twiddles */
+
+    t[0].re = tab[2].im*t[1].im + tab[1].im*t[5].im - tab[0].im*t[3].im;
+    t[2].re = tab[0].im*t[5].im + tab[2].im*t[3].im - tab[1].im*t[1].im;
+    t[4].re = tab[2].im*t[5].im + tab[1].im*t[3].im + tab[0].im*t[1].im;
+    t[0].im = tab[0].im*t[1].re + tab[1].im*t[3].re + tab[2].im*t[5].re;
+    t[2].im = tab[2].im*t[3].re + tab[0].im*t[5].re - tab[1].im*t[1].re;
+    t[4].im = tab[2].im*t[1].re + tab[1].im*t[5].re - tab[0].im*t[3].re;
+#endif
+
+    BF(t[1].re, z[0].re, z[0].re, t[4].re);
+    BF(t[3].re, z[1].re, z[1].re, t[2].re);
+    BF(t[5].re, z[2].re, z[2].re, t[0].re);
+    BF(t[1].im, z[0].im, z[0].im, t[0].im);
+    BF(t[3].im, z[1].im, z[1].im, t[2].im);
+    BF(t[5].im, z[2].im, z[2].im, t[4].im);
+
+    out[1*stride].re = in[0].re + z[0].re;
+    out[1*stride].im = in[0].im + t[1].im;
+    out[2*stride].re = in[0].re + t[3].re;
+    out[2*stride].im = in[0].im + z[1].im;
+    out[3*stride].re = in[0].re + z[2].re;
+    out[3*stride].im = in[0].im + t[5].im;
+    out[4*stride].re = in[0].re + t[5].re;
+    out[4*stride].im = in[0].im + z[2].im;
+    out[5*stride].re = in[0].re + z[1].re;
+    out[5*stride].im = in[0].im + t[3].im;
+    out[6*stride].re = in[0].re + t[1].re;
+    out[6*stride].im = in[0].im + z[0].im;
+}
+
+static av_always_inline void fft9(FFTComplex *out, FFTComplex *in,
+                                  ptrdiff_t stride)
+{
+    const FFTComplex *tab = TX_NAME(ff_cos_9);
+    FFTComplex t[16], w[4], x[5], y[5], z[2];
+#ifdef TX_INT32
+    int64_t mtmp[12];
+#endif
+
+    BF(t[1].re, t[0].re, in[1].re, in[8].re);
+    BF(t[1].im, t[0].im, in[1].im, in[8].im);
+    BF(t[3].re, t[2].re, in[2].re, in[7].re);
+    BF(t[3].im, t[2].im, in[2].im, in[7].im);
+    BF(t[5].re, t[4].re, in[3].re, in[6].re);
+    BF(t[5].im, t[4].im, in[3].im, in[6].im);
+    BF(t[7].re, t[6].re, in[4].re, in[5].re);
+    BF(t[7].im, t[6].im, in[4].im, in[5].im);
+
+    w[0].re = t[0].re - t[6].re;
+    w[0].im = t[0].im - t[6].im;
+    w[1].re = t[2].re - t[6].re;
+    w[1].im = t[2].im - t[6].im;
+    w[2].re = t[1].re - t[7].re;
+    w[2].im = t[1].im - t[7].im;
+    w[3].re = t[3].re + t[7].re;
+    w[3].im = t[3].im + t[7].im;
+
+    z[0].re = in[0].re + t[4].re;
+    z[0].im = in[0].im + t[4].im;
+
+    z[1].re = t[0].re + t[2].re + t[6].re;
+    z[1].im = t[0].im + t[2].im + t[6].im;
+
+    out[0*stride].re = z[0].re + z[1].re;
+    out[0*stride].im = z[0].im + z[1].im;
+
+#ifdef TX_INT32
+    mtmp[0] = t[1].re - t[3].re + t[7].re;
+    mtmp[1] = t[1].im - t[3].im + t[7].im;
+
+    y[3].re = (int32_t)(((int64_t)tab[0].im)*mtmp[0] + 0x40000000 >> 31);
+    y[3].im = (int32_t)(((int64_t)tab[0].im)*mtmp[1] + 0x40000000 >> 31);
+
+    mtmp[0] = (int32_t)(((int64_t)tab[0].re)*z[1].re + 0x40000000 >> 31);
+    mtmp[1] = (int32_t)(((int64_t)tab[0].re)*z[1].im + 0x40000000 >> 31);
+    mtmp[2] = (int32_t)(((int64_t)tab[0].re)*t[4].re + 0x40000000 >> 31);
+    mtmp[3] = (int32_t)(((int64_t)tab[0].re)*t[4].im + 0x40000000 >> 31);
+
+    x[3].re = z[0].re  + (int32_t)mtmp[0];
+    x[3].im = z[0].im  + (int32_t)mtmp[1];
+    z[0].re = in[0].re + (int32_t)mtmp[2];
+    z[0].im = in[0].im + (int32_t)mtmp[3];
+
+    mtmp[0] = ((int64_t)tab[1].re)*w[0].re;
+    mtmp[1] = ((int64_t)tab[1].re)*w[0].im;
+    mtmp[2] = ((int64_t)tab[2].im)*w[0].re;
+    mtmp[3] = ((int64_t)tab[2].im)*w[0].im;
+    mtmp[4] = ((int64_t)tab[1].im)*w[2].re;
+    mtmp[5] = ((int64_t)tab[1].im)*w[2].im;
+    mtmp[6] = ((int64_t)tab[2].re)*w[2].re;
+    mtmp[7] = ((int64_t)tab[2].re)*w[2].im;
+
+    x[1].re = (int32_t)(mtmp[0] + ((int64_t)tab[2].im)*w[1].re + 0x40000000 >> 31);
+    x[1].im = (int32_t)(mtmp[1] + ((int64_t)tab[2].im)*w[1].im + 0x40000000 >> 31);
+    x[2].re = (int32_t)(mtmp[2] - ((int64_t)tab[3].re)*w[1].re + 0x40000000 >> 31);
+    x[2].im = (int32_t)(mtmp[3] - ((int64_t)tab[3].re)*w[1].im + 0x40000000 >> 31);
+    y[1].re = (int32_t)(mtmp[4] + ((int64_t)tab[2].re)*w[3].re + 0x40000000 >> 31);
+    y[1].im = (int32_t)(mtmp[5] + ((int64_t)tab[2].re)*w[3].im + 0x40000000 >> 31);
+    y[2].re = (int32_t)(mtmp[6] - ((int64_t)tab[3].im)*w[3].re + 0x40000000 >> 31);
+    y[2].im = (int32_t)(mtmp[7] - ((int64_t)tab[3].im)*w[3].im + 0x40000000 >> 31);
+
+    y[0].re = (int32_t)(((int64_t)tab[0].im)*t[5].re + 0x40000000 >> 31);
+    y[0].im = (int32_t)(((int64_t)tab[0].im)*t[5].im + 0x40000000 >> 31);
+
+#else
+    y[3].re = tab[0].im*(t[1].re - t[3].re + t[7].re);
+    y[3].im = tab[0].im*(t[1].im - t[3].im + t[7].im);
+
+    x[3].re = z[0].re  + tab[0].re*z[1].re;
+    x[3].im = z[0].im  + tab[0].re*z[1].im;
+    z[0].re = in[0].re + tab[0].re*t[4].re;
+    z[0].im = in[0].im + tab[0].re*t[4].im;
+
+    x[1].re = tab[1].re*w[0].re + tab[2].im*w[1].re;
+    x[1].im = tab[1].re*w[0].im + tab[2].im*w[1].im;
+    x[2].re = tab[2].im*w[0].re - tab[3].re*w[1].re;
+    x[2].im = tab[2].im*w[0].im - tab[3].re*w[1].im;
+    y[1].re = tab[1].im*w[2].re + tab[2].re*w[3].re;
+    y[1].im = tab[1].im*w[2].im + tab[2].re*w[3].im;
+    y[2].re = tab[2].re*w[2].re - tab[3].im*w[3].re;
+    y[2].im = tab[2].re*w[2].im - tab[3].im*w[3].im;
+
+    y[0].re = tab[0].im*t[5].re;
+    y[0].im = tab[0].im*t[5].im;
+#endif
+
+    x[4].re = x[1].re + x[2].re;
+    x[4].im = x[1].im + x[2].im;
+
+    y[4].re = y[1].re - y[2].re;
+    y[4].im = y[1].im - y[2].im;
+    x[1].re = z[0].re + x[1].re;
+    x[1].im = z[0].im + x[1].im;
+    y[1].re = y[0].re + y[1].re;
+    y[1].im = y[0].im + y[1].im;
+    x[2].re = z[0].re + x[2].re;
+    x[2].im = z[0].im + x[2].im;
+    y[2].re = y[2].re - y[0].re;
+    y[2].im = y[2].im - y[0].im;
+    x[4].re = z[0].re - x[4].re;
+    x[4].im = z[0].im - x[4].im;
+    y[4].re = y[0].re - y[4].re;
+    y[4].im = y[0].im - y[4].im;
+
+    out[1*stride] = (FFTComplex){ x[1].re + y[1].im, x[1].im - y[1].re };
+    out[2*stride] = (FFTComplex){ x[2].re + y[2].im, x[2].im - y[2].re };
+    out[3*stride] = (FFTComplex){ x[3].re + y[3].im, x[3].im - y[3].re };
+    out[4*stride] = (FFTComplex){ x[4].re + y[4].im, x[4].im - y[4].re };
+    out[5*stride] = (FFTComplex){ x[4].re - y[4].im, x[4].im + y[4].re };
+    out[6*stride] = (FFTComplex){ x[3].re - y[3].im, x[3].im + y[3].re };
+    out[7*stride] = (FFTComplex){ x[2].re - y[2].im, x[2].im + y[2].re };
+    out[8*stride] = (FFTComplex){ x[1].re - y[1].im, x[1].im + y[1].re };
+}
+
 static av_always_inline void fft15(FFTComplex *out, FFTComplex *in,
                                    ptrdiff_t stride)
 {
@@ -364,7 +593,7 @@ static void compound_fft_##N##xM(AVTXContext *s, void *_out,                   \
     for (int i = 0; i < m; i++) {                                              \
         for (int j = 0; j < N; j++)                                            \
             fft##N##in[j] = in[in_map[i*N + j]];                               \
-        fft##N(s->tmp + s->revtab[i], fft##N##in, m);                          \
+        fft##N(s->tmp + s->revtab_c[i], fft##N##in, m);                        \
     }                                                                          \
                                                                                \
     for (int i = 0; i < N; i++)                                                \
@@ -376,6 +605,8 @@ static void compound_fft_##N##xM(AVTXContext *s, void *_out,                   \
 
 DECL_COMP_FFT(3)
 DECL_COMP_FFT(5)
+DECL_COMP_FFT(7)
+DECL_COMP_FFT(9)
 DECL_COMP_FFT(15)
 
 static void split_radix_fft(AVTXContext *s, void *_out, void *_in,
@@ -393,16 +624,16 @@ static void split_radix_fft(AVTXContext *s, void *_out, void *_in,
 
         do {
             tmp = out[src];
-            dst = s->revtab[src];
+            dst = s->revtab_c[src];
             do {
                 FFSWAP(FFTComplex, tmp, out[dst]);
-                dst = s->revtab[dst];
+                dst = s->revtab_c[dst];
             } while (dst != src); /* Can be > as well, but is less predictable */
             out[dst] = tmp;
         } while ((src = *inplace_idx++));
     } else {
         for (int i = 0; i < m; i++)
-            out[i] = in[s->revtab[i]];
+            out[i] = in[s->revtab_c[i]];
     }
 
     fft_dispatch[mb](out);
@@ -454,7 +685,7 @@ static void compound_imdct_##N##xM(AVTXContext *s, void *_dst, void *_src,     \
             FFTComplex tmp = { in2[-k*stride], in1[k*stride] };                \
             CMUL3(fft##N##in[j], tmp, exp[k >> 1]);                            \
         }                                                                      \
-        fft##N(s->tmp + s->revtab[i], fft##N##in, m);                          \
+        fft##N(s->tmp + s->revtab_c[i], fft##N##in, m);                        \
     }                                                                          \
                                                                                \
     for (int i = 0; i < N; i++)                                                \
@@ -473,6 +704,8 @@ static void compound_imdct_##N##xM(AVTXContext *s, void *_dst, void *_src,     \
 
 DECL_COMP_IMDCT(3)
 DECL_COMP_IMDCT(5)
+DECL_COMP_IMDCT(7)
+DECL_COMP_IMDCT(9)
 DECL_COMP_IMDCT(15)
 
 #define DECL_COMP_MDCT(N)                                                      \
@@ -500,7 +733,7 @@ static void compound_mdct_##N##xM(AVTXContext *s, void *_dst, void *_src,      \
             CMUL(fft##N##in[j].im, fft##N##in[j].re, tmp.re, tmp.im,           \
                  exp[k >> 1].re, exp[k >> 1].im);                              \
         }                                                                      \
-        fft##N(s->tmp + s->revtab[i], fft##N##in, m);                          \
+        fft##N(s->tmp + s->revtab_c[i], fft##N##in, m);                        \
     }                                                                          \
                                                                                \
     for (int i = 0; i < N; i++)                                                \
@@ -521,6 +754,8 @@ static void compound_mdct_##N##xM(AVTXContext *s, void *_dst, void *_src,      \
 
 DECL_COMP_MDCT(3)
 DECL_COMP_MDCT(5)
+DECL_COMP_MDCT(7)
+DECL_COMP_MDCT(9)
 DECL_COMP_MDCT(15)
 
 static void monolithic_imdct(AVTXContext *s, void *_dst, void *_src,
@@ -537,7 +772,7 @@ static void monolithic_imdct(AVTXContext *s, void *_dst, void *_src,
 
     for (int i = 0; i < m; i++) {
         FFTComplex tmp = { in2[-2*i*stride], in1[2*i*stride] };
-        CMUL3(z[s->revtab[i]], tmp, exp[i]);
+        CMUL3(z[s->revtab_c[i]], tmp, exp[i]);
     }
 
     fftp(z);
@@ -571,7 +806,7 @@ static void monolithic_mdct(AVTXContext *s, void *_dst, void *_src,
             tmp.re = FOLD(-src[ len4 + k], -src[5*len4 - 1 - k]);
             tmp.im = FOLD( src[-len4 + k], -src[1*len3 - 1 - k]);
         }
-        CMUL(z[s->revtab[i]].im, z[s->revtab[i]].re, tmp.re, tmp.im,
+        CMUL(z[s->revtab_c[i]].im, z[s->revtab_c[i]].re, tmp.re, tmp.im,
              exp[i].re, exp[i].im);
     }
 
@@ -640,6 +875,24 @@ static void naive_mdct(AVTXContext *s, void *_dst, void *_src,
     }
 }
 
+static void full_imdct_wrapper_fn(AVTXContext *s, void *_dst, void *_src,
+                                  ptrdiff_t stride)
+{
+    int len = s->m*s->n*4;
+    int len2 = len >> 1;
+    int len4 = len >> 2;
+    FFTSample *dst = _dst;
+
+    s->top_tx(s, dst + len4, _src, stride);
+
+    stride /= sizeof(*dst);
+
+    for (int i = 0; i < len4; i++) {
+        dst[            i*stride] = -dst[(len2 - i - 1)*stride];
+        dst[(len - i - 1)*stride] =  dst[(len2 + i + 0)*stride];
+    }
+}
+
 static int gen_mdct_exptab(AVTXContext *s, int len4, double scale)
 {
     const double theta = (scale < 0 ? len4 : 0) + 1.0/8.0;
@@ -675,6 +928,8 @@ int TX_NAME(ff_tx_init_mdct_fft)(AVTXContext *s, av_tx_fn *tx,
         SRC /= FACTOR;                                                         \
     }
     CHECK_FACTOR(n, 15, len)
+    CHECK_FACTOR(n,  9, len)
+    CHECK_FACTOR(n,  7, len)
     CHECK_FACTOR(n,  5, len)
     CHECK_FACTOR(n,  3, len)
 #undef CHECK_FACTOR
@@ -705,6 +960,10 @@ int TX_NAME(ff_tx_init_mdct_fft)(AVTXContext *s, av_tx_fn *tx,
         if (is_mdct) {
             s->scale = *((SCALE_TYPE *)scale);
             *tx = inv ? naive_imdct : naive_mdct;
+            if (inv && (flags & AV_TX_FULL_IMDCT)) {
+                s->top_tx = *tx;
+                *tx = full_imdct_wrapper_fn;
+            }
         }
         return 0;
     }
@@ -714,36 +973,52 @@ int TX_NAME(ff_tx_init_mdct_fft)(AVTXContext *s, av_tx_fn *tx,
             return err;
         if (!(s->tmp = av_malloc(n*m*sizeof(*s->tmp))))
             return AVERROR(ENOMEM);
-        *tx = n == 3 ? compound_fft_3xM :
-              n == 5 ? compound_fft_5xM :
-                       compound_fft_15xM;
-        if (is_mdct)
-            *tx = n == 3 ? inv ? compound_imdct_3xM  : compound_mdct_3xM :
-                  n == 5 ? inv ? compound_imdct_5xM  : compound_mdct_5xM :
-                           inv ? compound_imdct_15xM : compound_mdct_15xM;
+        if (!(m & (m - 1))) {
+            *tx = n == 3 ? compound_fft_3xM :
+                  n == 5 ? compound_fft_5xM :
+                  n == 7 ? compound_fft_7xM :
+                  n == 9 ? compound_fft_9xM :
+                           compound_fft_15xM;
+            if (is_mdct)
+                *tx = n == 3 ? inv ? compound_imdct_3xM  : compound_mdct_3xM :
+                      n == 5 ? inv ? compound_imdct_5xM  : compound_mdct_5xM :
+                      n == 7 ? inv ? compound_imdct_7xM  : compound_mdct_7xM :
+                      n == 9 ? inv ? compound_imdct_9xM  : compound_mdct_9xM :
+                               inv ? compound_imdct_15xM : compound_mdct_15xM;
+        }
     } else { /* Direct transform case */
         *tx = split_radix_fft;
         if (is_mdct)
             *tx = inv ? monolithic_imdct : monolithic_mdct;
     }
 
-    if (n != 1)
+    if (n == 3 || n == 5 || n == 15)
         init_cos_tabs(0);
-    if (m != 1) {
+    else if (n == 7)
+        init_cos_tabs(1);
+    else if (n == 9)
+        init_cos_tabs(2);
+
+    if (m != 1 && !(m & (m - 1))) {
         if ((err = ff_tx_gen_ptwo_revtab(s, n == 1 && !is_mdct && !(flags & AV_TX_INPLACE))))
             return err;
         if (flags & AV_TX_INPLACE) {
             if (is_mdct) /* In-place MDCTs are not supported yet */
                 return AVERROR(ENOSYS);
-            if ((err = ff_tx_gen_ptwo_inplace_revtab_idx(s)))
+            if ((err = ff_tx_gen_ptwo_inplace_revtab_idx(s, s->revtab_c)))
                 return err;
         }
         for (int i = 4; i <= av_log2(m); i++)
             init_cos_tabs(i);
     }
 
-    if (is_mdct)
+    if (is_mdct) {
+        if (inv && (flags & AV_TX_FULL_IMDCT)) {
+            s->top_tx = *tx;
+            *tx = full_imdct_wrapper_fn;
+        }
         return gen_mdct_exptab(s, n*m, *((SCALE_TYPE *)scale));
+    }
 
     return 0;
 }