]> git.sesse.net Git - narabu/blobdiff - encoder.shader
More fixes of hard-coded values.
[narabu] / encoder.shader
index da9e6c07b2c730274513087a6ffaa5239146b29e..8d4bcc2e5b1e46e70cbe460daa93eaf1c3c07a9b 100644 (file)
@@ -1,7 +1,11 @@
 #version 440
 #extension GL_ARB_shader_clock : enable
 
-layout(local_size_x = 8) in;
+// Do sixteen 8x8 blocks in a local group, because that matches up perfectly
+// with needing 1024 coefficients for our four histograms (of 256 bins each).
+#define NUM_Z 16
+
+layout(local_size_x = 8, local_size_z = NUM_Z) in;
 
 layout(r16ui) uniform restrict writeonly uimage2D dc_ac7_tex;
 layout(r16ui) uniform restrict writeonly uimage2D ac1_ac6_tex;
@@ -10,11 +14,11 @@ layout(r8i) uniform restrict writeonly iimage2D ac3_tex;
 layout(r8i) uniform restrict writeonly iimage2D ac4_tex;
 layout(r8ui) uniform restrict readonly uimage2D image_tex;
 
-shared float temp[64];
+shared uint temp[64 * NUM_Z];
 
 layout(std430, binding = 9) buffer layoutName
 {
-       uint dist[4][256];
+       uint dist[4 * 256];
 };
 
 #define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14))
