]> git.sesse.net Git - movit/commitdiff
Add a new effect that can do FFT/IFFT.
authorSteinar H. Gunderson <sgunderson@bigfoot.com>
Sun, 19 Jan 2014 17:27:55 +0000 (18:27 +0100)
committerSteinar H. Gunderson <sgunderson@bigfoot.com>
Sun, 19 Jan 2014 17:32:54 +0000 (18:32 +0100)
Makefile.in
fft_pass_effect.cpp [new file with mode: 0644]
fft_pass_effect.frag [new file with mode: 0644]
fft_pass_effect.h [new file with mode: 0644]
fft_pass_effect_test.cpp [new file with mode: 0644]

index 78b82fba6bcb511a8b3302cad3a3c1cb014e7c67..73766ac75b31b69765ab4be9da5d4519a63559a9 100644 (file)
@@ -49,6 +49,7 @@ TESTED_EFFECTS += padding_effect
 TESTED_EFFECTS += resample_effect
 TESTED_EFFECTS += dither_effect
 TESTED_EFFECTS += deconvolution_sharpen_effect
+TESTED_EFFECTS += fft_pass_effect
 
 UNTESTED_EFFECTS = sandbox_effect
 UNTESTED_EFFECTS += mirror_effect
diff --git a/fft_pass_effect.cpp b/fft_pass_effect.cpp
new file mode 100644 (file)
index 0000000..a3de379
--- /dev/null
@@ -0,0 +1,155 @@
+#include <GL/glew.h>
+
+#include "fft_pass_effect.h"
+#include "effect_util.h"
+#include "util.h"
+
+FFTPassEffect::FFTPassEffect()
+       : input_width(1280),
+         input_height(720),
+         direction(HORIZONTAL)
+{
+       register_int("fft_size", &fft_size);
+       register_int("direction", (int *)&direction);
+       register_int("pass_number", &pass_number);
+       register_int("inverse", &inverse);
+       glGenTextures(1, &tex);
+}
+
+FFTPassEffect::~FFTPassEffect()
+{
+       glDeleteTextures(1, &tex);
+}
+
+std::string FFTPassEffect::output_fragment_shader()
+{
+       char buf[256];
+       sprintf(buf, "#define DIRECTION_VERTICAL %d\n", (direction == VERTICAL));
+       return buf + read_file("fft_pass_effect.frag");
+}
+
+void FFTPassEffect::set_gl_state(GLuint glsl_program_num, const std::string &prefix, unsigned *sampler_num)
+{
+       Effect::set_gl_state(glsl_program_num, prefix, sampler_num);
+
+       int input_size = (direction == VERTICAL) ? input_height : input_width;
+
+       // See the comments on changes_output_size() in the .h file to see
+       // why this is legal. It is _needed_ because it counteracts the
+       // precision issues we get because we sample the input texture with
+       // normalized coordinates (especially when the repeat count along
+       // the axis is not a power of two); we very rapidly end up in narrowly
+       // missing a texel center, which causes precision loss to propagate
+       // throughout the FFT.
+       assert(*sampler_num == 1);
+       glActiveTexture(GL_TEXTURE0);
+       check_error();
+       glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+       check_error();
+       glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+       check_error();
+
+       // The memory layout follows figure 5.2 on page 25 of
+       // http://gpuwave.sesse.net/gpuwave.pdf -- it can be a bit confusing
+       // at first, but is classically explained more or less as follows:
+       //
+       // The classic Cooley-Tukey decimation-in-time FFT algorithm works
+       // by first splitting input data into odd and even elements
+       // (e.g. bit-wise xxxxx0 and xxxxx1 for a size-32 FFT), then FFTing
+       // them separately and combining them using twiddle factors.
+       // So the outer pass (done _last_) looks only at the last bit,
+       // and does one such merge pass of sub-size N/2 (FFT size N).
+       //
+       // FFT of the first part must then necessarily be split into xxxx00 and
+       // xxxx10, and similarly xxxx01 and xxxx11 for the other part. Since
+       // these two FFTs are handled identically, it means we split into xxxx0x
+       // and xxxx1x, so that the second-outer pass (done second-to-last)
+       // looks only at the second last bit, and so on. We do two such merge
+       // passes of sub-size N/4 (sub-FFT size N/2).
+       //
+       // Thus, the inner, Nth pass (done first) splits at the first bit,
+       // so 0 is paired with 16, 1 with 17 and so on, doing N/2 such merge
+       // passes of sub-size 1 (sub-FFT size 2). We say that the stride is 16.
+       // The second-inner, (N-1)th pass (done second) splits at the second
+       // bit, so the stride is 8, and so on.
+
+       assert((fft_size & (fft_size - 1)) == 0);  // Must be power of two.
+       float *tmp = new float[fft_size * 4];
+       int subfft_size = 1 << pass_number;
+       double mulfac;
+       if (inverse) {
+               mulfac = 2.0 * M_PI;
+       } else {
+               mulfac = -2.0 * M_PI;
+       }
+
+       assert((fft_size & (fft_size - 1)) == 0);  // Must be power of two.
+       assert(fft_size % subfft_size == 0);
+       int stride = fft_size / subfft_size;
+       for (int i = 0; i < fft_size; ++i) {
+               int k = i / stride;         // Element number within this sub-FFT.
+               int offset = i % stride;    // Sub-FFT number.
+               double twiddle_real, twiddle_imag;
+
+               if (k < subfft_size / 2) {
+                       twiddle_real = cos(mulfac * (k / double(subfft_size)));
+                       twiddle_imag = sin(mulfac * (k / double(subfft_size)));
+               } else {
+                       // This is mathematically equivalent to the twiddle factor calculations
+                       // in the other branch of the if, but not numerically; the range
+                       // reductions on x87 are not all that precise, and this keeps us within
+                       // [0,pi>.
+                       k -= subfft_size / 2;
+                       twiddle_real = -cos(mulfac * (k / double(subfft_size)));
+                       twiddle_imag = -sin(mulfac * (k / double(subfft_size)));
+               }
+
+               // The support texture contains everything we need for the FFT:
+               // Obviously, the twiddle factor (in the Z and W components), but also
+               // which two samples to fetch. These are stored as normalized
+               // X coordinate offsets (Y coordinate for a vertical FFT); the reason
+               // for using offsets and not direct coordinates as in GPUwave
+               // is that we can have multiple FFTs along the same line,
+               // and want to reuse the support texture by repeating it.
+               int base = k * stride * 2 + offset;
+               int support_texture_index;
+               if (direction == FFTPassEffect::VERTICAL) {
+                       // Compensate for OpenGL's bottom-left convention.
+                       support_texture_index = fft_size - i - 1;
+               } else {
+                       support_texture_index = i;
+               }
+               tmp[support_texture_index * 4 + 0] = (base - support_texture_index) / double(input_size);
+               tmp[support_texture_index * 4 + 1] = (base + stride - support_texture_index) / double(input_size);
+               tmp[support_texture_index * 4 + 2] = twiddle_real;
+               tmp[support_texture_index * 4 + 3] = twiddle_imag;
+       }
+
+       glActiveTexture(GL_TEXTURE0 + *sampler_num);
+       check_error();
+       glBindTexture(GL_TEXTURE_1D, tex);
+       check_error();
+       glTexParameteri(GL_TEXTURE_1D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+       check_error();
+       glTexParameteri(GL_TEXTURE_1D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+       check_error();
+       glTexParameteri(GL_TEXTURE_1D, GL_TEXTURE_WRAP_S, GL_REPEAT);
+       check_error();
+
+       // Supposedly FFTs are very sensitive to inaccuracies in the twiddle factors,
+       // at least according to a paper by Schatzman (see gpuwave.pdf reference [30]
+       // for the full reference), so we keep them at 32-bit. However, for
+       // small sizes, all components are exact anyway, so we can cheat there
+       // (although noting that the source coordinates become somewhat less
+       // accurate then, too).
+       glTexImage1D(GL_TEXTURE_1D, 0, (subfft_size <= 4) ? GL_RGBA16F : GL_RGBA32F, fft_size, 0, GL_RGBA, GL_FLOAT, tmp);
+       check_error();
+
+       delete[] tmp;
+
+       set_uniform_int(glsl_program_num, prefix, "support_tex", *sampler_num);
+       ++*sampler_num;
+
+       assert(input_size % fft_size == 0);
+       set_uniform_float(glsl_program_num, prefix, "num_repeats", input_size / fft_size);
+}
diff --git a/fft_pass_effect.frag b/fft_pass_effect.frag
new file mode 100644 (file)
index 0000000..462a673
--- /dev/null
@@ -0,0 +1,24 @@
+// DIRECTION_VERTICAL will be #defined to 1 if we are doing a vertical FFT,
+// and 0 otherwise.
+
+uniform float PREFIX(num_repeats);
+uniform sampler1D PREFIX(support_tex);
+
+vec4 FUNCNAME(vec2 tc) {
+#if DIRECTION_VERTICAL
+       vec4 support = texture1D(PREFIX(support_tex), tc.y * PREFIX(num_repeats));
+        vec4 c1 = INPUT(vec2(tc.x, 1.0 - (tc.y + support.x)));
+        vec4 c2 = INPUT(vec2(tc.x, 1.0 - (tc.y + support.y)));
+#else
+       vec4 support = texture1D(PREFIX(support_tex), tc.x * PREFIX(num_repeats));
+        vec4 c1 = INPUT(vec2(tc.x + support.x, tc.y));
+        vec4 c2 = INPUT(vec2(tc.x + support.y, tc.y));
+#endif
+       // Two complex additions and multiplications in parallel; essentially
+       //
+       //   result.xy = c1.xy + twiddle * c2.xy
+       //   result.zw = c1.zw + twiddle * c2.zw
+       //
+       // where * is complex multiplication.
+       return c1 + support.z * c2 + support.w * vec4(-c2.y, c2.x, -c2.w, c2.z);
+}
diff --git a/fft_pass_effect.h b/fft_pass_effect.h
new file mode 100644 (file)
index 0000000..b3e025b
--- /dev/null
@@ -0,0 +1,110 @@
+#ifndef _MOVIT_FFT_PASS_EFFECT_H
+#define _MOVIT_FFT_PASS_EFFECT_H 1
+
+// One pass of a radix-2, in-order, decimation-in-time 1D FFT/IFFT. If you
+// connect multiple ones of these together, you will eventually have a complete
+// FFT or IFFT. The FFTed data is not so useful for video effects in itself,
+// but enables faster convolutions (especially non-separable 2D convolutions)
+// than can be done directly, by doing FFT -> multiply -> IFFT. The utilities
+// for doing this efficiently will probably be added to Movit at a later date;
+// for now, this effect isn't the most useful.
+//
+// An introduction to FFTs is outside the scope of a file-level comment; see
+// http://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm#The_radix-2_DIT_case .
+//
+// The pixels are not really interpreted as pixels, but are interpreted as two
+// complex numbers with (real,imaginary) parts stored in (R,G) and (B,A).
+// On top of this two-way parallelism, many FFTs are done in parallel (see below).
+//
+// Implementing a high-performance FFT on the GPU is not easy, especially
+// within the demands of Movit filters. (This is one of the places where
+// using CUDA or D3D would be easier, as both ship with pre-made and highly
+// tuned FFTs.) We don't go to great lengths to get an optimal implementation,
+// but rather stay with someting simple. I'll conveniently enough refer to
+// my own report on this topic from 2007, namely
+//
+//    Steinar H. Gunderson: “GPUwave: An implementation of the split-step
+//    Fourier method for the GPU”, http://gpuwave.sesse.net/gpuwave.pdf
+//
+// Chapter 5 contains the details of the FFT. We follow this rather closely,
+// with the exception that in Movit, we only ever draw a single quad,
+// so the strategy used in GPUwave with drawing multiple quads with constant
+// twiddle factors on them will not be in use here. (It requires some
+// benchmarking to find the optimal crossover point anyway.)
+//
+// Also, we support doing many FFTs along the same axis, so e.g. if you
+// have a 128x128 image and ask for a horizontal FFT of size 64, you will
+// actually get 256 of them (two wide, 128 high). This is in contrast with
+// GPUwave, which only supports them one wide; in a picture setting,
+// moving blocks around to create only one block wide FFTs would rapidly
+// lead to way too slender textures to be practical (e.g., 1280x720
+// with an FFT of size 64 would be 64x14400 rearranged, and many GPUs
+// have limits of 8192 pixels or even 2048 along one dimension).
+//
+// Note that this effect produces an _unnormalized_ FFT, which means that a
+// FFT -> IFFT chain will end up not returning the original data (even modulo
+// precision errors) but rather the original data with each element multiplied
+// by N, the FFT size. As the FFT and IFFT contribute equally to this energy
+// gain, it is recommended that you do the division by N after the FFT but
+// before the IFFT. This way, you use the least range possible (for one
+// scaling), and as fp16 has quite limited range at times, this can be relevant
+// on some GPUs for larger sizes.
+
+#include <stdio.h>
+#include <GL/glew.h>
+#include <string>
+
+#include "effect.h"
+
+class FFTPassEffect : public Effect {
+public:
+       FFTPassEffect();
+       ~FFTPassEffect();
+       virtual std::string effect_type_id() const {
+               char buf[256];
+               if (inverse) {
+                       snprintf(buf, sizeof(buf), "IFFTPassEffect[%d]", (1 << pass_number));
+               } else {
+                       snprintf(buf, sizeof(buf), "FFTPassEffect[%d]", (1 << pass_number));
+               }
+               return buf;
+       }
+       std::string output_fragment_shader();
+
+       void set_gl_state(GLuint glsl_program_num, const std::string &prefix, unsigned *sampler_num);
+
+       // We don't actually change the output size, but this flag makes sure
+       // that no other effect is chained after us. This is important since
+       // we cannot deliver filtered results; any attempt at sampling in-between
+       // pixels would necessarily give garbage. In addition, we set our sampling
+       // mode to GL_NEAREST, which other effects are not ready for; so, the
+       // combination of these two flags guarantee that we're run entirely alone
+       // in our own phase, which is exactly what we want.
+       virtual bool needs_texture_bounce() const { return true; }
+       virtual bool changes_output_size() const { return true; }
+
+       virtual void inform_input_size(unsigned input_num, unsigned width, unsigned height)
+       {
+               assert(input_num == 0);
+               input_width = width;
+               input_height = height;
+       }
+       
+       virtual void get_output_size(unsigned *width, unsigned *height,
+                                    unsigned *virtual_width, unsigned *virtual_height) const {
+               *width = *virtual_width = input_width;
+               *height = *virtual_height = input_height;
+       }
+       
+       enum Direction { HORIZONTAL = 0, VERTICAL = 1 };
+
+private:
+       int input_width, input_height;
+       GLuint tex;
+       int fft_size;
+       Direction direction;
+       int pass_number;  // From 1..n.
+       int inverse;  // 0 = forward (FFT), 1 = reverse (IFFT).
+};
+
+#endif // !defined(_MOVIT_FFT_PASS_EFFECT_H)
diff --git a/fft_pass_effect_test.cpp b/fft_pass_effect_test.cpp
new file mode 100644 (file)
index 0000000..6a6406c
--- /dev/null
@@ -0,0 +1,332 @@
+// Unit tests for FFTPassEffect.
+
+#include <math.h>
+
+#include "effect_chain.h"
+#include "gtest/gtest.h"
+#include "image_format.h"
+#include "fft_pass_effect.h"
+#include "multiply_effect.h"
+#include "test_util.h"
+
+namespace {
+
+// Generate a random number uniformly distributed between [-1.0, 1.0].
+float uniform_random()
+{
+       return 2.0 * ((float)rand() / RAND_MAX - 0.5);
+}
+
+void setup_fft(EffectChain *chain, int fft_size, bool inverse,
+               bool add_normalizer = false,
+               FFTPassEffect::Direction direction = FFTPassEffect::HORIZONTAL)
+{
+       assert((fft_size & (fft_size - 1)) == 0);  // Must be power of two.
+       for (int i = 1, subsize = 2; subsize <= fft_size; ++i, subsize *= 2) {
+               Effect *fft_effect = chain->add_effect(new FFTPassEffect());
+               bool ok = fft_effect->set_int("fft_size", fft_size);
+               ok |= fft_effect->set_int("pass_number", i);
+               ok |= fft_effect->set_int("inverse", inverse);
+               ok |= fft_effect->set_int("direction", direction);
+               assert(ok);
+       }
+
+       if (add_normalizer) {
+               float factor[4] = { 1.0f / fft_size, 1.0f / fft_size, 1.0f / fft_size, 1.0f / fft_size };
+               Effect *multiply_effect = chain->add_effect(new MultiplyEffect());
+               bool ok = multiply_effect->set_vec4("factor", factor);
+               assert(ok);
+       }
+}
+
+void run_fft(const float *in, float *out, int fft_size, bool inverse,
+             bool add_normalizer = false,
+             FFTPassEffect::Direction direction = FFTPassEffect::HORIZONTAL)
+{
+       int width, height;
+       if (direction == FFTPassEffect::HORIZONTAL) {
+               width = fft_size;
+               height = 1;
+       } else {
+               width = 1;
+               height = fft_size;
+       }
+       EffectChainTester tester(in, width, height, FORMAT_RGBA_PREMULTIPLIED_ALPHA, COLORSPACE_sRGB, GAMMA_LINEAR);
+       setup_fft(tester.get_chain(), fft_size, inverse, add_normalizer, direction);
+       tester.run(out, GL_RGBA, COLORSPACE_sRGB, GAMMA_LINEAR, OUTPUT_ALPHA_FORMAT_PREMULTIPLIED);
+}
+
+}  // namespace
+
+TEST(FFTPassEffectTest, ZeroStaysZero) {
+       const int fft_size = 64;
+       float data[fft_size * 4] = { 0 };
+       float out_data[fft_size * 4];
+
+       run_fft(data, out_data, fft_size, false);
+       expect_equal(data, out_data, 4, fft_size);
+
+       run_fft(data, out_data, fft_size, true);
+       expect_equal(data, out_data, 4, fft_size);
+}
+
+TEST(FFTPassEffectTest, Impulse) {
+       const int fft_size = 64;
+       float data[fft_size * 4] = { 0 };
+       float expected_data[fft_size * 4], out_data[fft_size * 4];
+       data[0] = 1.0;
+       data[1] = 1.2;
+       data[2] = 1.4;
+       data[3] = 3.0;
+
+       for (int i = 0; i < fft_size; ++i) {
+               expected_data[i * 4 + 0] = data[0];
+               expected_data[i * 4 + 1] = data[1];
+               expected_data[i * 4 + 2] = data[2];
+               expected_data[i * 4 + 3] = data[3];
+       }
+
+       run_fft(data, out_data, fft_size, false);
+       expect_equal(expected_data, out_data, 4, fft_size);
+
+       run_fft(data, out_data, fft_size, true);
+       expect_equal(expected_data, out_data, 4, fft_size);
+}
+
+TEST(FFTPassEffectTest, SingleFrequency) {
+       const int fft_size = 16;
+       float data[fft_size * 4] = { 0 };
+       float expected_data[fft_size * 4], out_data[fft_size * 4];
+       for (int i = 0; i < fft_size; ++i) {
+               data[i * 4 + 0] = sin(2.0 * M_PI * (4.0 * i) / fft_size);
+               data[i * 4 + 1] = 0.0;
+               data[i * 4 + 2] = 0.0;
+               data[i * 4 + 3] = 0.0;
+       }
+       for (int i = 0; i < fft_size; ++i) {
+               expected_data[i * 4 + 0] = 0.0;
+               expected_data[i * 4 + 1] = 0.0;
+               expected_data[i * 4 + 2] = 0.0;
+               expected_data[i * 4 + 3] = 0.0;
+       }
+       expected_data[4 * 4 + 1] = -8.0;
+       expected_data[12 * 4 + 1] = 8.0;
+
+       run_fft(data, out_data, fft_size, false, false, FFTPassEffect::HORIZONTAL);
+       expect_equal(expected_data, out_data, 4, fft_size);
+
+       run_fft(data, out_data, fft_size, false, false, FFTPassEffect::VERTICAL);
+       expect_equal(expected_data, out_data, 4, fft_size);
+}
+
+TEST(FFTPassEffectTest, Repeat) {
+       const int fft_size = 64;
+       const int num_repeats = 31;  // Prime, to make things more challenging.
+       float data[num_repeats * fft_size * 4] = { 0 };
+       float expected_data[num_repeats * fft_size * 4], out_data[num_repeats * fft_size * 4];
+
+       srand(12345);
+       for (int i = 0; i < num_repeats * fft_size * 4; ++i) {
+               data[i] = uniform_random();
+       }
+
+       for (int i = 0; i < num_repeats; ++i) {
+               run_fft(data + i * fft_size * 4, expected_data + i * fft_size * 4, fft_size, false);
+       }
+
+       {
+               // Horizontal.
+               EffectChainTester tester(data, num_repeats * fft_size, 1, FORMAT_RGBA_PREMULTIPLIED_ALPHA, COLORSPACE_sRGB, GAMMA_LINEAR);
+               setup_fft(tester.get_chain(), fft_size, false);
+               tester.run(out_data, GL_RGBA, COLORSPACE_sRGB, GAMMA_LINEAR, OUTPUT_ALPHA_FORMAT_PREMULTIPLIED);
+
+               expect_equal(expected_data, out_data, 4, num_repeats * fft_size);
+       }
+       {
+               // Vertical.
+               EffectChainTester tester(data, 1, num_repeats * fft_size, FORMAT_RGBA_PREMULTIPLIED_ALPHA, COLORSPACE_sRGB, GAMMA_LINEAR);
+               setup_fft(tester.get_chain(), fft_size, false, false, FFTPassEffect::VERTICAL);
+               tester.run(out_data, GL_RGBA, COLORSPACE_sRGB, GAMMA_LINEAR, OUTPUT_ALPHA_FORMAT_PREMULTIPLIED);
+
+               expect_equal(expected_data, out_data, 4, num_repeats * fft_size);
+       }
+}
+
+TEST(FFTPassEffectTest, TwoDimensional) {  // Implicitly tests vertical.
+       srand(1234);
+       const int fft_size = 16;
+       float in[fft_size * fft_size * 4], out[fft_size * fft_size * 4], expected_out[fft_size * fft_size * 4];
+       for (int y = 0; y < fft_size; ++y) {
+               for (int x = 0; x < fft_size; ++x) {
+                       in[(y * fft_size + x) * 4 + 0] =
+                               sin(2.0 * M_PI * (2 * x + 3 * y) / fft_size);
+                       in[(y * fft_size + x) * 4 + 1] = 0.0;
+                       in[(y * fft_size + x) * 4 + 2] = 0.0;
+                       in[(y * fft_size + x) * 4 + 3] = 0.0;
+               }
+       }
+       memset(expected_out, 0, sizeof(expected_out));
+
+       // This result has been verified using the fft2() function in Octave,
+       // which uses FFTW.
+       expected_out[(3 * fft_size + 2) * 4 + 1] = -128.0;
+       expected_out[(13 * fft_size + 14) * 4 + 1] = 128.0;
+
+       EffectChainTester tester(in, fft_size, fft_size, FORMAT_RGBA_PREMULTIPLIED_ALPHA, COLORSPACE_sRGB, GAMMA_LINEAR);
+       setup_fft(tester.get_chain(), fft_size, false, false, FFTPassEffect::HORIZONTAL);
+       setup_fft(tester.get_chain(), fft_size, false, false, FFTPassEffect::VERTICAL);
+       tester.run(out, GL_RGBA, COLORSPACE_sRGB, GAMMA_LINEAR, OUTPUT_ALPHA_FORMAT_PREMULTIPLIED);
+
+       expect_equal(expected_out, out, 4 * fft_size, fft_size, 0.25, 0.0005);
+}
+
+// The classic paper for FFT correctness testing is Funda Ergün:
+// “Testing Multivariate Linear Functions: Overcoming the Generator Bottleneck”
+// (http://www.cs.sfu.ca/~funda/PUBLICATIONS/stoc95.ps), which proves that
+// testing three basic properties of FFTs guarantees that the function is
+// correct (at least under the assumption that errors are random).
+//
+// We don't follow the paper directly, though, for a few reasons: First,
+// Ergün's paper really considers _self-correcting_ systems, which may
+// be stochastically faulty, and thus uses various relatively complicated
+// bounds and tests we don't really need. Second, the FFTs it considers
+// are all about polynomials over finite fields, which means that results
+// are exact and thus easy to test; we work with floats (half-floats!),
+// and thus need some error tolerance.
+//
+// So instead, we follow the implementation of FFTW, which is really the
+// gold standard when it comes to FFTs these days. They hard-code 20
+// testing rounds as opposed to the more complicated bounds in the paper,
+// and have a simpler version of the third test.
+//
+// The error bounds are set somewhat empirically, but remember that these
+// inputs will give frequency values as large as ~16, where 0.025 is
+// within the 9th bit (of 11 total mantissa bits in fp16).
+const int ergun_rounds = 20;
+
+// Test 1: Test that FFT(a + b) = FFT(a) + FFT(b).
+TEST(FFTPassEffectTest, ErgunLinearityTest) {
+       srand(1234);
+       const int max_fft_size = 64;
+       float a[max_fft_size * 4], b[max_fft_size * 4], sum[max_fft_size * 4];
+       float a_out[max_fft_size * 4], b_out[max_fft_size * 4], sum_out[max_fft_size * 4], expected_sum_out[max_fft_size * 4];
+       for (int fft_size = 2; fft_size <= max_fft_size; fft_size *= 2) {
+               for (int inverse = 0; inverse <= 1; ++inverse) {
+                       for (int i = 0; i < ergun_rounds; ++i) {
+                               for (int j = 0; j < fft_size * 4; ++j) {
+                                       a[j] = uniform_random();
+                                       b[j] = uniform_random();
+                               }
+                               run_fft(a, a_out, fft_size, inverse);
+                               run_fft(b, b_out, fft_size, inverse);
+
+                               for (int j = 0; j < fft_size * 4; ++j) {
+                                       sum[j] = a[j] + b[j];
+                                       expected_sum_out[j] = a_out[j] + b_out[j];
+                               }
+
+                               run_fft(sum, sum_out, fft_size, inverse);
+                               expect_equal(expected_sum_out, sum_out, 4, fft_size, 0.03, 0.0005);
+                       }
+               }
+       }
+}
+
+// Test 2: Test that FFT(delta(i)) = 1  (where delta(i) = [1 0 0 0 ...]),
+// or more specifically, test that FFT(a + delta(i)) - FFT(a) = 1.
+TEST(FFTPassEffectTest, ErgunImpulseTransform) {
+       srand(1235);
+       const int max_fft_size = 64;
+       float a[max_fft_size * 4], b[max_fft_size * 4];
+       float a_out[max_fft_size * 4], b_out[max_fft_size * 4], sum_out[max_fft_size * 4], expected_sum_out[max_fft_size * 4];
+       for (int fft_size = 2; fft_size <= max_fft_size; fft_size *= 2) {
+               for (int inverse = 0; inverse <= 1; ++inverse) {
+                       for (int i = 0; i < ergun_rounds; ++i) {
+                               for (int j = 0; j < fft_size * 4; ++j) {
+                                       a[j] = uniform_random();
+
+                                       // Compute delta(j) - a.
+                                       if (j < 4) {
+                                               b[j] = 1.0 - a[j];
+                                       } else {
+                                               b[j] = -a[j];
+                                       }
+                               }
+                               run_fft(a, a_out, fft_size, inverse);
+                               run_fft(b, b_out, fft_size, inverse);
+
+                               for (int j = 0; j < fft_size * 4; ++j) {
+                                       sum_out[j] = a_out[j] + b_out[j];
+                                       expected_sum_out[j] = 1.0;
+                               }
+                               expect_equal(expected_sum_out, sum_out, 4, fft_size, 0.025, 0.0005);
+                       }
+               }
+       }
+}
+
+// Test 3: Test the time-shift property of the FFT, in that a circular left-shift
+// multiplies the result by e^(j 2pi k/N) (linear phase adjustment).
+// As fftw_test.c says, “The paper performs more tests, but this code should be
+// fine too”.
+TEST(FFTPassEffectTest, ErgunShiftProperty) {
+       srand(1236);
+       const int max_fft_size = 64;
+       float a[max_fft_size * 4], b[max_fft_size * 4];
+       float a_out[max_fft_size * 4], b_out[max_fft_size * 4], expected_a_out[max_fft_size * 4];
+       for (int fft_size = 2; fft_size <= max_fft_size; fft_size *= 2) {
+               for (int inverse = 0; inverse <= 1; ++inverse) {
+                       for (int direction = 0; direction <= 1; ++direction) {
+                               for (int i = 0; i < ergun_rounds; ++i) {
+                                       for (int j = 0; j < fft_size * 4; ++j) {
+                                               a[j] = uniform_random();
+                                       }
+
+                                       // Circular shift left by one step.
+                                       for (int j = 0; j < fft_size * 4; ++j) {
+                                               b[j] = a[(j + 4) % (fft_size * 4)];
+                                       }
+                                       run_fft(a, a_out, fft_size, inverse, false, FFTPassEffect::Direction(direction));
+                                       run_fft(b, b_out, fft_size, inverse, false, FFTPassEffect::Direction(direction));
+
+                                       for (int j = 0; j < fft_size; ++j) {
+                                               double s = -sin(j * 2.0 * M_PI / fft_size);
+                                               double c = cos(j * 2.0 * M_PI / fft_size);
+                                               if (inverse) {
+                                                       s = -s;
+                                               }
+
+                                               expected_a_out[j * 4 + 0] = b_out[j * 4 + 0] * c - b_out[j * 4 + 1] * s;
+                                               expected_a_out[j * 4 + 1] = b_out[j * 4 + 0] * s + b_out[j * 4 + 1] * c;
+
+                                               expected_a_out[j * 4 + 2] = b_out[j * 4 + 2] * c - b_out[j * 4 + 3] * s;
+                                               expected_a_out[j * 4 + 3] = b_out[j * 4 + 2] * s + b_out[j * 4 + 3] * c;
+                                       }
+                                       expect_equal(expected_a_out, a_out, 4, fft_size, 0.025, 0.0005);
+                               }
+                       }
+               }
+       }
+}
+
+TEST(FFTPassEffectTest, BigFFTAccuracy) {
+       srand(1234);
+       const int max_fft_size = 2048;
+       float in[max_fft_size * 4], out[max_fft_size * 4], out2[max_fft_size * 4];
+       for (int fft_size = 2; fft_size <= max_fft_size; fft_size *= 2) {
+               for (int j = 0; j < fft_size * 4; ++j) {
+                       in[j] = uniform_random();
+               }
+               run_fft(in, out, fft_size, false, true);  // Forward, with normalization.
+               run_fft(out, out2, fft_size, true);       // Reverse.
+
+               // These error bounds come from
+               // http://en.wikipedia.org/wiki/Fast_Fourier_transform#Accuracy_and_approximations,
+               // with empirically estimated epsilons. Note that the calculated
+               // rms in expect_equal() is divided by sqrt(N), so we compensate
+               // similarly here.
+               double max_error = 0.0009 * log2(fft_size);
+               double rms_limit = 0.0007 * sqrt(log2(fft_size)) / sqrt(fft_size);
+               expect_equal(in, out2, 4, fft_size, max_error, rms_limit);
+       }
+}