4 #include "complex_modulate_effect.h"
5 #include "effect_chain.h"
6 #include "fft_convolution_effect.h"
8 #include "fft_pass_effect.h"
9 #include "multiply_effect.h"
10 #include "padding_effect.h"
11 #include "slice_effect.h"
18 FFTConvolutionEffect::FFTConvolutionEffect(int input_width, int input_height, int convolve_width, int convolve_height)
19 : input_width(input_width),
20 input_height(input_height),
21 convolve_width(convolve_width),
22 convolve_height(convolve_height),
23 fft_input(new FFTInput(convolve_width, convolve_height)),
24 crop_effect(new PaddingEffect()),
26 CHECK(crop_effect->set_int("width", input_width));
27 CHECK(crop_effect->set_int("height", input_height));
28 CHECK(crop_effect->set_float("top", 0));
29 CHECK(crop_effect->set_float("left", 0));
32 FFTConvolutionEffect::~FFTConvolutionEffect()
42 // Returns the last Effect in the new chain.
43 Effect *add_overlap_and_fft(EffectChain *chain, Effect *last_effect, int fft_size, int pad_size, FFTPassEffect::Direction direction)
47 Effect *overlap_effect = chain->add_effect(new SliceEffect(), last_effect);
48 CHECK(overlap_effect->set_int("input_slice_size", fft_size - pad_size));
49 CHECK(overlap_effect->set_int("output_slice_size", fft_size));
50 CHECK(overlap_effect->set_int("offset", -pad_size));
51 if (direction == FFTPassEffect::HORIZONTAL) {
52 CHECK(overlap_effect->set_int("direction", SliceEffect::HORIZONTAL));
54 assert(direction == FFTPassEffect::VERTICAL);
55 CHECK(overlap_effect->set_int("direction", SliceEffect::VERTICAL));
58 last_effect = overlap_effect;
62 int num_passes = ffs(fft_size) - 1;
63 for (int i = 1; i <= num_passes; ++i) {
64 Effect *fft_effect = chain->add_effect(new FFTPassEffect(), last_effect);
65 CHECK(fft_effect->set_int("pass_number", i));
66 CHECK(fft_effect->set_int("fft_size", fft_size));
67 CHECK(fft_effect->set_int("direction", direction));
68 CHECK(fft_effect->set_int("inverse", 0));
70 last_effect = fft_effect;
76 // Returns the last Effect in the new chain.
77 Effect *add_ifft_and_discard(EffectChain *chain, Effect *last_effect, int fft_size, int pad_size, FFTPassEffect::Direction direction)
80 int num_passes = ffs(fft_size) - 1;
81 for (int i = 1; i <= num_passes; ++i) {
82 Effect *fft_effect = chain->add_effect(new FFTPassEffect(), last_effect);
83 CHECK(fft_effect->set_int("pass_number", i));
84 CHECK(fft_effect->set_int("fft_size", fft_size));
85 CHECK(fft_effect->set_int("direction", direction));
86 CHECK(fft_effect->set_int("inverse", 1));
88 last_effect = fft_effect;
93 Effect *discard_effect = chain->add_effect(new SliceEffect(), last_effect);
94 CHECK(discard_effect->set_int("input_slice_size", fft_size));
95 CHECK(discard_effect->set_int("output_slice_size", fft_size - pad_size));
96 if (direction == FFTPassEffect::HORIZONTAL) {
97 CHECK(discard_effect->set_int("direction", SliceEffect::HORIZONTAL));
99 assert(direction == FFTPassEffect::VERTICAL);
100 CHECK(discard_effect->set_int("direction", SliceEffect::VERTICAL));
102 CHECK(discard_effect->set_int("offset", pad_size));
104 last_effect = discard_effect;
112 void FFTConvolutionEffect::rewrite_graph(EffectChain *chain, Node *self)
114 int pad_width = convolve_width - 1;
115 int pad_height = convolve_height - 1;
117 // Try all possible FFT widths and heights to see which one is the
118 // cheapest. As a proxy for real performance, we use number of texel
119 // fetches; this isn't perfect by any means, but it's easy to work with
120 // and should be approximately correct.
121 int min_x = next_power_of_two(1 + pad_width);
122 int min_y = next_power_of_two(1 + pad_height);
123 int max_y = next_power_of_two(input_height + pad_width);
124 int max_x = next_power_of_two(input_width + pad_height);
126 size_t best_cost = numeric_limits<size_t>::max();
127 int best_x = -1, best_y = -1, best_x_before_y_fft = -1, best_x_before_y_ifft = -1;
131 // overlap(X), FFT(X), overlap(Y), FFT(Y), modulate, IFFT(Y), discard(Y), IFFT(X), discard(X) and
132 // overlap(Y), FFT(Y), overlap(X), FFT(X), modulate, IFFT(X), discard(X), IFFT(Y), discard(Y)
134 // For simplicity, call them the XY-YX and YX-XY orders. In theory, we
135 // could have XY-XY and YX-YX orders as well, and I haven't found a
136 // convincing argument that they will never be optimal (although it
137 // sounds odd and should be rare), so we test all four possible ones.
139 // We assume that the kernel FFT is for free, since it is typically done
140 // only once and per frame.
141 for (int x_before_y_fft = 0; x_before_y_fft <= 1; ++x_before_y_fft) {
142 for (int x_before_y_ifft = 0; x_before_y_ifft <= 1; ++x_before_y_ifft) {
143 for (int y = min_y; y <= max_y; y *= 2) {
144 int y_pixels_per_block = y - pad_height;
145 int num_vertical_blocks = div_round_up(input_height, y_pixels_per_block);
146 size_t output_height = y * num_vertical_blocks;
147 for (int x = min_x; x <= max_x; x *= 2) {
148 int x_pixels_per_block = x - pad_width;
149 int num_horizontal_blocks = div_round_up(input_width, x_pixels_per_block);
150 size_t output_width = x * num_horizontal_blocks;
154 if (x_before_y_fft) {
155 // First, the cost of the horizontal padding.
156 cost = output_width * input_height;
158 // log(X) FFT passes. Each pass reads two inputs per pixel,
159 // plus the support texture.
160 cost += (ffs(x) - 1) * 3 * output_width * input_height;
162 // Now, horizontal padding.
163 cost += output_width * output_height;
165 // log(Y) FFT passes, now at full resolution.
166 cost += (ffs(y) - 1) * 3 * output_width * output_height;
168 // First, the cost of the vertical padding.
169 cost = input_width * output_height;
171 // log(Y) FFT passes. Each pass reads two inputs per pixel,
172 // plus the support texture.
173 cost += (ffs(y) - 1) * 3 * input_width * output_height;
175 // Now, horizontal padding.
176 cost += output_width * output_height;
178 // log(X) FFT passes, now at full resolution.
179 cost += (ffs(x) - 1) * 3 * output_width * output_height;
182 // The actual modulation. Reads one pixel each from two textures.
183 cost += 2 * output_width * output_height;
185 if (x_before_y_ifft) {
186 // log(X) IFFT passes.
187 cost += (ffs(x) - 1) * 3 * output_width * output_height;
189 // Discard horizontally.
190 cost += input_width * output_height;
192 // log(Y) IFFT passes.
193 cost += (ffs(y) - 1) * 3 * input_width * output_height;
195 // Discard horizontally.
196 cost += input_width * input_height;
198 // log(Y) IFFT passes.
199 cost += (ffs(y) - 1) * 3 * output_width * output_height;
201 // Discard vertically.
202 cost += output_width * input_height;
204 // log(X) IFFT passes.
205 cost += (ffs(x) - 1) * 3 * output_width * input_height;
207 // Discard horizontally.
208 cost += input_width * input_height;
211 if (cost < best_cost) {
214 best_x_before_y_fft = x_before_y_fft;
215 best_x_before_y_ifft = x_before_y_ifft;
223 const int fft_width = best_x, fft_height = best_y;
225 assert(self->incoming_links.size() == 1);
226 Node *last_node = self->incoming_links[0];
227 self->incoming_links.clear();
228 last_node->outgoing_links.clear();
231 Effect *last_effect = last_node->effect;
232 if (best_x_before_y_fft) {
233 last_effect = add_overlap_and_fft(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
234 last_effect = add_overlap_and_fft(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
236 last_effect = add_overlap_and_fft(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
237 last_effect = add_overlap_and_fft(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
241 Effect *multiply_effect;
242 float fft_size = fft_width * fft_height;
243 float factor[4] = { 1.0f / fft_size, 1.0f / fft_size, 1.0f / fft_size, 1.0f / fft_size };
244 last_effect = multiply_effect = chain->add_effect(new MultiplyEffect(), last_effect);
245 CHECK(multiply_effect->set_vec4("factor", factor));
247 // Multiply by the FFT of the convolution kernel.
248 CHECK(fft_input->set_int("fft_width", fft_width));
249 CHECK(fft_input->set_int("fft_height", fft_height));
250 chain->add_input(fft_input);
251 owns_effects = false;
253 Effect *modulate_effect = chain->add_effect(new ComplexModulateEffect(), multiply_effect, fft_input);
254 CHECK(modulate_effect->set_int("num_repeats_x", div_round_up(input_width, fft_width - pad_width)));
255 CHECK(modulate_effect->set_int("num_repeats_y", div_round_up(input_height, fft_height - pad_height)));
256 last_effect = modulate_effect;
259 if (best_x_before_y_ifft) {
260 last_effect = add_ifft_and_discard(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
261 last_effect = add_ifft_and_discard(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
263 last_effect = add_ifft_and_discard(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
264 last_effect = add_ifft_and_discard(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
267 // ...and crop away any extra padding we have have added.
268 last_effect = chain->add_effect(crop_effect);
270 chain->replace_sender(self, chain->find_node_for_effect(last_effect));
271 self->disabled = true;