]> git.sesse.net Git - narabu/blob - encoder.shader
da9e6c07b2c730274513087a6ffaa5239146b29e
[narabu] / encoder.shader
1 #version 440
2 #extension GL_ARB_shader_clock : enable
3
4 layout(local_size_x = 8) in;
5
6 layout(r16ui) uniform restrict writeonly uimage2D dc_ac7_tex;
7 layout(r16ui) uniform restrict writeonly uimage2D ac1_ac6_tex;
8 layout(r16ui) uniform restrict writeonly uimage2D ac2_ac5_tex;
9 layout(r8i) uniform restrict writeonly iimage2D ac3_tex;
10 layout(r8i) uniform restrict writeonly iimage2D ac4_tex;
11 layout(r8ui) uniform restrict readonly uimage2D image_tex;
12
13 shared float temp[64];
14
15 layout(std430, binding = 9) buffer layoutName
16 {
17         uint dist[4][256];
18 };
19
20 #define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14))
21
22 const uint luma_mapping[8] = {
23         MAPPING(0, 0, 1, 1, 2, 2, 3, 3),
24         MAPPING(0, 0, 1, 2, 2, 2, 3, 3),
25         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
26         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
27         MAPPING(1, 2, 2, 2, 2, 3, 3, 3),
28         MAPPING(2, 2, 2, 2, 3, 3, 3, 3),
29         MAPPING(2, 2, 3, 3, 3, 3, 3, 3),
30         MAPPING(3, 3, 3, 3, 3, 3, 3, 3),
31 };
32
33 // Scale factors; 1.0 / (sqrt(2.0) * cos(k * M_PI / 16.0)), except for the first which is 1.
34 const float sf[8] = {
35         1.0, 0.7209598220069479, 0.765366864730180, 0.8504300947672564,
36         1.0, 1.2727585805728336, 1.847759065022573, 3.6245097854115502
37 };
38
39 const float W[64] = {
40          8, 16, 19, 22, 26, 27, 29, 34,
41         16, 16, 22, 24, 27, 29, 34, 37,
42         19, 22, 26, 27, 29, 34, 34, 38,
43         22, 22, 26, 27, 29, 34, 37, 40,
44         22, 26, 27, 29, 32, 35, 40, 48,
45         26, 27, 29, 32, 35, 40, 48, 58,
46         26, 27, 29, 34, 38, 46, 56, 69,
47         27, 29, 35, 38, 46, 56, 69, 83
48 };
49 const float S = 4.0 * 0.5;  // whatever?
50
51 // NOTE: Contains factors to counteract the scaling in the DCT implementation.
52 const float quant_matrix[64] = {
53         sf[0] * sf[0] / 64.0,         sf[1] * sf[0] / (W[ 1] * S),  sf[2] * sf[0] / (W[ 2] * S),  sf[3] * sf[0] / (W[ 3] * S),  sf[4] * sf[0] / (W[ 4] * S),  sf[5] * sf[0] / (W[ 5] * S),  sf[6] * sf[0] / (W[ 6] * S),  sf[7] * sf[0] / (W[ 7] * S),
54         sf[0] * sf[1] / (W[ 8] * S),  sf[1] * sf[1] / (W[ 9] * S),  sf[2] * sf[1] / (W[10] * S),  sf[3] * sf[1] / (W[11] * S),  sf[4] * sf[1] / (W[12] * S),  sf[5] * sf[1] / (W[13] * S),  sf[6] * sf[1] / (W[14] * S),  sf[7] * sf[1] / (W[15] * S),
55         sf[0] * sf[2] / (W[16] * S),  sf[1] * sf[2] / (W[17] * S),  sf[2] * sf[2] / (W[18] * S),  sf[3] * sf[2] / (W[19] * S),  sf[4] * sf[2] / (W[20] * S),  sf[5] * sf[2] / (W[21] * S),  sf[6] * sf[2] / (W[22] * S),  sf[7] * sf[2] / (W[23] * S),
56         sf[0] * sf[3] / (W[24] * S),  sf[1] * sf[3] / (W[25] * S),  sf[2] * sf[3] / (W[26] * S),  sf[3] * sf[3] / (W[27] * S),  sf[4] * sf[3] / (W[28] * S),  sf[5] * sf[3] / (W[29] * S),  sf[6] * sf[3] / (W[30] * S),  sf[7] * sf[3] / (W[31] * S),
57         sf[0] * sf[4] / (W[32] * S),  sf[1] * sf[4] / (W[33] * S),  sf[2] * sf[4] / (W[34] * S),  sf[3] * sf[4] / (W[35] * S),  sf[4] * sf[4] / (W[36] * S),  sf[5] * sf[4] / (W[37] * S),  sf[6] * sf[4] / (W[38] * S),  sf[7] * sf[4] / (W[39] * S),
58         sf[0] * sf[5] / (W[40] * S),  sf[1] * sf[5] / (W[41] * S),  sf[2] * sf[5] / (W[42] * S),  sf[3] * sf[5] / (W[43] * S),  sf[4] * sf[5] / (W[44] * S),  sf[5] * sf[5] / (W[45] * S),  sf[6] * sf[5] / (W[46] * S),  sf[7] * sf[5] / (W[47] * S),
59         sf[0] * sf[6] / (W[48] * S),  sf[1] * sf[6] / (W[49] * S),  sf[2] * sf[6] / (W[50] * S),  sf[3] * sf[6] / (W[51] * S),  sf[4] * sf[6] / (W[52] * S),  sf[5] * sf[6] / (W[53] * S),  sf[6] * sf[6] / (W[54] * S),  sf[7] * sf[6] / (W[55] * S),
60         sf[0] * sf[7] / (W[56] * S),  sf[1] * sf[7] / (W[57] * S),  sf[2] * sf[7] / (W[58] * S),  sf[3] * sf[7] / (W[59] * S),  sf[4] * sf[7] / (W[60] * S),  sf[5] * sf[7] / (W[61] * S),  sf[6] * sf[7] / (W[62] * S),  sf[7] * sf[7] / (W[63] * S)
61 };
62
63 // Clamp and pack a 9-bit and a 7-bit signed value into a 16-bit word.
64 uint pack_9_7(int v9, int v7)
65 {
66         return (uint(clamp(v9, -256, 255)) & 0x1ffu) | ((uint(clamp(v7, -64, 63)) & 0x7fu) << 9);
67 }
68
69 // Scaled 1D DCT (AA&N). y0 is correctly scaled, all other y_k are scaled by sqrt(2) cos(k * Pi / 16).
70 void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
71 {
72         const float a1 = 0.7071067811865474;   // sqrt(2)
73         const float a2 = 0.5411961001461971;   // cos(3/8 pi) * sqrt(2)
74         const float a4 = 1.3065629648763766;   // cos(pi/8) * sqrt(2)
75         // static const float a5 = 0.5 * (a4 - a2);
76         const float a5 = 0.3826834323650897;
77
78         // phase 1
79         const float p1_0 = y0 + y7;
80         const float p1_1 = y1 + y6;
81         const float p1_2 = y2 + y5;
82         const float p1_3 = y3 + y4;
83         const float p1_4 = y3 - y4;
84         const float p1_5 = y2 - y5;
85         const float p1_6 = y1 - y6;
86         const float p1_7 = y0 - y7;
87
88         // phase 2
89         const float p2_0 = p1_0 + p1_3;
90         const float p2_1 = p1_1 + p1_2;
91         const float p2_2 = p1_1 - p1_2;
92         const float p2_3 = p1_0 - p1_3;
93         const float p2_4 = p1_4 + p1_5;  // Inverted.
94         const float p2_5 = p1_5 + p1_6;
95         const float p2_6 = p1_6 + p1_7;
96
97         // phase 3
98         const float p3_0 = p2_0 + p2_1;
99         const float p3_1 = p2_0 - p2_1;
100         const float p3_2 = p2_2 + p2_3;
101         
102         // phase 4
103         const float p4_2 = p3_2 * a1;
104         const float p4_4 = p2_4 * a2 + (p2_4 - p2_6) * a5;
105         const float p4_5 = p2_5 * a1;
106         const float p4_6 = p2_6 * a4 + (p2_4 - p2_6) * a5;
107
108         // phase 5
109         const float p5_2 = p2_3 + p4_2;
110         const float p5_3 = p2_3 - p4_2;
111         const float p5_5 = p1_7 + p4_5;
112         const float p5_7 = p1_7 - p4_5;
113         
114         // phase 6
115         y0 = p3_0;
116         y4 = p3_1;
117         y2 = p5_2;
118         y6 = p5_3;
119         y5 = p4_4 + p5_7;
120         y1 = p5_5 + p4_6;
121         y7 = p5_5 - p4_6;
122         y3 = p5_7 - p4_4;
123 }
124
125 void main()
126 {
127         uint x = 8 * gl_WorkGroupID.x;
128         uint y = 8 * gl_WorkGroupID.y;
129         uint n = gl_LocalInvocationID.x;
130
131         // Load column.
132         float y0 = imageLoad(image_tex, ivec2(x + n, y + 0)).x;
133         float y1 = imageLoad(image_tex, ivec2(x + n, y + 1)).x;
134         float y2 = imageLoad(image_tex, ivec2(x + n, y + 2)).x;
135         float y3 = imageLoad(image_tex, ivec2(x + n, y + 3)).x;
136         float y4 = imageLoad(image_tex, ivec2(x + n, y + 4)).x;
137         float y5 = imageLoad(image_tex, ivec2(x + n, y + 5)).x;
138         float y6 = imageLoad(image_tex, ivec2(x + n, y + 6)).x;
139         float y7 = imageLoad(image_tex, ivec2(x + n, y + 7)).x;
140
141         // Vertical DCT.
142         dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
143
144         // Communicate with the other shaders in the group.
145         temp[n + 0 * 8] = y0;
146         temp[n + 1 * 8] = y1;
147         temp[n + 2 * 8] = y2;
148         temp[n + 3 * 8] = y3;
149         temp[n + 4 * 8] = y4;
150         temp[n + 5 * 8] = y5;
151         temp[n + 6 * 8] = y6;
152         temp[n + 7 * 8] = y7;
153
154         memoryBarrierShared();
155         barrier();
156
157         // Load row (so transpose, in a sense).
158         y0 = temp[n * 8 + 0];
159         y1 = temp[n * 8 + 1];
160         y2 = temp[n * 8 + 2];
161         y3 = temp[n * 8 + 3];
162         y4 = temp[n * 8 + 4];
163         y5 = temp[n * 8 + 5];
164         y6 = temp[n * 8 + 6];
165         y7 = temp[n * 8 + 7];
166
167         // Horizontal DCT.
168         dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
169
170         // Quantize.
171         int c0 = int(round(y0 * quant_matrix[n * 8 + 0]));
172         int c1 = int(round(y1 * quant_matrix[n * 8 + 1]));
173         int c2 = int(round(y2 * quant_matrix[n * 8 + 2]));
174         int c3 = int(round(y3 * quant_matrix[n * 8 + 3]));
175         int c4 = int(round(y4 * quant_matrix[n * 8 + 4]));
176         int c5 = int(round(y5 * quant_matrix[n * 8 + 5]));
177         int c6 = int(round(y6 * quant_matrix[n * 8 + 6]));
178         int c7 = int(round(y7 * quant_matrix[n * 8 + 7]));
179
180         // Clamp, pack and store.
181         uint sx = gl_WorkGroupID.x;
182         imageStore(dc_ac7_tex,  ivec2(sx, y + n), uvec4(pack_9_7(c0, c7), 0, 0, 0));
183         imageStore(ac1_ac6_tex, ivec2(sx, y + n), uvec4(pack_9_7(c1, c6), 0, 0, 0));
184         imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0));
185         imageStore(ac3_tex,     ivec2(sx, y + n), ivec4(c3, 0, 0, 0));
186         imageStore(ac4_tex,     ivec2(sx, y + n), ivec4(c4, 0, 0, 0));
187
188         // Count frequencies, but only for every 8th block or so, randomly selected.
189         uint wg_index = gl_WorkGroupID.y * gl_WorkGroupSize.x + gl_WorkGroupID.x;
190         if ((wg_index * 0x9E3779B9u) >> 29 == 0) {  // Fibonacci hashing, essentially a PRNG in this context.
191                 c0 = min(abs(c0), 255);
192                 c1 = min(abs(c1), 255);
193                 c2 = min(abs(c2), 255);
194                 c3 = min(abs(c3), 255);
195                 c4 = min(abs(c4), 255);
196                 c5 = min(abs(c5), 255);
197                 c6 = min(abs(c6), 255);
198                 c7 = min(abs(c7), 255);
199
200                 // Spread out the most popular elements among the cache lines by reversing the bits
201                 // of the index, reducing false sharing.
202                 c0 = bitfieldReverse(c0) >> 24;
203                 c1 = bitfieldReverse(c1) >> 24;
204                 c2 = bitfieldReverse(c2) >> 24;
205                 c3 = bitfieldReverse(c3) >> 24;
206                 c4 = bitfieldReverse(c4) >> 24;
207                 c5 = bitfieldReverse(c5) >> 24;
208                 c6 = bitfieldReverse(c6) >> 24;
209                 c7 = bitfieldReverse(c7) >> 24;
210
211                 uint m = luma_mapping[n];
212                 atomicAdd(dist[bitfieldExtract(m,  0, 2)][c0], 1);
213                 atomicAdd(dist[bitfieldExtract(m,  2, 2)][c1], 1);
214                 atomicAdd(dist[bitfieldExtract(m,  4, 2)][c2], 1);
215                 atomicAdd(dist[bitfieldExtract(m,  6, 2)][c3], 1);
216                 atomicAdd(dist[bitfieldExtract(m,  8, 2)][c4], 1);
217                 atomicAdd(dist[bitfieldExtract(m, 10, 2)][c5], 1);
218                 atomicAdd(dist[bitfieldExtract(m, 12, 2)][c6], 1);
219                 atomicAdd(dist[bitfieldExtract(m, 14, 2)][c7], 1);
220         }
221 }
222