@@ -49,15 +53,16 @@ const float W[64] = {
 const float S = 4.0 * 0.5;  // whatever?
 
 // NOTE: Contains factors to counteract the scaling in the DCT implementation.
+#define QM(x, y) (sf[x] * sf[y] / (W[y*8 + x] * S))
 const float quant_matrix[64] = {
-        sf[0] * sf[0] / 64.0,         sf[1] * sf[0] / (W[ 1] * S),  sf[2] * sf[0] / (W[ 2] * S),  sf[3] * sf[0] / (W[ 3] * S),  sf[4] * sf[0] / (W[ 4] * S),  sf[5] * sf[0] / (W[ 5] * S),  sf[6] * sf[0] / (W[ 6] * S),  sf[7] * sf[0] / (W[ 7] * S),
-        sf[0] * sf[1] / (W[ 8] * S),  sf[1] * sf[1] / (W[ 9] * S),  sf[2] * sf[1] / (W[10] * S),  sf[3] * sf[1] / (W[11] * S),  sf[4] * sf[1] / (W[12] * S),  sf[5] * sf[1] / (W[13] * S),  sf[6] * sf[1] / (W[14] * S),  sf[7] * sf[1] / (W[15] * S),
-        sf[0] * sf[2] / (W[16] * S),  sf[1] * sf[2] / (W[17] * S),  sf[2] * sf[2] / (W[18] * S),  sf[3] * sf[2] / (W[19] * S),  sf[4] * sf[2] / (W[20] * S),  sf[5] * sf[2] / (W[21] * S),  sf[6] * sf[2] / (W[22] * S),  sf[7] * sf[2] / (W[23] * S),
-        sf[0] * sf[3] / (W[24] * S),  sf[1] * sf[3] / (W[25] * S),  sf[2] * sf[3] / (W[26] * S),  sf[3] * sf[3] / (W[27] * S),  sf[4] * sf[3] / (W[28] * S),  sf[5] * sf[3] / (W[29] * S),  sf[6] * sf[3] / (W[30] * S),  sf[7] * sf[3] / (W[31] * S),
-        sf[0] * sf[4] / (W[32] * S),  sf[1] * sf[4] / (W[33] * S),  sf[2] * sf[4] / (W[34] * S),  sf[3] * sf[4] / (W[35] * S),  sf[4] * sf[4] / (W[36] * S),  sf[5] * sf[4] / (W[37] * S),  sf[6] * sf[4] / (W[38] * S),  sf[7] * sf[4] / (W[39] * S),
-        sf[0] * sf[5] / (W[40] * S),  sf[1] * sf[5] / (W[41] * S),  sf[2] * sf[5] / (W[42] * S),  sf[3] * sf[5] / (W[43] * S),  sf[4] * sf[5] / (W[44] * S),  sf[5] * sf[5] / (W[45] * S),  sf[6] * sf[5] / (W[46] * S),  sf[7] * sf[5] / (W[47] * S),
-        sf[0] * sf[6] / (W[48] * S),  sf[1] * sf[6] / (W[49] * S),  sf[2] * sf[6] / (W[50] * S),  sf[3] * sf[6] / (W[51] * S),  sf[4] * sf[6] / (W[52] * S),  sf[5] * sf[6] / (W[53] * S),  sf[6] * sf[6] / (W[54] * S),  sf[7] * sf[6] / (W[55] * S),
-        sf[0] * sf[7] / (W[56] * S),  sf[1] * sf[7] / (W[57] * S),  sf[2] * sf[7] / (W[58] * S),  sf[3] * sf[7] / (W[59] * S),  sf[4] * sf[7] / (W[60] * S),  sf[5] * sf[7] / (W[61] * S),  sf[6] * sf[7] / (W[62] * S),  sf[7] * sf[7] / (W[63] * S)
+        1.0 / 64.0, QM(1, 0), QM(2, 0), QM(3, 0), QM(4, 0), QM(5, 0), QM(6, 0), QM(7, 0),
+        QM(0, 1),   QM(1, 1), QM(2, 1), QM(3, 1), QM(4, 1), QM(5, 1), QM(6, 1), QM(7, 1),
+        QM(0, 2),   QM(1, 2), QM(2, 2), QM(3, 2), QM(4, 2), QM(5, 2), QM(6, 2), QM(7, 2),
+        QM(0, 3),   QM(1, 3), QM(2, 3), QM(3, 3), QM(4, 3), QM(5, 3), QM(6, 3), QM(7, 3),
+        QM(0, 4),   QM(1, 4), QM(2, 4), QM(3, 4), QM(4, 4), QM(5, 4), QM(6, 4), QM(7, 4),
+        QM(0, 5),   QM(1, 5), QM(2, 5), QM(3, 5), QM(4, 5), QM(5, 5), QM(6, 5), QM(7, 5),
+        QM(0, 6),   QM(1, 6), QM(2, 6), QM(3, 6), QM(4, 6), QM(5, 6), QM(6, 6), QM(7, 6),
+        QM(0, 7),   QM(1, 7), QM(2, 7), QM(3, 7), QM(4, 7), QM(5, 7), QM(6, 7), QM(7, 7)
 };
 
 // Clamp and pack a 9-bit and a 7-bit signed value into a 16-bit word.
@@ -124,9 +129,11 @@ void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inou
 
 void main()
 {
-       uint x = 8 * gl_WorkGroupID.x;
+       uint sx = gl_WorkGroupID.x * NUM_Z + gl_LocalInvocationID.z;
+       uint x = 8 * sx;
        uint y = 8 * gl_WorkGroupID.y;
        uint n = gl_LocalInvocationID.x;
+       uint z = gl_LocalInvocationID.z;
 
        // Load column.
        float y0 = imageLoad(image_tex, ivec2(x + n, y + 0)).x;
@@ -142,27 +149,28 @@ void main()
        dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
 
        // Communicate with the other shaders in the group.
-       temp[n + 0 * 8] = y0;
-       temp[n + 1 * 8] = y1;
-       temp[n + 2 * 8] = y2;
-       temp[n + 3 * 8] = y3;
-       temp[n + 4 * 8] = y4;
-       temp[n + 5 * 8] = y5;
-       temp[n + 6 * 8] = y6;
-       temp[n + 7 * 8] = y7;
+       uint base_idx = 64 * z;
+       temp[base_idx + 0 * 8 + n] = floatBitsToUint(y0);
+       temp[base_idx + 1 * 8 + n] = floatBitsToUint(y1);
+       temp[base_idx + 2 * 8 + n] = floatBitsToUint(y2);
+       temp[base_idx + 3 * 8 + n] = floatBitsToUint(y3);
+       temp[base_idx + 4 * 8 + n] = floatBitsToUint(y4);
+       temp[base_idx + 5 * 8 + n] = floatBitsToUint(y5);
+       temp[base_idx + 6 * 8 + n] = floatBitsToUint(y6);
+       temp[base_idx + 7 * 8 + n] = floatBitsToUint(y7);
 
        memoryBarrierShared();
        barrier();
 
        // Load row (so transpose, in a sense).
-       y0 = temp[n * 8 + 0];
-       y1 = temp[n * 8 + 1];
-       y2 = temp[n * 8 + 2];
-       y3 = temp[n * 8 + 3];
-       y4 = temp[n * 8 + 4];
-       y5 = temp[n * 8 + 5];
-       y6 = temp[n * 8 + 6];
-       y7 = temp[n * 8 + 7];
+       y0 = uintBitsToFloat(temp[base_idx + n * 8 + 0]);
+       y1 = uintBitsToFloat(temp[base_idx + n * 8 + 1]);
+       y2 = uintBitsToFloat(temp[base_idx + n * 8 + 2]);
+       y3 = uintBitsToFloat(temp[base_idx + n * 8 + 3]);
+       y4 = uintBitsToFloat(temp[base_idx + n * 8 + 4]);
+       y5 = uintBitsToFloat(temp[base_idx + n * 8 + 5]);
+       y6 = uintBitsToFloat(temp[base_idx + n * 8 + 6]);
+       y7 = uintBitsToFloat(temp[base_idx + n * 8 + 7]);
 
        // Horizontal DCT.
        dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
@@ -178,45 +186,62 @@ void main()
        int c7 = int(round(y7 * quant_matrix[n * 8 + 7]));
 
        // Clamp, pack and store.
-       uint sx = gl_WorkGroupID.x;
        imageStore(dc_ac7_tex,  ivec2(sx, y + n), uvec4(pack_9_7(c0, c7), 0, 0, 0));
        imageStore(ac1_ac6_tex, ivec2(sx, y + n), uvec4(pack_9_7(c1, c6), 0, 0, 0));
        imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0));
        imageStore(ac3_tex,     ivec2(sx, y + n), ivec4(c3, 0, 0, 0));
        imageStore(ac4_tex,     ivec2(sx, y + n), ivec4(c4, 0, 0, 0));
 
