From: Steinar H. Gunderson Date: Tue, 10 Oct 2017 20:49:21 +0000 (+0200) Subject: Speed up the histogram counting immensely by adding via local memory. X-Git-Url: https://git.sesse.net/?p=narabu;a=commitdiff_plain;h=57fdfa782c418299639dcc670a194716c0657cab Speed up the histogram counting immensely by adding via local memory. --- diff --git a/encoder.shader b/encoder.shader index 163f8fa..8d4bcc2 100644 --- a/encoder.shader +++ b/encoder.shader @@ -1,7 +1,11 @@ #version 440 #extension GL_ARB_shader_clock : enable -layout(local_size_x = 8) in; +// Do sixteen 8x8 blocks in a local group, because that matches up perfectly +// with needing 1024 coefficients for our four histograms (of 256 bins each). +#define NUM_Z 16 + +layout(local_size_x = 8, local_size_z = NUM_Z) in; layout(r16ui) uniform restrict writeonly uimage2D dc_ac7_tex; layout(r16ui) uniform restrict writeonly uimage2D ac1_ac6_tex; @@ -10,11 +14,11 @@ layout(r8i) uniform restrict writeonly iimage2D ac3_tex; layout(r8i) uniform restrict writeonly iimage2D ac4_tex; layout(r8ui) uniform restrict readonly uimage2D image_tex; -shared float temp[64]; +shared uint temp[64 * NUM_Z]; layout(std430, binding = 9) buffer layoutName { - uint dist[4][256]; + uint dist[4 * 256]; }; #define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14)) @@ -125,9 +129,11 @@ void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inou void main() { - uint x = 8 * gl_WorkGroupID.x; + uint sx = gl_WorkGroupID.x * NUM_Z + gl_LocalInvocationID.z; + uint x = 8 * sx; uint y = 8 * gl_WorkGroupID.y; uint n = gl_LocalInvocationID.x; + uint z = gl_LocalInvocationID.z; // Load column. float y0 = imageLoad(image_tex, ivec2(x + n, y + 0)).x; @@ -143,27 +149,28 @@ void main() dct_1d(y0, y1, y2, y3, y4, y5, y6, y7); // Communicate with the other shaders in the group. - temp[n + 0 * 8] = y0; - temp[n + 1 * 8] = y1; - temp[n + 2 * 8] = y2; - temp[n + 3 * 8] = y3; - temp[n + 4 * 8] = y4; - temp[n + 5 * 8] = y5; - temp[n + 6 * 8] = y6; - temp[n + 7 * 8] = y7; + uint base_idx = 64 * z; + temp[base_idx + 0 * 8 + n] = floatBitsToUint(y0); + temp[base_idx + 1 * 8 + n] = floatBitsToUint(y1); + temp[base_idx + 2 * 8 + n] = floatBitsToUint(y2); + temp[base_idx + 3 * 8 + n] = floatBitsToUint(y3); + temp[base_idx + 4 * 8 + n] = floatBitsToUint(y4); + temp[base_idx + 5 * 8 + n] = floatBitsToUint(y5); + temp[base_idx + 6 * 8 + n] = floatBitsToUint(y6); + temp[base_idx + 7 * 8 + n] = floatBitsToUint(y7); memoryBarrierShared(); barrier(); // Load row (so transpose, in a sense). - y0 = temp[n * 8 + 0]; - y1 = temp[n * 8 + 1]; - y2 = temp[n * 8 + 2]; - y3 = temp[n * 8 + 3]; - y4 = temp[n * 8 + 4]; - y5 = temp[n * 8 + 5]; - y6 = temp[n * 8 + 6]; - y7 = temp[n * 8 + 7]; + y0 = uintBitsToFloat(temp[base_idx + n * 8 + 0]); + y1 = uintBitsToFloat(temp[base_idx + n * 8 + 1]); + y2 = uintBitsToFloat(temp[base_idx + n * 8 + 2]); + y3 = uintBitsToFloat(temp[base_idx + n * 8 + 3]); + y4 = uintBitsToFloat(temp[base_idx + n * 8 + 4]); + y5 = uintBitsToFloat(temp[base_idx + n * 8 + 5]); + y6 = uintBitsToFloat(temp[base_idx + n * 8 + 6]); + y7 = uintBitsToFloat(temp[base_idx + n * 8 + 7]); // Horizontal DCT. dct_1d(y0, y1, y2, y3, y4, y5, y6, y7); @@ -179,45 +186,62 @@ void main() int c7 = int(round(y7 * quant_matrix[n * 8 + 7])); // Clamp, pack and store. - uint sx = gl_WorkGroupID.x; imageStore(dc_ac7_tex, ivec2(sx, y + n), uvec4(pack_9_7(c0, c7), 0, 0, 0)); imageStore(ac1_ac6_tex, ivec2(sx, y + n), uvec4(pack_9_7(c1, c6), 0, 0, 0)); imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0)); imageStore(ac3_tex, ivec2(sx, y + n), ivec4(c3, 0, 0, 0)); imageStore(ac4_tex, ivec2(sx, y + n), ivec4(c4, 0, 0, 0)); - // Count frequencies, but only for every 8th block or so, randomly selected. - uint wg_index = gl_WorkGroupID.y * gl_WorkGroupSize.x + gl_WorkGroupID.x; - if ((wg_index * 0x9E3779B9u) >> 29 == 0) { // Fibonacci hashing, essentially a PRNG in this context. - c0 = min(abs(c0), 255); - c1 = min(abs(c1), 255); - c2 = min(abs(c2), 255); - c3 = min(abs(c3), 255); - c4 = min(abs(c4), 255); - c5 = min(abs(c5), 255); - c6 = min(abs(c6), 255); - c7 = min(abs(c7), 255); - - // Spread out the most popular elements among the cache lines by reversing the bits - // of the index, reducing false sharing. - c0 = bitfieldReverse(c0) >> 24; - c1 = bitfieldReverse(c1) >> 24; - c2 = bitfieldReverse(c2) >> 24; - c3 = bitfieldReverse(c3) >> 24; - c4 = bitfieldReverse(c4) >> 24; - c5 = bitfieldReverse(c5) >> 24; - c6 = bitfieldReverse(c6) >> 24; - c7 = bitfieldReverse(c7) >> 24; - - uint m = luma_mapping[n]; - atomicAdd(dist[bitfieldExtract(m, 0, 2)][c0], 1); - atomicAdd(dist[bitfieldExtract(m, 2, 2)][c1], 1); - atomicAdd(dist[bitfieldExtract(m, 4, 2)][c2], 1); - atomicAdd(dist[bitfieldExtract(m, 6, 2)][c3], 1); - atomicAdd(dist[bitfieldExtract(m, 8, 2)][c4], 1); - atomicAdd(dist[bitfieldExtract(m, 10, 2)][c5], 1); - atomicAdd(dist[bitfieldExtract(m, 12, 2)][c6], 1); - atomicAdd(dist[bitfieldExtract(m, 14, 2)][c7], 1); - } -} + // Zero out the temporary area in preparation for counting up the histograms. + base_idx += 8 * n; + temp[base_idx + 0] = 0; + temp[base_idx + 1] = 0; + temp[base_idx + 2] = 0; + temp[base_idx + 3] = 0; + temp[base_idx + 4] = 0; + temp[base_idx + 5] = 0; + temp[base_idx + 6] = 0; + temp[base_idx + 7] = 0; + + memoryBarrierShared(); + barrier(); + // Count frequencies into four histograms. We do this to local memory first, + // because this is _much_ faster; then we do global atomic adds for the nonzero + // members. + + // First take the absolute value (signs are encoded differently) and clamp, + // as any value over 255 is going to be encoded as an escape. + c0 = min(abs(c0), 255); + c1 = min(abs(c1), 255); + c2 = min(abs(c2), 255); + c3 = min(abs(c3), 255); + c4 = min(abs(c4), 255); + c5 = min(abs(c5), 255); + c6 = min(abs(c6), 255); + c7 = min(abs(c7), 255); + + // Add up in local memory. + uint m = luma_mapping[n]; + atomicAdd(temp[bitfieldExtract(m, 0, 2) * 256 + c0], 1); + atomicAdd(temp[bitfieldExtract(m, 2, 2) * 256 + c1], 1); + atomicAdd(temp[bitfieldExtract(m, 4, 2) * 256 + c2], 1); + atomicAdd(temp[bitfieldExtract(m, 6, 2) * 256 + c3], 1); + atomicAdd(temp[bitfieldExtract(m, 8, 2) * 256 + c4], 1); + atomicAdd(temp[bitfieldExtract(m, 10, 2) * 256 + c5], 1); + atomicAdd(temp[bitfieldExtract(m, 12, 2) * 256 + c6], 1); + atomicAdd(temp[bitfieldExtract(m, 14, 2) * 256 + c7], 1); + + memoryBarrierShared(); + barrier(); + + // Add from local memory to global memory. + if (temp[base_idx + 0] != 0) atomicAdd(dist[base_idx + 0], temp[base_idx + 0]); + if (temp[base_idx + 1] != 0) atomicAdd(dist[base_idx + 1], temp[base_idx + 1]); + if (temp[base_idx + 2] != 0) atomicAdd(dist[base_idx + 2], temp[base_idx + 2]); + if (temp[base_idx + 3] != 0) atomicAdd(dist[base_idx + 3], temp[base_idx + 3]); + if (temp[base_idx + 4] != 0) atomicAdd(dist[base_idx + 4], temp[base_idx + 4]); + if (temp[base_idx + 5] != 0) atomicAdd(dist[base_idx + 5], temp[base_idx + 5]); + if (temp[base_idx + 6] != 0) atomicAdd(dist[base_idx + 6], temp[base_idx + 6]); + if (temp[base_idx + 7] != 0) atomicAdd(dist[base_idx + 7], temp[base_idx + 7]); +} diff --git a/narabu-encoder.cpp b/narabu-encoder.cpp index e8d20e0..7301359 100644 --- a/narabu-encoder.cpp +++ b/narabu-encoder.cpp @@ -401,7 +401,7 @@ int main(int argc, char **argv) steady_clock::time_point start = steady_clock::now(); unsigned num_iterations = 100; for (unsigned i = 0; i < num_iterations; ++i) { - glDispatchCompute(WIDTH_BLOCKS, HEIGHT_BLOCKS, 1); + glDispatchCompute(WIDTH_BLOCKS / 16, HEIGHT_BLOCKS, 1); } check_error(); glFinish();