]> git.sesse.net Git - narabu/blobdiff - encoder.shader
Make quant_matrix a bit more compact.
[narabu] / encoder.shader
index 54103a093d374cca0fb2fcda37a7add308574268..163f8fa0110aed25ff9654156f115b54167d95e5 100644 (file)
@@ -12,6 +12,30 @@ layout(r8ui) uniform restrict readonly uimage2D image_tex;
 
 shared float temp[64];
 
+layout(std430, binding = 9) buffer layoutName
+{
+       uint dist[4][256];
+};
+
+#define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14))
+
+const uint luma_mapping[8] = {
+       MAPPING(0, 0, 1, 1, 2, 2, 3, 3),
+       MAPPING(0, 0, 1, 2, 2, 2, 3, 3),
+       MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
+       MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
+       MAPPING(1, 2, 2, 2, 2, 3, 3, 3),
+       MAPPING(2, 2, 2, 2, 3, 3, 3, 3),
+       MAPPING(2, 2, 3, 3, 3, 3, 3, 3),
+       MAPPING(3, 3, 3, 3, 3, 3, 3, 3),
+};
+
+// Scale factors; 1.0 / (sqrt(2.0) * cos(k * M_PI / 16.0)), except for the first which is 1.
+const float sf[8] = {
+       1.0, 0.7209598220069479, 0.765366864730180, 0.8504300947672564,
+       1.0, 1.2727585805728336, 1.847759065022573, 3.6245097854115502
+};
+
 const float W[64] = {
          8, 16, 19, 22, 26, 27, 29, 34,
         16, 16, 22, 24, 27, 29, 34, 37,
@@ -22,18 +46,19 @@ const float W[64] = {
         26, 27, 29, 34, 38, 46, 56, 69,
         27, 29, 35, 38, 46, 56, 69, 83
 };
-const float S = 4.0;  // whatever?
+const float S = 4.0 * 0.5;  // whatever?
 
 // NOTE: Contains factors to counteract the scaling in the DCT implementation.
+#define QM(x, y) (sf[x] * sf[y] / (W[y*8 + x] * S))
 const float quant_matrix[64] = {
-        1.0 / 64.0,         1.0 / (W[ 1] * S),  1.0 / (W[ 2] * S),  1.0 / (W[ 3] * S),  1.0 / (W[ 4] * S),  1.0 / (W[ 5] * S),  1.0 / (W[ 6] * S),  1.0 / (W[ 7] * S),
-        1.0 / (W[ 8] * S),  2.0 / (W[ 9] * S),  2.0 / (W[10] * S),  2.0 / (W[11] * S),  2.0 / (W[12] * S),  2.0 / (W[13] * S),  2.0 / (W[14] * S),  2.0 / (W[15] * S),
-        1.0 / (W[16] * S),  2.0 / (W[17] * S),  2.0 / (W[18] * S),  2.0 / (W[19] * S),  2.0 / (W[20] * S),  2.0 / (W[21] * S),  2.0 / (W[22] * S),  2.0 / (W[23] * S),
-        1.0 / (W[24] * S),  2.0 / (W[25] * S),  2.0 / (W[26] * S),  2.0 / (W[27] * S),  2.0 / (W[28] * S),  2.0 / (W[29] * S),  2.0 / (W[30] * S),  2.0 / (W[31] * S),
-        1.0 / (W[32] * S),  2.0 / (W[33] * S),  2.0 / (W[34] * S),  2.0 / (W[35] * S),  2.0 / (W[36] * S),  2.0 / (W[37] * S),  2.0 / (W[38] * S),  2.0 / (W[39] * S),
-        1.0 / (W[40] * S),  2.0 / (W[41] * S),  2.0 / (W[42] * S),  2.0 / (W[43] * S),  2.0 / (W[44] * S),  2.0 / (W[45] * S),  2.0 / (W[46] * S),  2.0 / (W[47] * S),
-        1.0 / (W[48] * S),  2.0 / (W[49] * S),  2.0 / (W[50] * S),  2.0 / (W[51] * S),  2.0 / (W[52] * S),  2.0 / (W[53] * S),  2.0 / (W[54] * S),  2.0 / (W[55] * S),
-        1.0 / (W[56] * S),  2.0 / (W[57] * S),  2.0 / (W[58] * S),  2.0 / (W[59] * S),  2.0 / (W[60] * S),  2.0 / (W[61] * S),  2.0 / (W[62] * S),  2.0 / (W[63] * S)
+        1.0 / 64.0, QM(1, 0), QM(2, 0), QM(3, 0), QM(4, 0), QM(5, 0), QM(6, 0), QM(7, 0),
+        QM(0, 1),   QM(1, 1), QM(2, 1), QM(3, 1), QM(4, 1), QM(5, 1), QM(6, 1), QM(7, 1),
+        QM(0, 2),   QM(1, 2), QM(2, 2), QM(3, 2), QM(4, 2), QM(5, 2), QM(6, 2), QM(7, 2),
+        QM(0, 3),   QM(1, 3), QM(2, 3), QM(3, 3), QM(4, 3), QM(5, 3), QM(6, 3), QM(7, 3),
+        QM(0, 4),   QM(1, 4), QM(2, 4), QM(3, 4), QM(4, 4), QM(5, 4), QM(6, 4), QM(7, 4),
+        QM(0, 5),   QM(1, 5), QM(2, 5), QM(3, 5), QM(4, 5), QM(5, 5), QM(6, 5), QM(7, 5),
+        QM(0, 6),   QM(1, 6), QM(2, 6), QM(3, 6), QM(4, 6), QM(5, 6), QM(6, 6), QM(7, 6),
+        QM(0, 7),   QM(1, 7), QM(2, 7), QM(3, 7), QM(4, 7), QM(5, 7), QM(6, 7), QM(7, 7)
 };
 
 // Clamp and pack a 9-bit and a 7-bit signed value into a 16-bit word.
@@ -42,7 +67,7 @@ uint pack_9_7(int v9, int v7)
        return (uint(clamp(v9, -256, 255)) & 0x1ffu) | ((uint(clamp(v7, -64, 63)) & 0x7fu) << 9);
 }
 
