]> git.sesse.net Git - narabu/blob - encoder.shader
Make quant_matrix a bit more compact.
[narabu] / encoder.shader
1 #version 440
2 #extension GL_ARB_shader_clock : enable
3
4 layout(local_size_x = 8) in;
5
6 layout(r16ui) uniform restrict writeonly uimage2D dc_ac7_tex;
7 layout(r16ui) uniform restrict writeonly uimage2D ac1_ac6_tex;
8 layout(r16ui) uniform restrict writeonly uimage2D ac2_ac5_tex;
9 layout(r8i) uniform restrict writeonly iimage2D ac3_tex;
10 layout(r8i) uniform restrict writeonly iimage2D ac4_tex;
11 layout(r8ui) uniform restrict readonly uimage2D image_tex;
12
13 shared float temp[64];
14
15 layout(std430, binding = 9) buffer layoutName
16 {
17         uint dist[4][256];
18 };
19
20 #define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14))
21
22 const uint luma_mapping[8] = {
23         MAPPING(0, 0, 1, 1, 2, 2, 3, 3),
24         MAPPING(0, 0, 1, 2, 2, 2, 3, 3),
25         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
26         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
27         MAPPING(1, 2, 2, 2, 2, 3, 3, 3),
28         MAPPING(2, 2, 2, 2, 3, 3, 3, 3),
29         MAPPING(2, 2, 3, 3, 3, 3, 3, 3),
30         MAPPING(3, 3, 3, 3, 3, 3, 3, 3),
31 };
32
33 // Scale factors; 1.0 / (sqrt(2.0) * cos(k * M_PI / 16.0)), except for the first which is 1.
34 const float sf[8] = {
35         1.0, 0.7209598220069479, 0.765366864730180, 0.8504300947672564,
36         1.0, 1.2727585805728336, 1.847759065022573, 3.6245097854115502
37 };
38
39 const float W[64] = {
40          8, 16, 19, 22, 26, 27, 29, 34,
41         16, 16, 22, 24, 27, 29, 34, 37,
42         19, 22, 26, 27, 29, 34, 34, 38,
43         22, 22, 26, 27, 29, 34, 37, 40,
44         22, 26, 27, 29, 32, 35, 40, 48,
45         26, 27, 29, 32, 35, 40, 48, 58,
46         26, 27, 29, 34, 38, 46, 56, 69,
47         27, 29, 35, 38, 46, 56, 69, 83
48 };
49 const float S = 4.0 * 0.5;  // whatever?
50
51 // NOTE: Contains factors to counteract the scaling in the DCT implementation.
52 #define QM(x, y) (sf[x] * sf[y] / (W[y*8 + x] * S))
53 const float quant_matrix[64] = {
54         1.0 / 64.0, QM(1, 0), QM(2, 0), QM(3, 0), QM(4, 0), QM(5, 0), QM(6, 0), QM(7, 0),
55         QM(0, 1),   QM(1, 1), QM(2, 1), QM(3, 1), QM(4, 1), QM(5, 1), QM(6, 1), QM(7, 1),
56         QM(0, 2),   QM(1, 2), QM(2, 2), QM(3, 2), QM(4, 2), QM(5, 2), QM(6, 2), QM(7, 2),
57         QM(0, 3),   QM(1, 3), QM(2, 3), QM(3, 3), QM(4, 3), QM(5, 3), QM(6, 3), QM(7, 3),
58         QM(0, 4),   QM(1, 4), QM(2, 4), QM(3, 4), QM(4, 4), QM(5, 4), QM(6, 4), QM(7, 4),
59         QM(0, 5),   QM(1, 5), QM(2, 5), QM(3, 5), QM(4, 5), QM(5, 5), QM(6, 5), QM(7, 5),
60         QM(0, 6),   QM(1, 6), QM(2, 6), QM(3, 6), QM(4, 6), QM(5, 6), QM(6, 6), QM(7, 6),
61         QM(0, 7),   QM(1, 7), QM(2, 7), QM(3, 7), QM(4, 7), QM(5, 7), QM(6, 7), QM(7, 7)
62 };
63
64 // Clamp and pack a 9-bit and a 7-bit signed value into a 16-bit word.
65 uint pack_9_7(int v9, int v7)
66 {
67         return (uint(clamp(v9, -256, 255)) & 0x1ffu) | ((uint(clamp(v7, -64, 63)) & 0x7fu) << 9);
68 }
69
70 // Scaled 1D DCT (AA&N). y0 is correctly scaled, all other y_k are scaled by sqrt(2) cos(k * Pi / 16).
71 void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
72 {
73         const float a1 = 0.7071067811865474;   // sqrt(2)
74         const float a2 = 0.5411961001461971;   // cos(3/8 pi) * sqrt(2)
75         const float a4 = 1.3065629648763766;   // cos(pi/8) * sqrt(2)
76         // static const float a5 = 0.5 * (a4 - a2);
77         const float a5 = 0.3826834323650897;
78
79         // phase 1
80         const float p1_0 = y0 + y7;
81         const float p1_1 = y1 + y6;
82         const float p1_2 = y2 + y5;
83         const float p1_3 = y3 + y4;
84         const float p1_4 = y3 - y4;
85         const float p1_5 = y2 - y5;
86         const float p1_6 = y1 - y6;
87         const float p1_7 = y0 - y7;
88
89         // phase 2
90         const float p2_0 = p1_0 + p1_3;
91         const float p2_1 = p1_1 + p1_2;
92         const float p2_2 = p1_1 - p1_2;
93         const float p2_3 = p1_0 - p1_3;
94         const float p2_4 = p1_4 + p1_5;  // Inverted.
95         const float p2_5 = p1_5 + p1_6;
96         const float p2_6 = p1_6 + p1_7;
97
98         // phase 3
99         const float p3_0 = p2_0 + p2_1;
100         const float p3_1 = p2_0 - p2_1;
101         const float p3_2 = p2_2 + p2_3;
102         
103         // phase 4
104         const float p4_2 = p3_2 * a1;
105         const float p4_4 = p2_4 * a2 + (p2_4 - p2_6) * a5;
106         const float p4_5 = p2_5 * a1;
107         const float p4_6 = p2_6 * a4 + (p2_4 - p2_6) * a5;
108
109         // phase 5
110         const float p5_2 = p2_3 + p4_2;
111         const float p5_3 = p2_3 - p4_2;
112         const float p5_5 = p1_7 + p4_5;
113         const float p5_7 = p1_7 - p4_5;
114         
115         // phase 6
116         y0 = p3_0;
117         y4 = p3_1;
118         y2 = p5_2;
119         y6 = p5_3;
120         y5 = p4_4 + p5_7;
121         y1 = p5_5 + p4_6;
122         y7 = p5_5 - p4_6;
123         y3 = p5_7 - p4_4;
124 }
125
126 void main()
127 {
128         uint x = 8 * gl_WorkGroupID.x;
129         uint y = 8 * gl_WorkGroupID.y;
130         uint n = gl_LocalInvocationID.x;
131
132         // Load column.
133         float y0 = imageLoad(image_tex, ivec2(x + n, y + 0)).x;
134         float y1 = imageLoad(image_tex, ivec2(x + n, y + 1)).x;
135         float y2 = imageLoad(image_tex, ivec2(x + n, y + 2)).x;
136         float y3 = imageLoad(image_tex, ivec2(x + n, y + 3)).x;
137         float y4 = imageLoad(image_tex, ivec2(x + n, y + 4)).x;
138         float y5 = imageLoad(image_tex, ivec2(x + n, y + 5)).x;
139         float y6 = imageLoad(image_tex, ivec2(x + n, y + 6)).x;
140         float y7 = imageLoad(image_tex, ivec2(x + n, y + 7)).x;
141
142         // Vertical DCT.
143         dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
144
145         // Communicate with the other shaders in the group.
146         temp[n + 0 * 8] = y0;
147         temp[n + 1 * 8] = y1;
148         temp[n + 2 * 8] = y2;
149         temp[n + 3 * 8] = y3;
150         temp[n + 4 * 8] = y4;
151         temp[n + 5 * 8] = y5;
152         temp[n + 6 * 8] = y6;
153         temp[n + 7 * 8] = y7;
154
155         memoryBarrierShared();
156         barrier();
157
158         // Load row (so transpose, in a sense).
159         y0 = temp[n * 8 + 0];
160         y1 = temp[n * 8 + 1];
161         y2 = temp[n * 8 + 2];
162         y3 = temp[n * 8 + 3];
163         y4 = temp[n * 8 + 4];
164         y5 = temp[n * 8 + 5];
165         y6 = temp[n * 8 + 6];
166         y7 = temp[n * 8 + 7];
167
168         // Horizontal DCT.
169         dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
170
171         // Quantize.
172         int c0 = int(round(y0 * quant_matrix[n * 8 + 0]));
173         int c1 = int(round(y1 * quant_matrix[n * 8 + 1]));
174         int c2 = int(round(y2 * quant_matrix[n * 8 + 2]));
175         int c3 = int(round(y3 * quant_matrix[n * 8 + 3]));
176         int c4 = int(round(y4 * quant_matrix[n * 8 + 4]));
177         int c5 = int(round(y5 * quant_matrix[n * 8 + 5]));
178         int c6 = int(round(y6 * quant_matrix[n * 8 + 6]));
179         int c7 = int(round(y7 * quant_matrix[n * 8 + 7]));
180
181         // Clamp, pack and store.
182         uint sx = gl_WorkGroupID.x;
183         imageStore(dc_ac7_tex,  ivec2(sx, y + n), uvec4(pack_9_7(c0, c7), 0, 0, 0));
184         imageStore(ac1_ac6_tex, ivec2(sx, y + n), uvec4(pack_9_7(c1, c6), 0, 0, 0));
185         imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0));
186         imageStore(ac3_tex,     ivec2(sx, y + n), ivec4(c3, 0, 0, 0));
187         imageStore(ac4_tex,     ivec2(sx, y + n), ivec4(c4, 0, 0, 0));
188
189         // Count frequencies, but only for every 8th block or so, randomly selected.
190         uint wg_index = gl_WorkGroupID.y * gl_WorkGroupSize.x + gl_WorkGroupID.x;
191         if ((wg_index * 0x9E3779B9u) >> 29 == 0) {  // Fibonacci hashing, essentially a PRNG in this context.
192                 c0 = min(abs(c0), 255);
193                 c1 = min(abs(c1), 255);
194                 c2 = min(abs(c2), 255);
195                 c3 = min(abs(c3), 255);
196                 c4 = min(abs(c4), 255);
197                 c5 = min(abs(c5), 255);
198                 c6 = min(abs(c6), 255);
199                 c7 = min(abs(c7), 255);
200
201                 // Spread out the most popular elements among the cache lines by reversing the bits
202                 // of the index, reducing false sharing.
203                 c0 = bitfieldReverse(c0) >> 24;
204                 c1 = bitfieldReverse(c1) >> 24;
205                 c2 = bitfieldReverse(c2) >> 24;
206                 c3 = bitfieldReverse(c3) >> 24;
207                 c4 = bitfieldReverse(c4) >> 24;
208                 c5 = bitfieldReverse(c5) >> 24;
209                 c6 = bitfieldReverse(c6) >> 24;
210                 c7 = bitfieldReverse(c7) >> 24;
211
212                 uint m = luma_mapping[n];
213                 atomicAdd(dist[bitfieldExtract(m,  0, 2)][c0], 1);
214                 atomicAdd(dist[bitfieldExtract(m,  2, 2)][c1], 1);
215                 atomicAdd(dist[bitfieldExtract(m,  4, 2)][c2], 1);
216                 atomicAdd(dist[bitfieldExtract(m,  6, 2)][c3], 1);
217                 atomicAdd(dist[bitfieldExtract(m,  8, 2)][c4], 1);
218                 atomicAdd(dist[bitfieldExtract(m, 10, 2)][c5], 1);
219                 atomicAdd(dist[bitfieldExtract(m, 12, 2)][c6], 1);
220                 atomicAdd(dist[bitfieldExtract(m, 14, 2)][c7], 1);
221         }
222 }
223