-       // Count frequencies, but only for every 8th block or so, randomly selected.
-       uint wg_index = gl_WorkGroupID.y * gl_WorkGroupSize.x + gl_WorkGroupID.x;
-       if ((wg_index * 0x9E3779B9u) >> 29 == 0) {  // Fibonacci hashing, essentially a PRNG in this context.
-               c0 = min(abs(c0), 255);
-               c1 = min(abs(c1), 255);
-               c2 = min(abs(c2), 255);
-               c3 = min(abs(c3), 255);
-               c4 = min(abs(c4), 255);
-               c5 = min(abs(c5), 255);
-               c6 = min(abs(c6), 255);
-               c7 = min(abs(c7), 255);
-
-               // Spread out the most popular elements among the cache lines by reversing the bits
-               // of the index, reducing false sharing.
-               c0 = bitfieldReverse(c0) >> 24;
-               c1 = bitfieldReverse(c1) >> 24;
-               c2 = bitfieldReverse(c2) >> 24;
-               c3 = bitfieldReverse(c3) >> 24;
-               c4 = bitfieldReverse(c4) >> 24;
-               c5 = bitfieldReverse(c5) >> 24;
-               c6 = bitfieldReverse(c6) >> 24;
-               c7 = bitfieldReverse(c7) >> 24;
-
-               uint m = luma_mapping[n];
-               atomicAdd(dist[bitfieldExtract(m,  0, 2)][c0], 1);
-               atomicAdd(dist[bitfieldExtract(m,  2, 2)][c1], 1);
-               atomicAdd(dist[bitfieldExtract(m,  4, 2)][c2], 1);
-               atomicAdd(dist[bitfieldExtract(m,  6, 2)][c3], 1);
-               atomicAdd(dist[bitfieldExtract(m,  8, 2)][c4], 1);
-               atomicAdd(dist[bitfieldExtract(m, 10, 2)][c5], 1);
-               atomicAdd(dist[bitfieldExtract(m, 12, 2)][c6], 1);
-               atomicAdd(dist[bitfieldExtract(m, 14, 2)][c7], 1);
-       }
-}
+       // Zero out the temporary area in preparation for counting up the histograms.
+       base_idx += 8 * n;
+       temp[base_idx + 0] = 0;
+       temp[base_idx + 1] = 0;
+       temp[base_idx + 2] = 0;
+       temp[base_idx + 3] = 0;
+       temp[base_idx + 4] = 0;
+       temp[base_idx + 5] = 0;
+       temp[base_idx + 6] = 0;
+       temp[base_idx + 7] = 0;
+
+       memoryBarrierShared();
+       barrier();
 