-// Scaled 1D DCT. y0 output is scaled by 8, everything else is scaled by 16.
+// Scaled 1D DCT (AA&N). y0 is correctly scaled, all other y_k are scaled by sqrt(2) cos(k * Pi / 16).
 void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
 {
        const float a1 = 0.7071067811865474;   // sqrt(2)
@@ -97,6 +122,7 @@ void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inou
        y7 = p5_5 - p4_6;
        y3 = p5_7 - p4_4;
 }
+
 void main()
 {
        uint x = 8 * gl_WorkGroupID.x;
@@ -159,5 +185,39 @@ void main()
        imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0));
        imageStore(ac3_tex,     ivec2(sx, y + n), ivec4(c3, 0, 0, 0));
        imageStore(ac4_tex,     ivec2(sx, y + n), ivec4(c4, 0, 0, 0));
+
+       // Count frequencies, but only for every 8th block or so, randomly selected.
+       uint wg_index = gl_WorkGroupID.y * gl_WorkGroupSize.x + gl_WorkGroupID.x;
+       if ((wg_index * 0x9E3779B9u) >> 29 == 0) {  // Fibonacci hashing, essentially a PRNG in this context.
+               c0 = min(abs(c0), 255);
+               c1 = min(abs(c1), 255);
+               c2 = min(abs(c2), 255);
+               c3 = min(abs(c3), 255);
+               c4 = min(abs(c4), 255);
+               c5 = min(abs(c5), 255);
+               c6 = min(abs(c6), 255);
+               c7 = min(abs(c7), 255);
+
+               // Spread out the most popular elements among the cache lines by reversing the bits
+               // of the index, reducing false sharing.
+               c0 = bitfieldReverse(c0) >> 24;
+               c1 = bitfieldReverse(c1) >> 24;
+               c2 = bitfieldReverse(c2) >> 24;
+               c3 = bitfieldReverse(c3) >> 24;
+               c4 = bitfieldReverse(c4) >> 24;
+               c5 = bitfieldReverse(c5) >> 24;
+               c6 = bitfieldReverse(c6) >> 24;
+               c7 = bitfieldReverse(c7) >> 24;
+
+               uint m = luma_mapping[n];
+               atomicAdd(dist[bitfieldExtract(m,  0, 2)][c0], 1);
+               atomicAdd(dist[bitfieldExtract(m,  2, 2)][c1], 1);
+               atomicAdd(dist[bitfieldExtract(m,  4, 2)][c2], 1);
+               atomicAdd(dist[bitfieldExtract(m,  6, 2)][c3], 1);
+               atomicAdd(dist[bitfieldExtract(m,  8, 2)][c4], 1);
+               atomicAdd(dist[bitfieldExtract(m, 10, 2)][c5], 1);
+               atomicAdd(dist[bitfieldExtract(m, 12, 2)][c6], 1);
+               atomicAdd(dist[bitfieldExtract(m, 14, 2)][c7], 1);
+       }
 }