]> git.sesse.net Git - narabu/blob - decoder.shader
Prepare for more flexible slices.
[narabu] / decoder.shader
1 #version 440
2 #extension GL_ARB_shader_clock : enable
3
4 #define PARALLEL_SLICES 1
5
6 #define ENABLE_TIMING 0
7
8 layout(local_size_x = 64*PARALLEL_SLICES) in;
9 layout(r8ui) uniform restrict readonly uimage2D cum2sym_tex;
10 layout(rg16ui) uniform restrict readonly uimage2D dsyms_tex;
11 layout(r8) uniform restrict writeonly image2D out_tex;
12 layout(r16i) uniform restrict writeonly iimage2D coeff_tex;
13
14 const uint prob_bits = 12;
15 const uint prob_scale = 1 << prob_bits;
16 const uint NUM_SYMS = 256;
17 const uint ESCAPE_LIMIT = NUM_SYMS - 1;
18
19 // These need to be folded into quant_matrix.
20 const float dc_scalefac = 8.0;
21 const float quant_scalefac = 4.0;
22
23 const float quant_matrix[64] = {
24          8, 16, 19, 22, 26, 27, 29, 34,
25         16, 16, 22, 24, 27, 29, 34, 37,
26         19, 22, 26, 27, 29, 34, 34, 38,
27         22, 22, 26, 27, 29, 34, 37, 40,
28         22, 26, 27, 29, 32, 35, 40, 48,
29         26, 27, 29, 32, 35, 40, 48, 58,
30         26, 27, 29, 34, 38, 46, 56, 69,
31         27, 29, 35, 38, 46, 56, 69, 83
32 };
33 const uint ff_zigzag_direct[64] = {
34     0,   1,  8, 16,  9,  2,  3, 10,
35     17, 24, 32, 25, 18, 11,  4,  5,
36     12, 19, 26, 33, 40, 48, 41, 34,
37     27, 20, 13,  6,  7, 14, 21, 28,
38     35, 42, 49, 56, 57, 50, 43, 36,
39     29, 22, 15, 23, 30, 37, 44, 51,
40     58, 59, 52, 45, 38, 31, 39, 46,
41     53, 60, 61, 54, 47, 55, 62, 63
42 };
43 const uint stream_mapping[64] = {
44         0, 0, 1, 1, 2, 2, 3, 3,
45         0, 0, 1, 2, 2, 2, 3, 3,
46         1, 1, 2, 2, 2, 3, 3, 3,
47         1, 1, 2, 2, 2, 3, 3, 3,
48         1, 2, 2, 2, 2, 3, 3, 3,
49         2, 2, 2, 2, 3, 3, 3, 3,
50         2, 2, 3, 3, 3, 3, 3, 3,
51         3, 3, 3, 3, 3, 3, 3, 3,
52 };
53
54 layout(std430, binding = 9) buffer layoutName
55 {
56         uint data_SSBO[];
57 };
58 layout(std430, binding = 10) buffer layoutName2
59 {
60         uvec2 timing[10 * 64];
61 };
62
63 struct CoeffStream {
64         uint src_offset, src_len;
65 };
66 layout(std430, binding = 0) buffer whatever3
67 {
68         CoeffStream streams[];
69 };
70 uniform uint sign_bias_per_model[16];
71
72 const uint RANS_BYTE_L = (1u << 23);  // lower bound of our normalization interval
73
74 uint get_rans_byte(uint offset)
75 {
76         // We assume little endian.
77         return bitfieldExtract(data_SSBO[offset >> 2], 8 * int(offset & 3u), 8);
78 }
79
80 uint RansDecInit(inout uint offset)
81 {
82         uint x;
83
84         x  = get_rans_byte(offset);
85         x |= get_rans_byte(offset + 1) << 8;
86         x |= get_rans_byte(offset + 2) << 16;
87         x |= get_rans_byte(offset + 3) << 24;
88         offset += 4;
89
90         return x;
91 }
92
93 uint RansDecGet(uint r, uint scale_bits)
94 {
95         return r & ((1u << scale_bits) - 1);
96 }
97
98 void RansDecAdvance(inout uint rans, inout uint offset, const uint start, const uint freq, uint prob_bits)
99 {
100         const uint mask = (1u << prob_bits) - 1;
101         rans = freq * (rans >> prob_bits) + (rans & mask) - start;
102         
103         // renormalize
104         while (rans < RANS_BYTE_L) {
105                 rans = (rans << 8) | get_rans_byte(offset++);
106         }
107 }
108
109 uint cum2sym(uint bits, uint table)
110 {
111         return imageLoad(cum2sym_tex, ivec2(bits, table)).x;
112 }
113
114 uvec2 get_dsym(uint k, uint table)
115 {
116         return imageLoad(dsyms_tex, ivec2(k, table)).xy;
117 }
118
119 void idct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
120 {
121         const float a1 = 0.7071067811865474;   // sqrt(2)
122         const float a2 = 0.5411961001461971;   // cos(3/8 pi) * sqrt(2)
123         const float a4 = 1.3065629648763766;   // cos(pi/8) * sqrt(2)
124         // static const float a5 = 0.5 * (a4 - a2);
125         const float a5 = 0.3826834323650897;
126
127         // phase 2 (phase 1 is just moving around)
128         const float p2_4 = y5 - y3;
129         const float p2_5 = y1 + y7;
130         const float p2_6 = y1 - y7;
131         const float p2_7 = y5 + y3;
132
133         // phase 3
134         const float p3_2 = y2 - y6;
135         const float p3_3 = y2 + y6;
136         const float p3_5 = p2_5 - p2_7;
137         const float p3_7 = p2_5 + p2_7;
138
139         // phase 4
140         const float p4_2 = a1 * p3_2;
141         const float p4_4 = p2_4 * a2 + (p2_4 + p2_6) * a5;  // Inverted.
142         const float p4_5 = a1 * p3_5;
143         const float p4_6 = p2_6 * a4 - (p2_4 + p2_6) * a5;
144
145         // phase 5
146         const float p5_0 = y0 + y4;
147         const float p5_1 = y0 - y4;
148         const float p5_3 = p4_2 + p3_3;
149
150         // phase 6
151         const float p6_0 = p5_0 + p5_3;
152         const float p6_1 = p5_1 + p4_2;
153         const float p6_2 = p5_1 - p4_2;
154         const float p6_3 = p5_0 - p5_3;
155         const float p6_5 = p4_5 + p4_4;
156         const float p6_6 = p4_5 + p4_6;
157         const float p6_7 = p4_6 + p3_7;
158
159         // phase 7
160         y0 = p6_0 + p6_7;
161         y1 = p6_1 + p6_6;
162         y2 = p6_2 + p6_5;
163         y3 = p6_3 - p4_4;
164         y4 = p6_3 + p4_4;
165         y5 = p6_2 - p6_5;
166         y6 = p6_1 - p6_6;
167         y7 = p6_0 - p6_7;
168 }
169
170 shared float temp[64 * 8 * PARALLEL_SLICES];
171
172 void pick_timer(inout uvec2 start, inout uvec2 t)
173 {
174 #if ENABLE_TIMING
175         uvec2 now = clock2x32ARB();
176
177         uvec2 delta = now - start;
178         if (now.x < start.x) {
179                 --delta.y;
180         }
181
182         uvec2 new_t = t + delta;
183         if (new_t.x < t.x) {
184                 ++new_t.y;
185         }
186         t = new_t;
187
188         start = clock2x32ARB();
189 #endif
190 }
191
192 void main()
193 {
194         uvec2 local_timing[10];
195 #if ENABLE_TIMING
196         for (int timer_idx = 0; timer_idx < 10; ++timer_idx) {
197                 local_timing[timer_idx] = uvec2(0, 0);
198         }
199         uvec2 start = clock2x32ARB();
200 #else
201         uvec2 start = uvec2(0, 0);
202         local_timing[0] = start;
203 #endif
204
205         const uint local_x = gl_LocalInvocationID.x % 8;
206         const uint local_y = (gl_LocalInvocationID.x / 8) % 8;
207         const uint local_z = gl_LocalInvocationID.x / 64;
208
209         const uint num_blocks = 720 / 16;  // FIXME: make a uniform
210         const uint slice_num = local_z;
211         const uint thread_num = local_y * 8 + local_x;
212
213         const uint block_row = gl_WorkGroupID.y * PARALLEL_SLICES + slice_num;
214         //const uint coeff_num = ff_zigzag_direct[thread_num];
215         const uint coeff_num = thread_num;
216         const uint stream_num = coeff_num * num_blocks + block_row;
217         const uint model_num = stream_mapping[coeff_num];
218         const uint sign_bias = sign_bias_per_model[model_num];
219
220         // Initialize rANS decoder.
221         uint offset = streams[stream_num].src_offset;
222         uint rans = RansDecInit(offset);
223
224         float q = (coeff_num == 0) ? 1.0 : (quant_matrix[coeff_num] * quant_scalefac / 128.0 / sqrt(2.0));  // FIXME: fold
225         q *= (1.0 / 255.0);
226         //int w = (coeff_num == 0) ? 32 : int(quant_matrix[coeff_num]);
227         int last_k = 0;
228
229         pick_timer(start, local_timing[0]);
230
231         for (uint block_idx = 40; block_idx --> 0; ) {
232                 uint block_x = block_idx % 20;
233                 uint block_y = block_idx / 20;
234                 if (block_x == 19) last_k = 0;
235
236                 pick_timer(start, local_timing[1]);
237
238                 // rANS decode one coefficient across eight blocks (so 64x8 coefficients).
239                 for (uint subblock_idx = 8; subblock_idx --> 0; ) {
240                         // Read a symbol.
241                         uint bottom_bits = RansDecGet(rans, prob_bits + 1);
242                         bool sign = false;
243                         if (bottom_bits >= sign_bias) {
244                                 bottom_bits -= sign_bias;
245                                 rans -= sign_bias;
246                                 sign = true;
247                         }
248                         int k = int(cum2sym(bottom_bits, model_num));  // Can go out-of-bounds; that will return zero.
249                         uvec2 sym = get_dsym(k, model_num);
250                         RansDecAdvance(rans, offset, sym.x, sym.y, prob_bits + 1);
251
252                         if (k == ESCAPE_LIMIT) {
253                                 k = int(RansDecGet(rans, prob_bits));
254                                 RansDecAdvance(rans, offset, k, 1, prob_bits);
255                         }
256                         if (sign) {
257                                 k = -k;
258                         }
259
260                         if (coeff_num == 0) {
261                                 k += last_k;
262                                 last_k = k;
263                         }
264
265 #if 0
266                         uint y = block_row * 16 + block_y * 8 + local_y;
267                         uint x = block_x * 64 + subblock_idx * 8 + local_x;
268                         imageStore(coeff_tex, ivec2(x, y), ivec4(k, 0,0,0));
269 #endif
270
271                         temp[slice_num * 64 * 8 + subblock_idx * 64 + coeff_num] = k * q;
272                         //temp[subblock_idx * 64 + 8 * y + x] = (2 * k * w * 4) / 32;  // 100% matching unquant
273                 }
274
275                 pick_timer(start, local_timing[2]);
276
277                 memoryBarrierShared();
278                 barrier();
279
280                 pick_timer(start, local_timing[3]);
281
282                 // Horizontal DCT one row (so 64 rows).
283                 idct_1d(temp[slice_num * 64 * 8 + thread_num * 8 + 0],
284                         temp[slice_num * 64 * 8 + thread_num * 8 + 1],
285                         temp[slice_num * 64 * 8 + thread_num * 8 + 2],
286                         temp[slice_num * 64 * 8 + thread_num * 8 + 3],
287                         temp[slice_num * 64 * 8 + thread_num * 8 + 4],
288                         temp[slice_num * 64 * 8 + thread_num * 8 + 5],
289                         temp[slice_num * 64 * 8 + thread_num * 8 + 6],
290                         temp[slice_num * 64 * 8 + thread_num * 8 + 7]);
291
292                 pick_timer(start, local_timing[4]);
293
294                 memoryBarrierShared();
295                 barrier();
296
297                 pick_timer(start, local_timing[5]);
298
299                 // Vertical DCT one row (so 64 columns).
300                 uint row_offset = local_z * 64 * 8 + local_y * 64 + local_x;
301                 idct_1d(temp[row_offset + 0 * 8],
302                         temp[row_offset + 1 * 8],
303                         temp[row_offset + 2 * 8],
304                         temp[row_offset + 3 * 8],
305                         temp[row_offset + 4 * 8],
306                         temp[row_offset + 5 * 8],
307                         temp[row_offset + 6 * 8],
308                         temp[row_offset + 7 * 8]);
309
310                 pick_timer(start, local_timing[6]);
311
312                 uint y = block_row * 16 + block_y * 8;
313                 uint x = block_x * 64 + local_y * 8 + local_x;
314                 for (uint yl = 0; yl < 8; ++yl) {
315                         imageStore(out_tex, ivec2(x, yl + y), vec4(temp[row_offset + yl * 8], 0.0, 0.0, 1.0));
316                 }
317
318                 pick_timer(start, local_timing[7]);
319
320                 memoryBarrierShared();  // is this needed?
321                 barrier();
322
323                 pick_timer(start, local_timing[8]);
324                 pick_timer(start, local_timing[9]);  // should be nearly nothing
325         }
326
327 #if ENABLE_TIMING
328         for (int timer_idx = 0; timer_idx < 10; ++timer_idx) {
329                 uint global_idx = thread_num * 10 + timer_idx;
330
331                 uint old_val = atomicAdd(timing[global_idx].x, local_timing[timer_idx].x);
332                 if (old_val + local_timing[timer_idx].x < old_val) {
333                         ++local_timing[timer_idx].y;
334                 }
335                 atomicAdd(timing[global_idx].y, local_timing[timer_idx].y);
336         }
337 #endif
338 }