+       // Count frequencies into four histograms. We do this to local memory first,
+       // because this is _much_ faster; then we do global atomic adds for the nonzero
+       // members.
+
+       // First take the absolute value (signs are encoded differently) and clamp,
+       // as any value over 255 is going to be encoded as an escape.
+       c0 = min(abs(c0), 255);
+       c1 = min(abs(c1), 255);
+       c2 = min(abs(c2), 255);
+       c3 = min(abs(c3), 255);
+       c4 = min(abs(c4), 255);
+       c5 = min(abs(c5), 255);
+       c6 = min(abs(c6), 255);
+       c7 = min(abs(c7), 255);
+
+       // Add up in local memory.
+       uint m = luma_mapping[n];
+       atomicAdd(temp[bitfieldExtract(m,  0, 2) * 256 + c0], 1);
+       atomicAdd(temp[bitfieldExtract(m,  2, 2) * 256 + c1], 1);
+       atomicAdd(temp[bitfieldExtract(m,  4, 2) * 256 + c2], 1);
+       atomicAdd(temp[bitfieldExtract(m,  6, 2) * 256 + c3], 1);
+       atomicAdd(temp[bitfieldExtract(m,  8, 2) * 256 + c4], 1);
+       atomicAdd(temp[bitfieldExtract(m, 10, 2) * 256 + c5], 1);
+       atomicAdd(temp[bitfieldExtract(m, 12, 2) * 256 + c6], 1);
+       atomicAdd(temp[bitfieldExtract(m, 14, 2) * 256 + c7], 1);
+
+       memoryBarrierShared();
+       barrier();
+
+       // Add from local memory to global memory.
+       if (temp[base_idx + 0] != 0) atomicAdd(dist[base_idx + 0], temp[base_idx + 0]);
+       if (temp[base_idx + 1] != 0) atomicAdd(dist[base_idx + 1], temp[base_idx + 1]);
+       if (temp[base_idx + 2] != 0) atomicAdd(dist[base_idx + 2], temp[base_idx + 2]);
+       if (temp[base_idx + 3] != 0) atomicAdd(dist[base_idx + 3], temp[base_idx + 3]);
+       if (temp[base_idx + 4] != 0) atomicAdd(dist[base_idx + 4], temp[base_idx + 4]);
+       if (temp[base_idx + 5] != 0) atomicAdd(dist[base_idx + 5], temp[base_idx + 5]);
+       if (temp[base_idx + 6] != 0) atomicAdd(dist[base_idx + 6], temp[base_idx + 6]);
+       if (temp[base_idx + 7] != 0) atomicAdd(dist[base_idx + 7], temp[base_idx + 7]);
+}