]> git.sesse.net Git - narabu/blob - decoder.shader
Switch to 64-bit rANS, although probably due for immediate revert (just want to prese...
[narabu] / decoder.shader
1 #version 440
2 #extension GL_ARB_shader_clock : enable
3
4 #define PARALLEL_SLICES 1
5
6 #define ENABLE_TIMING 0
7
8 layout(local_size_x = 64*PARALLEL_SLICES) in;
9 layout(r8ui) uniform restrict readonly uimage2D cum2sym_tex;
10 layout(rg16ui) uniform restrict readonly uimage2D dsyms_tex;
11 layout(r8) uniform restrict writeonly image2D out_tex;
12 layout(r32i) uniform restrict writeonly iimage2D coeff_tex;
13 layout(r32i) uniform restrict writeonly iimage2D coeff2_tex;
14 uniform int num_blocks;
15
16 const uint prob_bits = 12;
17 const uint prob_scale = 1 << prob_bits;
18 const uint NUM_SYMS = 256;
19 const uint ESCAPE_LIMIT = NUM_SYMS - 1;
20 const uint BLOCKS_PER_STREAM = 320;
21
22 // These need to be folded into quant_matrix.
23 const float dc_scalefac = 8.0;
24 const float quant_scalefac = 4.0;
25
26 const float quant_matrix[64] = {
27          8, 16, 19, 22, 26, 27, 29, 34,
28         16, 16, 22, 24, 27, 29, 34, 37,
29         19, 22, 26, 27, 29, 34, 34, 38,
30         22, 22, 26, 27, 29, 34, 37, 40,
31         22, 26, 27, 29, 32, 35, 40, 48,
32         26, 27, 29, 32, 35, 40, 48, 58,
33         26, 27, 29, 34, 38, 46, 56, 69,
34         27, 29, 35, 38, 46, 56, 69, 83
35 };
36 const uint ff_zigzag_direct[64] = {
37     0,   1,  8, 16,  9,  2,  3, 10,
38     17, 24, 32, 25, 18, 11,  4,  5,
39     12, 19, 26, 33, 40, 48, 41, 34,
40     27, 20, 13,  6,  7, 14, 21, 28,
41     35, 42, 49, 56, 57, 50, 43, 36,
42     29, 22, 15, 23, 30, 37, 44, 51,
43     58, 59, 52, 45, 38, 31, 39, 46,
44     53, 60, 61, 54, 47, 55, 62, 63
45 };
46 const uint stream_mapping[64] = {
47         0, 0, 1, 1, 2, 2, 3, 3,
48         0, 0, 1, 2, 2, 2, 3, 3,
49         1, 1, 2, 2, 2, 3, 3, 3,
50         1, 1, 2, 2, 2, 3, 3, 3,
51         1, 2, 2, 2, 2, 3, 3, 3,
52         2, 2, 2, 2, 3, 3, 3, 3,
53         2, 2, 3, 3, 3, 3, 3, 3,
54         3, 3, 3, 3, 3, 3, 3, 3,
55 };
56
57 layout(std430, binding = 9) buffer layoutName
58 {
59         uint data_SSBO[];
60 };
61 layout(std430, binding = 10) buffer layoutName2
62 {
63         uvec2 timing[10 * 64];
64 };
65
66 struct CoeffStream {
67         uint src_offset, src_len;
68 };
69 layout(std430, binding = 0) buffer whatever3
70 {
71         CoeffStream streams[];
72 };
73 uniform uint sign_bias_per_model[16];
74
75 struct myuint64 {
76         uint high, low;
77 };
78
79 const uint RANS64_L = (1u << 31);  // lower bound of our normalization interval
80
81 myuint64 RansDecInit(inout uint offset)
82 {
83         myuint64 x;
84         x.low  = data_SSBO[offset++];
85         x.high = data_SSBO[offset++];
86         return x;
87 }
88
89 uint RansDecGet(myuint64 r, uint scale_bits)
90 {
91         return r.low & ((1u << scale_bits) - 1);
92 }
93
94 void RansDecAdvance(inout myuint64 rans, inout uint offset, const uint start, const uint freq, uint prob_bits)
95 {
96         const uint mask = (1u << prob_bits) - 1;
97         const uint recovered_lowbits = (rans.low & mask) - start;
98
99         // rans >>= prob_bits;
100         rans.low = (rans.low >> prob_bits) | ((rans.high & mask) << (32 - prob_bits));
101         rans.high >>= prob_bits;
102
103         // rans *= freq;
104         uint h1, l1, h2, l2;
105         umulExtended(rans.low, freq, h1, l1);
106         umulExtended(rans.high, freq, h2, l2);
107         rans.low = l1;
108         rans.high = l2 + h1;
109
110         // rans += recovered_lowbits;
111         uint carry;
112         rans.low = uaddCarry(rans.low, recovered_lowbits, carry);
113         rans.high += carry;
114
115         // renormalize
116         if (rans.high == 0 && rans.low < RANS64_L) {
117                 rans.high = rans.low;
118                 rans.low = data_SSBO[offset++];
119         }
120 }
121
122 uint cum2sym(uint bits, uint table)
123 {
124         return imageLoad(cum2sym_tex, ivec2(bits, table)).x;
125 }
126
127 uvec2 get_dsym(uint k, uint table)
128 {
129         return imageLoad(dsyms_tex, ivec2(k, table)).xy;
130 }
131
132 void idct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
133 {
134         const float a1 = 0.7071067811865474;   // sqrt(2)
135         const float a2 = 0.5411961001461971;   // cos(3/8 pi) * sqrt(2)
136         const float a4 = 1.3065629648763766;   // cos(pi/8) * sqrt(2)
137         // static const float a5 = 0.5 * (a4 - a2);
138         const float a5 = 0.3826834323650897;
139
140         // phase 2 (phase 1 is just moving around)
141         const float p2_4 = y5 - y3;
142         const float p2_5 = y1 + y7;
143         const float p2_6 = y1 - y7;
144         const float p2_7 = y5 + y3;
145
146         // phase 3
147         const float p3_2 = y2 - y6;
148         const float p3_3 = y2 + y6;
149         const float p3_5 = p2_5 - p2_7;
150         const float p3_7 = p2_5 + p2_7;
151
152         // phase 4
153         const float p4_2 = a1 * p3_2;
154         const float p4_4 = p2_4 * a2 + (p2_4 + p2_6) * a5;  // Inverted.
155         const float p4_5 = a1 * p3_5;
156         const float p4_6 = p2_6 * a4 - (p2_4 + p2_6) * a5;
157
158         // phase 5
159         const float p5_0 = y0 + y4;
160         const float p5_1 = y0 - y4;
161         const float p5_3 = p4_2 + p3_3;
162
163         // phase 6
164         const float p6_0 = p5_0 + p5_3;
165         const float p6_1 = p5_1 + p4_2;
166         const float p6_2 = p5_1 - p4_2;
167         const float p6_3 = p5_0 - p5_3;
168         const float p6_5 = p4_5 + p4_4;
169         const float p6_6 = p4_5 + p4_6;
170         const float p6_7 = p4_6 + p3_7;
171
172         // phase 7
173         y0 = p6_0 + p6_7;
174         y1 = p6_1 + p6_6;
175         y2 = p6_2 + p6_5;
176         y3 = p6_3 - p4_4;
177         y4 = p6_3 + p4_4;
178         y5 = p6_2 - p6_5;
179         y6 = p6_1 - p6_6;
180         y7 = p6_0 - p6_7;
181 }
182
183 shared float temp[64 * 8 * PARALLEL_SLICES];
184
185 void pick_timer(inout uvec2 start, inout uvec2 t)
186 {
187 #if ENABLE_TIMING
188         uvec2 now = clock2x32ARB();
189
190         uvec2 delta = now - start;
191         if (now.x < start.x) {
192                 --delta.y;
193         }
194
195         uvec2 new_t = t + delta;
196         if (new_t.x < t.x) {
197                 ++new_t.y;
198         }
199         t = new_t;
200
201         start = clock2x32ARB();
202 #endif
203 }
204
205 void main()
206 {
207         uvec2 local_timing[10];
208 #if ENABLE_TIMING
209         for (int timer_idx = 0; timer_idx < 10; ++timer_idx) {
210                 local_timing[timer_idx] = uvec2(0, 0);
211         }
212         uvec2 start = clock2x32ARB();
213 #else
214         uvec2 start = uvec2(0, 0);
215         local_timing[0] = start;
216 #endif
217
218         const uint blocks_per_row = (imageSize(out_tex).x + 7) / 8;
219
220         const uint local_x = gl_LocalInvocationID.x % 8;
221         const uint local_y = (gl_LocalInvocationID.x / 8) % 8;
222         const uint local_z = gl_LocalInvocationID.x / 64;
223
224         const uint slice_num = local_z;
225         const uint thread_num = local_y * 8 + local_x;
226
227         const uint block_row = gl_WorkGroupID.y * PARALLEL_SLICES + slice_num;
228         //const uint coeff_num = ff_zigzag_direct[thread_num];
229         const uint coeff_num = thread_num;
230         const uint stream_num = coeff_num * num_blocks + block_row;
231         const uint model_num = stream_mapping[coeff_num];
232         const uint sign_bias = sign_bias_per_model[model_num];
233
234         // Initialize rANS decoder.
235         uint offset = streams[stream_num].src_offset >> 2;
236         myuint64 rans = RansDecInit(offset);
237
238         float q = (coeff_num == 0) ? 1.0 : (quant_matrix[coeff_num] * quant_scalefac / 128.0 / sqrt(2.0));  // FIXME: fold
239         q *= (1.0 / 255.0);
240         //int w = (coeff_num == 0) ? 32 : int(quant_matrix[coeff_num]);
241         int last_k = 128;
242
243         pick_timer(start, local_timing[0]);
244
245         for (uint block_idx = BLOCKS_PER_STREAM / 8; block_idx --> 0; ) {
246                 pick_timer(start, local_timing[1]);
247
248                 // rANS decode one coefficient across eight blocks (so 64x8 coefficients).
249                 for (uint subblock_idx = 8; subblock_idx --> 0; ) {
250                         // Read a symbol.
251                         uint bottom_bits = RansDecGet(rans, prob_bits + 1);
252                         bool sign = false;
253                         if (bottom_bits >= sign_bias) {
254                                 bottom_bits -= sign_bias;
255                                 rans.low -= sign_bias;
256                                 sign = true;
257                         }
258                         int k = int(cum2sym(bottom_bits, model_num));  // Can go out-of-bounds; that will return zero.
259                         uvec2 sym = get_dsym(k, model_num);
260                         RansDecAdvance(rans, offset, sym.x, sym.y, prob_bits + 1);
261
262                         if (k == ESCAPE_LIMIT) {
263                                 k = int(RansDecGet(rans, prob_bits));
264                                 RansDecAdvance(rans, offset, k, 1, prob_bits);
265                         }
266                         if (sign) {
267                                 k = -k;
268                         }
269 #if 0
270                         if (coeff_num == 0) {
271                                 //imageStore(coeff_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(k, 0,0,0));
272                                 imageStore(coeff_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(rans.low, 0,0,0));
273                                 imageStore(coeff2_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(rans.high, 0,0,0));
274                         }
275 #endif
276
277                         if (coeff_num == 0) {
278                                 k += last_k;
279                                 last_k = k;
280                         }
281
282
283                         temp[slice_num * 64 * 8 + subblock_idx * 64 + coeff_num] = k * q;
284                         //temp[subblock_idx * 64 + 8 * y + x] = (2 * k * w * 4) / 32;  // 100% matching unquant
285                 }
286
287                 pick_timer(start, local_timing[2]);
288
289                 memoryBarrierShared();
290                 barrier();
291
292                 pick_timer(start, local_timing[3]);
293
294                 // Horizontal DCT one row (so 64 rows).
295                 idct_1d(temp[slice_num * 64 * 8 + thread_num * 8 + 0],
296                         temp[slice_num * 64 * 8 + thread_num * 8 + 1],
297                         temp[slice_num * 64 * 8 + thread_num * 8 + 2],
298                         temp[slice_num * 64 * 8 + thread_num * 8 + 3],
299                         temp[slice_num * 64 * 8 + thread_num * 8 + 4],
300                         temp[slice_num * 64 * 8 + thread_num * 8 + 5],
301                         temp[slice_num * 64 * 8 + thread_num * 8 + 6],
302                         temp[slice_num * 64 * 8 + thread_num * 8 + 7]);
303
304                 pick_timer(start, local_timing[4]);
305
306                 memoryBarrierShared();
307                 barrier();
308
309                 pick_timer(start, local_timing[5]);
310
311                 // Vertical DCT one row (so 64 columns).
312                 uint row_offset = local_z * 64 * 8 + local_y * 64 + local_x;
313                 idct_1d(temp[row_offset + 0 * 8],
314                         temp[row_offset + 1 * 8],
315                         temp[row_offset + 2 * 8],
316                         temp[row_offset + 3 * 8],
317                         temp[row_offset + 4 * 8],
318                         temp[row_offset + 5 * 8],
319                         temp[row_offset + 6 * 8],
320                         temp[row_offset + 7 * 8]);
321
322                 pick_timer(start, local_timing[6]);
323
324                 uint global_block_idx = (block_row * 40 + block_idx) * 8 + local_y;
325                 uint block_x = global_block_idx % blocks_per_row;
326                 uint block_y = global_block_idx / blocks_per_row;
327
328                 uint y = block_y * 8;
329                 uint x = block_x * 8 + local_x;
330                 for (uint yl = 0; yl < 8; ++yl) {
331                         imageStore(out_tex, ivec2(x, yl + y), vec4(temp[row_offset + yl * 8], 0.0, 0.0, 1.0));
332                 }
333
334                 pick_timer(start, local_timing[7]);
335
336                 memoryBarrierShared();  // is this needed?
337                 barrier();
338
339                 pick_timer(start, local_timing[8]);
340                 pick_timer(start, local_timing[9]);  // should be nearly nothing
341         }
342
343 #if ENABLE_TIMING
344         for (int timer_idx = 0; timer_idx < 10; ++timer_idx) {
345                 uint global_idx = thread_num * 10 + timer_idx;
346
347                 uint old_val = atomicAdd(timing[global_idx].x, local_timing[timer_idx].x);
348                 if (old_val + local_timing[timer_idx].x < old_val) {
349                         ++local_timing[timer_idx].y;
350                 }
351                 atomicAdd(timing[global_idx].y, local_timing[timer_idx].y);
352         }
353 #endif
354 }