#version 440
#extension GL_ARB_shader_clock : enable
-layout(local_size_x = 8) in;
+// Do sixteen 8x8 blocks in a local group, because that matches up perfectly
+// with needing 1024 coefficients for our four histograms (of 256 bins each).
+#define NUM_Z 16
+
+layout(local_size_x = 8, local_size_z = NUM_Z) in;
layout(r16ui) uniform restrict writeonly uimage2D dc_ac7_tex;
layout(r16ui) uniform restrict writeonly uimage2D ac1_ac6_tex;
layout(r8i) uniform restrict writeonly iimage2D ac4_tex;
layout(r8ui) uniform restrict readonly uimage2D image_tex;
-shared float temp[64];
+shared uint temp[64 * NUM_Z];
layout(std430, binding = 9) buffer layoutName
{
- uint dist[4][256];
+ uint dist[4 * 256];
};
#define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14))
void main()
{
- uint x = 8 * gl_WorkGroupID.x;
+ uint sx = gl_WorkGroupID.x * NUM_Z + gl_LocalInvocationID.z;
+ uint x = 8 * sx;
uint y = 8 * gl_WorkGroupID.y;
uint n = gl_LocalInvocationID.x;
+ uint z = gl_LocalInvocationID.z;
// Load column.
float y0 = imageLoad(image_tex, ivec2(x + n, y + 0)).x;
dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
// Communicate with the other shaders in the group.
- temp[n + 0 * 8] = y0;
- temp[n + 1 * 8] = y1;
- temp[n + 2 * 8] = y2;
- temp[n + 3 * 8] = y3;
- temp[n + 4 * 8] = y4;
- temp[n + 5 * 8] = y5;
- temp[n + 6 * 8] = y6;
- temp[n + 7 * 8] = y7;
+ uint base_idx = 64 * z;
+ temp[base_idx + 0 * 8 + n] = floatBitsToUint(y0);
+ temp[base_idx + 1 * 8 + n] = floatBitsToUint(y1);
+ temp[base_idx + 2 * 8 + n] = floatBitsToUint(y2);
+ temp[base_idx + 3 * 8 + n] = floatBitsToUint(y3);
+ temp[base_idx + 4 * 8 + n] = floatBitsToUint(y4);
+ temp[base_idx + 5 * 8 + n] = floatBitsToUint(y5);
+ temp[base_idx + 6 * 8 + n] = floatBitsToUint(y6);
+ temp[base_idx + 7 * 8 + n] = floatBitsToUint(y7);
memoryBarrierShared();
barrier();
// Load row (so transpose, in a sense).
- y0 = temp[n * 8 + 0];
- y1 = temp[n * 8 + 1];
- y2 = temp[n * 8 + 2];
- y3 = temp[n * 8 + 3];
- y4 = temp[n * 8 + 4];
- y5 = temp[n * 8 + 5];
- y6 = temp[n * 8 + 6];
- y7 = temp[n * 8 + 7];
+ y0 = uintBitsToFloat(temp[base_idx + n * 8 + 0]);
+ y1 = uintBitsToFloat(temp[base_idx + n * 8 + 1]);
+ y2 = uintBitsToFloat(temp[base_idx + n * 8 + 2]);
+ y3 = uintBitsToFloat(temp[base_idx + n * 8 + 3]);
+ y4 = uintBitsToFloat(temp[base_idx + n * 8 + 4]);
+ y5 = uintBitsToFloat(temp[base_idx + n * 8 + 5]);
+ y6 = uintBitsToFloat(temp[base_idx + n * 8 + 6]);
+ y7 = uintBitsToFloat(temp[base_idx + n * 8 + 7]);
// Horizontal DCT.
dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
int c7 = int(round(y7 * quant_matrix[n * 8 + 7]));
// Clamp, pack and store.
- uint sx = gl_WorkGroupID.x;
imageStore(dc_ac7_tex, ivec2(sx, y + n), uvec4(pack_9_7(c0, c7), 0, 0, 0));
imageStore(ac1_ac6_tex, ivec2(sx, y + n), uvec4(pack_9_7(c1, c6), 0, 0, 0));
imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0));
imageStore(ac3_tex, ivec2(sx, y + n), ivec4(c3, 0, 0, 0));
imageStore(ac4_tex, ivec2(sx, y + n), ivec4(c4, 0, 0, 0));
- // Count frequencies, but only for every 8th block or so, randomly selected.
- uint wg_index = gl_WorkGroupID.y * gl_WorkGroupSize.x + gl_WorkGroupID.x;
- if ((wg_index * 0x9E3779B9u) >> 29 == 0) { // Fibonacci hashing, essentially a PRNG in this context.
- c0 = min(abs(c0), 255);
- c1 = min(abs(c1), 255);
- c2 = min(abs(c2), 255);
- c3 = min(abs(c3), 255);
- c4 = min(abs(c4), 255);
- c5 = min(abs(c5), 255);
- c6 = min(abs(c6), 255);
- c7 = min(abs(c7), 255);
-
- // Spread out the most popular elements among the cache lines by reversing the bits
- // of the index, reducing false sharing.
- c0 = bitfieldReverse(c0) >> 24;
- c1 = bitfieldReverse(c1) >> 24;
- c2 = bitfieldReverse(c2) >> 24;
- c3 = bitfieldReverse(c3) >> 24;
- c4 = bitfieldReverse(c4) >> 24;
- c5 = bitfieldReverse(c5) >> 24;
- c6 = bitfieldReverse(c6) >> 24;
- c7 = bitfieldReverse(c7) >> 24;
-
- uint m = luma_mapping[n];
- atomicAdd(dist[bitfieldExtract(m, 0, 2)][c0], 1);
- atomicAdd(dist[bitfieldExtract(m, 2, 2)][c1], 1);
- atomicAdd(dist[bitfieldExtract(m, 4, 2)][c2], 1);
- atomicAdd(dist[bitfieldExtract(m, 6, 2)][c3], 1);
- atomicAdd(dist[bitfieldExtract(m, 8, 2)][c4], 1);
- atomicAdd(dist[bitfieldExtract(m, 10, 2)][c5], 1);
- atomicAdd(dist[bitfieldExtract(m, 12, 2)][c6], 1);
- atomicAdd(dist[bitfieldExtract(m, 14, 2)][c7], 1);
- }
-}
+ // Zero out the temporary area in preparation for counting up the histograms.
+ base_idx += 8 * n;
+ temp[base_idx + 0] = 0;
+ temp[base_idx + 1] = 0;
+ temp[base_idx + 2] = 0;
+ temp[base_idx + 3] = 0;
+ temp[base_idx + 4] = 0;
+ temp[base_idx + 5] = 0;
+ temp[base_idx + 6] = 0;
+ temp[base_idx + 7] = 0;
+
+ memoryBarrierShared();
+ barrier();
+ // Count frequencies into four histograms. We do this to local memory first,
+ // because this is _much_ faster; then we do global atomic adds for the nonzero
+ // members.
+
+ // First take the absolute value (signs are encoded differently) and clamp,
+ // as any value over 255 is going to be encoded as an escape.
+ c0 = min(abs(c0), 255);
+ c1 = min(abs(c1), 255);
+ c2 = min(abs(c2), 255);
+ c3 = min(abs(c3), 255);
+ c4 = min(abs(c4), 255);
+ c5 = min(abs(c5), 255);
+ c6 = min(abs(c6), 255);
+ c7 = min(abs(c7), 255);
+
+ // Add up in local memory.
+ uint m = luma_mapping[n];
+ atomicAdd(temp[bitfieldExtract(m, 0, 2) * 256 + c0], 1);
+ atomicAdd(temp[bitfieldExtract(m, 2, 2) * 256 + c1], 1);
+ atomicAdd(temp[bitfieldExtract(m, 4, 2) * 256 + c2], 1);
+ atomicAdd(temp[bitfieldExtract(m, 6, 2) * 256 + c3], 1);
+ atomicAdd(temp[bitfieldExtract(m, 8, 2) * 256 + c4], 1);
+ atomicAdd(temp[bitfieldExtract(m, 10, 2) * 256 + c5], 1);
+ atomicAdd(temp[bitfieldExtract(m, 12, 2) * 256 + c6], 1);
+ atomicAdd(temp[bitfieldExtract(m, 14, 2) * 256 + c7], 1);
+
+ memoryBarrierShared();
+ barrier();
+
+ // Add from local memory to global memory.
+ if (temp[base_idx + 0] != 0) atomicAdd(dist[base_idx + 0], temp[base_idx + 0]);
+ if (temp[base_idx + 1] != 0) atomicAdd(dist[base_idx + 1], temp[base_idx + 1]);
+ if (temp[base_idx + 2] != 0) atomicAdd(dist[base_idx + 2], temp[base_idx + 2]);
+ if (temp[base_idx + 3] != 0) atomicAdd(dist[base_idx + 3], temp[base_idx + 3]);
+ if (temp[base_idx + 4] != 0) atomicAdd(dist[base_idx + 4], temp[base_idx + 4]);
+ if (temp[base_idx + 5] != 0) atomicAdd(dist[base_idx + 5], temp[base_idx + 5]);
+ if (temp[base_idx + 6] != 0) atomicAdd(dist[base_idx + 6], temp[base_idx + 6]);
+ if (temp[base_idx + 7] != 0) atomicAdd(dist[base_idx + 7], temp[base_idx + 7]);
+}