]> git.sesse.net Git - narabu/blob - encoder.shader
More fixes of hard-coded values.
[narabu] / encoder.shader
1 #version 440
2 #extension GL_ARB_shader_clock : enable
3
4 // Do sixteen 8x8 blocks in a local group, because that matches up perfectly
5 // with needing 1024 coefficients for our four histograms (of 256 bins each).
6 #define NUM_Z 16
7
8 layout(local_size_x = 8, local_size_z = NUM_Z) in;
9
10 layout(r16ui) uniform restrict writeonly uimage2D dc_ac7_tex;
11 layout(r16ui) uniform restrict writeonly uimage2D ac1_ac6_tex;
12 layout(r16ui) uniform restrict writeonly uimage2D ac2_ac5_tex;
13 layout(r8i) uniform restrict writeonly iimage2D ac3_tex;
14 layout(r8i) uniform restrict writeonly iimage2D ac4_tex;
15 layout(r8ui) uniform restrict readonly uimage2D image_tex;
16
17 shared uint temp[64 * NUM_Z];
18
19 layout(std430, binding = 9) buffer layoutName
20 {
21         uint dist[4 * 256];
22 };
23
24 #define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14))
25
26 const uint luma_mapping[8] = {
27         MAPPING(0, 0, 1, 1, 2, 2, 3, 3),
28         MAPPING(0, 0, 1, 2, 2, 2, 3, 3),
29         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
30         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
31         MAPPING(1, 2, 2, 2, 2, 3, 3, 3),
32         MAPPING(2, 2, 2, 2, 3, 3, 3, 3),
33         MAPPING(2, 2, 3, 3, 3, 3, 3, 3),
34         MAPPING(3, 3, 3, 3, 3, 3, 3, 3),
35 };
36
37 // Scale factors; 1.0 / (sqrt(2.0) * cos(k * M_PI / 16.0)), except for the first which is 1.
38 const float sf[8] = {
39         1.0, 0.7209598220069479, 0.765366864730180, 0.8504300947672564,
40         1.0, 1.2727585805728336, 1.847759065022573, 3.6245097854115502
41 };
42
43 const float W[64] = {
44          8, 16, 19, 22, 26, 27, 29, 34,
45         16, 16, 22, 24, 27, 29, 34, 37,
46         19, 22, 26, 27, 29, 34, 34, 38,
47         22, 22, 26, 27, 29, 34, 37, 40,
48         22, 26, 27, 29, 32, 35, 40, 48,
49         26, 27, 29, 32, 35, 40, 48, 58,
50         26, 27, 29, 34, 38, 46, 56, 69,
51         27, 29, 35, 38, 46, 56, 69, 83
52 };
53 const float S = 4.0 * 0.5;  // whatever?
54
55 // NOTE: Contains factors to counteract the scaling in the DCT implementation.
56 #define QM(x, y) (sf[x] * sf[y] / (W[y*8 + x] * S))
57 const float quant_matrix[64] = {
58         1.0 / 64.0, QM(1, 0), QM(2, 0), QM(3, 0), QM(4, 0), QM(5, 0), QM(6, 0), QM(7, 0),
59         QM(0, 1),   QM(1, 1), QM(2, 1), QM(3, 1), QM(4, 1), QM(5, 1), QM(6, 1), QM(7, 1),
60         QM(0, 2),   QM(1, 2), QM(2, 2), QM(3, 2), QM(4, 2), QM(5, 2), QM(6, 2), QM(7, 2),
61         QM(0, 3),   QM(1, 3), QM(2, 3), QM(3, 3), QM(4, 3), QM(5, 3), QM(6, 3), QM(7, 3),
62         QM(0, 4),   QM(1, 4), QM(2, 4), QM(3, 4), QM(4, 4), QM(5, 4), QM(6, 4), QM(7, 4),
63         QM(0, 5),   QM(1, 5), QM(2, 5), QM(3, 5), QM(4, 5), QM(5, 5), QM(6, 5), QM(7, 5),
64         QM(0, 6),   QM(1, 6), QM(2, 6), QM(3, 6), QM(4, 6), QM(5, 6), QM(6, 6), QM(7, 6),
65         QM(0, 7),   QM(1, 7), QM(2, 7), QM(3, 7), QM(4, 7), QM(5, 7), QM(6, 7), QM(7, 7)
66 };
67
68 // Clamp and pack a 9-bit and a 7-bit signed value into a 16-bit word.
69 uint pack_9_7(int v9, int v7)
70 {
71         return (uint(clamp(v9, -256, 255)) & 0x1ffu) | ((uint(clamp(v7, -64, 63)) & 0x7fu) << 9);
72 }
73
74 // Scaled 1D DCT (AA&N). y0 is correctly scaled, all other y_k are scaled by sqrt(2) cos(k * Pi / 16).
75 void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
76 {
77         const float a1 = 0.7071067811865474;   // sqrt(2)
78         const float a2 = 0.5411961001461971;   // cos(3/8 pi) * sqrt(2)
79         const float a4 = 1.3065629648763766;   // cos(pi/8) * sqrt(2)
80         // static const float a5 = 0.5 * (a4 - a2);
81         const float a5 = 0.3826834323650897;
82
83         // phase 1
84         const float p1_0 = y0 + y7;
85         const float p1_1 = y1 + y6;
86         const float p1_2 = y2 + y5;
87         const float p1_3 = y3 + y4;
88         const float p1_4 = y3 - y4;
89         const float p1_5 = y2 - y5;
90         const float p1_6 = y1 - y6;
91         const float p1_7 = y0 - y7;
92
93         // phase 2
94         const float p2_0 = p1_0 + p1_3;
95         const float p2_1 = p1_1 + p1_2;
96         const float p2_2 = p1_1 - p1_2;
97         const float p2_3 = p1_0 - p1_3;
98         const float p2_4 = p1_4 + p1_5;  // Inverted.
99         const float p2_5 = p1_5 + p1_6;
100         const float p2_6 = p1_6 + p1_7;
101
102         // phase 3
103         const float p3_0 = p2_0 + p2_1;
104         const float p3_1 = p2_0 - p2_1;
105         const float p3_2 = p2_2 + p2_3;
106         
107         // phase 4
108         const float p4_2 = p3_2 * a1;
109         const float p4_4 = p2_4 * a2 + (p2_4 - p2_6) * a5;
110         const float p4_5 = p2_5 * a1;
111         const float p4_6 = p2_6 * a4 + (p2_4 - p2_6) * a5;
112
113         // phase 5
114         const float p5_2 = p2_3 + p4_2;
115         const float p5_3 = p2_3 - p4_2;
116         const float p5_5 = p1_7 + p4_5;
117         const float p5_7 = p1_7 - p4_5;
118         
119         // phase 6
120         y0 = p3_0;
121         y4 = p3_1;
122         y2 = p5_2;
123         y6 = p5_3;
124         y5 = p4_4 + p5_7;
125         y1 = p5_5 + p4_6;
126         y7 = p5_5 - p4_6;
127         y3 = p5_7 - p4_4;
128 }
129
130 void main()
131 {
132         uint sx = gl_WorkGroupID.x * NUM_Z + gl_LocalInvocationID.z;
133         uint x = 8 * sx;
134         uint y = 8 * gl_WorkGroupID.y;
135         uint n = gl_LocalInvocationID.x;
136         uint z = gl_LocalInvocationID.z;
137
138         // Load column.
139         float y0 = imageLoad(image_tex, ivec2(x + n, y + 0)).x;
140         float y1 = imageLoad(image_tex, ivec2(x + n, y + 1)).x;
141         float y2 = imageLoad(image_tex, ivec2(x + n, y + 2)).x;
142         float y3 = imageLoad(image_tex, ivec2(x + n, y + 3)).x;
143         float y4 = imageLoad(image_tex, ivec2(x + n, y + 4)).x;
144         float y5 = imageLoad(image_tex, ivec2(x + n, y + 5)).x;
145         float y6 = imageLoad(image_tex, ivec2(x + n, y + 6)).x;
146         float y7 = imageLoad(image_tex, ivec2(x + n, y + 7)).x;
147
148         // Vertical DCT.
149         dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
150
151         // Communicate with the other shaders in the group.
152         uint base_idx = 64 * z;
153         temp[base_idx + 0 * 8 + n] = floatBitsToUint(y0);
154         temp[base_idx + 1 * 8 + n] = floatBitsToUint(y1);
155         temp[base_idx + 2 * 8 + n] = floatBitsToUint(y2);
156         temp[base_idx + 3 * 8 + n] = floatBitsToUint(y3);
157         temp[base_idx + 4 * 8 + n] = floatBitsToUint(y4);
158         temp[base_idx + 5 * 8 + n] = floatBitsToUint(y5);
159         temp[base_idx + 6 * 8 + n] = floatBitsToUint(y6);
160         temp[base_idx + 7 * 8 + n] = floatBitsToUint(y7);
161
162         memoryBarrierShared();
163         barrier();
164
165         // Load row (so transpose, in a sense).
166         y0 = uintBitsToFloat(temp[base_idx + n * 8 + 0]);
167         y1 = uintBitsToFloat(temp[base_idx + n * 8 + 1]);
168         y2 = uintBitsToFloat(temp[base_idx + n * 8 + 2]);
169         y3 = uintBitsToFloat(temp[base_idx + n * 8 + 3]);
170         y4 = uintBitsToFloat(temp[base_idx + n * 8 + 4]);
171         y5 = uintBitsToFloat(temp[base_idx + n * 8 + 5]);
172         y6 = uintBitsToFloat(temp[base_idx + n * 8 + 6]);
173         y7 = uintBitsToFloat(temp[base_idx + n * 8 + 7]);
174
175         // Horizontal DCT.
176         dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
177
178         // Quantize.
179         int c0 = int(round(y0 * quant_matrix[n * 8 + 0]));
180         int c1 = int(round(y1 * quant_matrix[n * 8 + 1]));
181         int c2 = int(round(y2 * quant_matrix[n * 8 + 2]));
182         int c3 = int(round(y3 * quant_matrix[n * 8 + 3]));
183         int c4 = int(round(y4 * quant_matrix[n * 8 + 4]));
184         int c5 = int(round(y5 * quant_matrix[n * 8 + 5]));
185         int c6 = int(round(y6 * quant_matrix[n * 8 + 6]));
186         int c7 = int(round(y7 * quant_matrix[n * 8 + 7]));
187
188         // Clamp, pack and store.
189         imageStore(dc_ac7_tex,  ivec2(sx, y + n), uvec4(pack_9_7(c0, c7), 0, 0, 0));
190         imageStore(ac1_ac6_tex, ivec2(sx, y + n), uvec4(pack_9_7(c1, c6), 0, 0, 0));
191         imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0));
192         imageStore(ac3_tex,     ivec2(sx, y + n), ivec4(c3, 0, 0, 0));
193         imageStore(ac4_tex,     ivec2(sx, y + n), ivec4(c4, 0, 0, 0));
194
195         // Zero out the temporary area in preparation for counting up the histograms.
196         base_idx += 8 * n;
197         temp[base_idx + 0] = 0;
198         temp[base_idx + 1] = 0;
199         temp[base_idx + 2] = 0;
200         temp[base_idx + 3] = 0;
201         temp[base_idx + 4] = 0;
202         temp[base_idx + 5] = 0;
203         temp[base_idx + 6] = 0;
204         temp[base_idx + 7] = 0;
205
206         memoryBarrierShared();
207         barrier();
208
209         // Count frequencies into four histograms. We do this to local memory first,
210         // because this is _much_ faster; then we do global atomic adds for the nonzero
211         // members.
212
213         // First take the absolute value (signs are encoded differently) and clamp,
214         // as any value over 255 is going to be encoded as an escape.
215         c0 = min(abs(c0), 255);
216         c1 = min(abs(c1), 255);
217         c2 = min(abs(c2), 255);
218         c3 = min(abs(c3), 255);
219         c4 = min(abs(c4), 255);
220         c5 = min(abs(c5), 255);
221         c6 = min(abs(c6), 255);
222         c7 = min(abs(c7), 255);
223
224         // Add up in local memory.
225         uint m = luma_mapping[n];
226         atomicAdd(temp[bitfieldExtract(m,  0, 2) * 256 + c0], 1);
227         atomicAdd(temp[bitfieldExtract(m,  2, 2) * 256 + c1], 1);
228         atomicAdd(temp[bitfieldExtract(m,  4, 2) * 256 + c2], 1);
229         atomicAdd(temp[bitfieldExtract(m,  6, 2) * 256 + c3], 1);
230         atomicAdd(temp[bitfieldExtract(m,  8, 2) * 256 + c4], 1);
231         atomicAdd(temp[bitfieldExtract(m, 10, 2) * 256 + c5], 1);
232         atomicAdd(temp[bitfieldExtract(m, 12, 2) * 256 + c6], 1);
233         atomicAdd(temp[bitfieldExtract(m, 14, 2) * 256 + c7], 1);
234
235         memoryBarrierShared();
236         barrier();
237
238         // Add from local memory to global memory.
239         if (temp[base_idx + 0] != 0) atomicAdd(dist[base_idx + 0], temp[base_idx + 0]);
240         if (temp[base_idx + 1] != 0) atomicAdd(dist[base_idx + 1], temp[base_idx + 1]);
241         if (temp[base_idx + 2] != 0) atomicAdd(dist[base_idx + 2], temp[base_idx + 2]);
242         if (temp[base_idx + 3] != 0) atomicAdd(dist[base_idx + 3], temp[base_idx + 3]);
243         if (temp[base_idx + 4] != 0) atomicAdd(dist[base_idx + 4], temp[base_idx + 4]);
244         if (temp[base_idx + 5] != 0) atomicAdd(dist[base_idx + 5], temp[base_idx + 5]);
245         if (temp[base_idx + 6] != 0) atomicAdd(dist[base_idx + 6], temp[base_idx + 6]);
246         if (temp[base_idx + 7] != 0) atomicAdd(dist[base_idx + 7], temp[base_idx + 7]);
247 }