Add an FFT convolution effect.
[movit] / fft_convolution_effect.cpp
1 #include <GL/glew.h>
2 #include <string.h>
3
4 #include "complex_modulate_effect.h"
5 #include "effect_chain.h"
6 #include "fft_convolution_effect.h"
7 #include "fft_input.h"
8 #include "fft_pass_effect.h"
9 #include "multiply_effect.h"
10 #include "padding_effect.h"
11 #include "slice_effect.h"
12 #include "util.h"
13
14 using namespace std;
15
16 namespace movit {
17
18 FFTConvolutionEffect::FFTConvolutionEffect(int input_width, int input_height, int convolve_width, int convolve_height)
19         : input_width(input_width),
20           input_height(input_height),
21           convolve_width(convolve_width),
22           convolve_height(convolve_height),
23           fft_input(new FFTInput(convolve_width, convolve_height)),
24           crop_effect(new PaddingEffect()),
25           owns_effects(true) {
26         CHECK(crop_effect->set_int("width", input_width));
27         CHECK(crop_effect->set_int("height", input_height));
28         CHECK(crop_effect->set_float("top", 0));
29         CHECK(crop_effect->set_float("left", 0));
30 }
31
32 FFTConvolutionEffect::~FFTConvolutionEffect()
33 {
34         if (owns_effects) {
35                 delete fft_input;
36                 delete crop_effect;
37         }
38 }
39
40 namespace {
41
42 // Returns the last Effect in the new chain.
43 Effect *add_overlap_and_fft(EffectChain *chain, Effect *last_effect, int fft_size, int pad_size, FFTPassEffect::Direction direction)
44 {
45         // Overlap.
46         {
47                 Effect *overlap_effect = chain->add_effect(new SliceEffect(), last_effect);
48                 CHECK(overlap_effect->set_int("input_slice_size", fft_size - pad_size));
49                 CHECK(overlap_effect->set_int("output_slice_size", fft_size));
50                 CHECK(overlap_effect->set_int("offset", -pad_size));
51                 if (direction == FFTPassEffect::HORIZONTAL) {
52                         CHECK(overlap_effect->set_int("direction", SliceEffect::HORIZONTAL));
53                 } else {
54                         assert(direction == FFTPassEffect::VERTICAL);
55                         CHECK(overlap_effect->set_int("direction", SliceEffect::VERTICAL));
56                 }
57
58                 last_effect = overlap_effect;
59         }
60
61         // FFT.
62         int num_passes = ffs(fft_size) - 1;
63         for (int i = 1; i <= num_passes; ++i) {
64                 Effect *fft_effect = chain->add_effect(new FFTPassEffect(), last_effect);
65                 CHECK(fft_effect->set_int("pass_number", i));
66                 CHECK(fft_effect->set_int("fft_size", fft_size));
67                 CHECK(fft_effect->set_int("direction", direction));
68                 CHECK(fft_effect->set_int("inverse", 0));
69
70                 last_effect = fft_effect;
71         }
72
73         return last_effect;
74 }
75
76 // Returns the last Effect in the new chain.
77 Effect *add_ifft_and_discard(EffectChain *chain, Effect *last_effect, int fft_size, int pad_size, FFTPassEffect::Direction direction)
78 {
79         // IFFT.
80         int num_passes = ffs(fft_size) - 1;
81         for (int i = 1; i <= num_passes; ++i) {
82                 Effect *fft_effect = chain->add_effect(new FFTPassEffect(), last_effect);
83                 CHECK(fft_effect->set_int("pass_number", i));
84                 CHECK(fft_effect->set_int("fft_size", fft_size));
85                 CHECK(fft_effect->set_int("direction", direction));
86                 CHECK(fft_effect->set_int("inverse", 1));
87
88                 last_effect = fft_effect;
89         }
90
91         // Discard.
92         {
93                 Effect *discard_effect = chain->add_effect(new SliceEffect(), last_effect);
94                 CHECK(discard_effect->set_int("input_slice_size", fft_size));
95                 CHECK(discard_effect->set_int("output_slice_size", fft_size - pad_size));
96                 if (direction == FFTPassEffect::HORIZONTAL) {
97                         CHECK(discard_effect->set_int("direction", SliceEffect::HORIZONTAL));
98                 } else {
99                         assert(direction == FFTPassEffect::VERTICAL);
100                         CHECK(discard_effect->set_int("direction", SliceEffect::VERTICAL));
101                 }
102                 CHECK(discard_effect->set_int("offset", pad_size));
103
104                 last_effect = discard_effect;
105         }
106
107         return last_effect;
108 }
109
110 }  // namespace
111
112 void FFTConvolutionEffect::rewrite_graph(EffectChain *chain, Node *self)
113 {
114         int pad_width = convolve_width - 1;
115         int pad_height = convolve_height - 1;
116
117         // Try all possible FFT widths and heights to see which one is the
118         // cheapest.  As a proxy for real performance, we use number of texel
119         // fetches; this isn't perfect by any means, but it's easy to work with
120         // and should be approximately correct.
121         int min_x = next_power_of_two(1 + pad_width);
122         int min_y = next_power_of_two(1 + pad_height);
123         int max_y = next_power_of_two(input_height + pad_width);
124         int max_x = next_power_of_two(input_width + pad_height);
125
126         size_t best_cost = numeric_limits<size_t>::max();
127         int best_x = -1, best_y = -1, best_x_before_y_fft = -1, best_x_before_y_ifft = -1;
128
129         // Try both
130         //
131         //   overlap(X), FFT(X), overlap(Y), FFT(Y), modulate, IFFT(Y), discard(Y), IFFT(X), discard(X) and
132         //   overlap(Y), FFT(Y), overlap(X), FFT(X), modulate, IFFT(X), discard(X), IFFT(Y), discard(Y)
133         //
134         // For simplicity, call them the XY-YX and YX-XY orders. In theory, we
135         // could have XY-XY and YX-YX orders as well, and I haven't found a
136         // convincing argument that they will never be optimal (although it
137         // sounds odd and should be rare), so we test all four possible ones.
138         //
139         // We assume that the kernel FFT is for free, since it is typically done
140         // only once and per frame.
141         for (int x_before_y_fft = 0; x_before_y_fft <= 1; ++x_before_y_fft) {
142                 for (int x_before_y_ifft = 0; x_before_y_ifft <= 1; ++x_before_y_ifft) {
143                         for (int y = min_y; y <= max_y; y *= 2) {
144                                 int y_pixels_per_block = y - pad_height;
145                                 int num_vertical_blocks = div_round_up(input_height, y_pixels_per_block);
146                                 size_t output_height = y * num_vertical_blocks;
147                                 for (int x = min_x; x <= max_x; x *= 2) {
148                                         int x_pixels_per_block = x - pad_width;
149                                         int num_horizontal_blocks = div_round_up(input_width, x_pixels_per_block);
150                                         size_t output_width = x * num_horizontal_blocks;
151
152                                         size_t cost = 0;
153
154                                         if (x_before_y_fft) {
155                                                 // First, the cost of the horizontal padding.
156                                                 cost = output_width * input_height;
157
158                                                 // log(X) FFT passes. Each pass reads two inputs per pixel,
159                                                 // plus the support texture.
160                                                 cost += (ffs(x) - 1) * 3 * output_width * input_height;
161
162                                                 // Now, horizontal padding.
163                                                 cost += output_width * output_height;
164
165                                                 // log(Y) FFT passes, now at full resolution.
166                                                 cost += (ffs(y) - 1) * 3 * output_width * output_height;
167                                         } else {
168                                                 // First, the cost of the vertical padding.
169                                                 cost = input_width * output_height;
170
171                                                 // log(Y) FFT passes. Each pass reads two inputs per pixel,
172                                                 // plus the support texture.
173                                                 cost += (ffs(y) - 1) * 3 * input_width * output_height;
174
175                                                 // Now, horizontal padding.
176                                                 cost += output_width * output_height;
177
178                                                 // log(X) FFT passes, now at full resolution.
179                                                 cost += (ffs(x) - 1) * 3 * output_width * output_height;
180                                         }
181
182                                         // The actual modulation. Reads one pixel each from two textures.
183                                         cost += 2 * output_width * output_height;
184
185                                         if (x_before_y_ifft) {
186                                                 // log(X) IFFT passes.
187                                                 cost += (ffs(x) - 1) * 3 * output_width * output_height;
188
189                                                 // Discard horizontally.
190                                                 cost += input_width * output_height;
191
192                                                 // log(Y) IFFT passes.
193                                                 cost += (ffs(y) - 1) * 3 * input_width * output_height;
194
195                                                 // Discard horizontally.
196                                                 cost += input_width * input_height;
197                                         } else {
198                                                 // log(Y) IFFT passes.
199                                                 cost += (ffs(y) - 1) * 3 * output_width * output_height;
200
201                                                 // Discard vertically.
202                                                 cost += output_width * input_height;
203
204                                                 // log(X) IFFT passes.
205                                                 cost += (ffs(x) - 1) * 3 * output_width * input_height;
206
207                                                 // Discard horizontally.
208                                                 cost += input_width * input_height;
209                                         }
210
211                                         if (cost < best_cost) {
212                                                 best_x = x;
213                                                 best_y = y;
214                                                 best_x_before_y_fft = x_before_y_fft;
215                                                 best_x_before_y_ifft = x_before_y_ifft;
216                                                 best_cost = cost;
217                                         }
218                                 }
219                         }
220                 }
221         }
222
223         const int fft_width = best_x, fft_height = best_y;
224
225         assert(self->incoming_links.size() == 1);
226         Node *last_node = self->incoming_links[0];
227         self->incoming_links.clear();
228         last_node->outgoing_links.clear();
229
230         // Do FFT.
231         Effect *last_effect = last_node->effect;
232         if (best_x_before_y_fft) {
233                 last_effect = add_overlap_and_fft(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
234                 last_effect = add_overlap_and_fft(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
235         } else {
236                 last_effect = add_overlap_and_fft(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
237                 last_effect = add_overlap_and_fft(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
238         }
239
240         // Normalizer.
241         Effect *multiply_effect;
242         float fft_size = fft_width * fft_height;
243         float factor[4] = { 1.0f / fft_size, 1.0f / fft_size, 1.0f / fft_size, 1.0f / fft_size };
244         last_effect = multiply_effect = chain->add_effect(new MultiplyEffect(), last_effect);
245         CHECK(multiply_effect->set_vec4("factor", factor));
246
247         // Multiply by the FFT of the convolution kernel.
248         CHECK(fft_input->set_int("fft_width", fft_width));
249         CHECK(fft_input->set_int("fft_height", fft_height));
250         chain->add_input(fft_input);
251         owns_effects = false;
252
253         Effect *modulate_effect = chain->add_effect(new ComplexModulateEffect(), multiply_effect, fft_input);
254         CHECK(modulate_effect->set_int("num_repeats_x", div_round_up(input_width, fft_width - pad_width)));
255         CHECK(modulate_effect->set_int("num_repeats_y", div_round_up(input_height, fft_height - pad_height)));
256         last_effect = modulate_effect;
257
258         // Finally, do IFFT.
259         if (best_x_before_y_ifft) {
260                 last_effect = add_ifft_and_discard(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
261                 last_effect = add_ifft_and_discard(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
262         } else {
263                 last_effect = add_ifft_and_discard(chain, last_effect, fft_height, pad_height, FFTPassEffect::VERTICAL);
264                 last_effect = add_ifft_and_discard(chain, last_effect, fft_width, pad_width, FFTPassEffect::HORIZONTAL);
265         }
266
267         // ...and crop away any extra padding we have have added.
268         last_effect = chain->add_effect(crop_effect);
269
270         chain->replace_sender(self, chain->find_node_for_effect(last_effect));
271         self->disabled = true;
272 }
273
274 }  // namespace movit