2 #extension GL_ARB_shader_clock : enable
4 #define PARALLEL_SLICES 1
6 #define ENABLE_TIMING 0
8 layout(local_size_x = 64*PARALLEL_SLICES) in;
9 layout(r8ui) uniform restrict readonly uimage2D cum2sym_tex;
10 layout(rg16ui) uniform restrict readonly uimage2D dsyms_tex;
11 layout(r8) uniform restrict writeonly image2D out_tex;
12 layout(r16i) uniform restrict writeonly iimage2D coeff_tex;
13 uniform int num_blocks;
15 const uint prob_bits = 12;
16 const uint prob_scale = 1 << prob_bits;
17 const uint NUM_SYMS = 256;
18 const uint ESCAPE_LIMIT = NUM_SYMS - 1;
19 const uint BLOCKS_PER_STREAM = 320;
21 // These need to be folded into quant_matrix.
22 const float dc_scalefac = 8.0;
23 const float quant_scalefac = 4.0;
25 const float quant_matrix[64] = {
26 8, 16, 19, 22, 26, 27, 29, 34,
27 16, 16, 22, 24, 27, 29, 34, 37,
28 19, 22, 26, 27, 29, 34, 34, 38,
29 22, 22, 26, 27, 29, 34, 37, 40,
30 22, 26, 27, 29, 32, 35, 40, 48,
31 26, 27, 29, 32, 35, 40, 48, 58,
32 26, 27, 29, 34, 38, 46, 56, 69,
33 27, 29, 35, 38, 46, 56, 69, 83
35 const uint ff_zigzag_direct[64] = {
36 0, 1, 8, 16, 9, 2, 3, 10,
37 17, 24, 32, 25, 18, 11, 4, 5,
38 12, 19, 26, 33, 40, 48, 41, 34,
39 27, 20, 13, 6, 7, 14, 21, 28,
40 35, 42, 49, 56, 57, 50, 43, 36,
41 29, 22, 15, 23, 30, 37, 44, 51,
42 58, 59, 52, 45, 38, 31, 39, 46,
43 53, 60, 61, 54, 47, 55, 62, 63
45 const uint stream_mapping[64] = {
46 0, 0, 1, 1, 2, 2, 3, 3,
47 0, 0, 1, 2, 2, 2, 3, 3,
48 1, 1, 2, 2, 2, 3, 3, 3,
49 1, 1, 2, 2, 2, 3, 3, 3,
50 1, 2, 2, 2, 2, 3, 3, 3,
51 2, 2, 2, 2, 3, 3, 3, 3,
52 2, 2, 3, 3, 3, 3, 3, 3,
53 3, 3, 3, 3, 3, 3, 3, 3,
56 layout(std430, binding = 9) buffer layoutName
60 layout(std430, binding = 10) buffer layoutName2
62 uvec2 timing[10 * 64];
66 uint src_offset, src_len;
68 layout(std430, binding = 0) buffer whatever3
70 CoeffStream streams[];
72 uniform uint sign_bias_per_model[16];
74 const uint RANS_BYTE_L = (1u << 23); // lower bound of our normalization interval
76 uint get_rans_byte(uint offset)
78 // We assume little endian.
79 return bitfieldExtract(data_SSBO[offset >> 2], 8 * int(offset & 3u), 8);
82 uint RansDecInit(inout uint offset)
86 x = get_rans_byte(offset);
87 x |= get_rans_byte(offset + 1) << 8;
88 x |= get_rans_byte(offset + 2) << 16;
89 x |= get_rans_byte(offset + 3) << 24;
95 uint RansDecGet(uint r, uint scale_bits)
97 return r & ((1u << scale_bits) - 1);
100 void RansDecAdvance(inout uint rans, inout uint offset, const uint start, const uint freq, uint prob_bits)
102 const uint mask = (1u << prob_bits) - 1;
103 rans = freq * (rans >> prob_bits) + (rans & mask) - start;
106 while (rans < RANS_BYTE_L) {
107 rans = (rans << 8) | get_rans_byte(offset++);
111 uint cum2sym(uint bits, uint table)
113 return imageLoad(cum2sym_tex, ivec2(bits, table)).x;
116 uvec2 get_dsym(uint k, uint table)
118 return imageLoad(dsyms_tex, ivec2(k, table)).xy;
121 void idct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
123 const float a1 = 0.7071067811865474; // sqrt(2)
124 const float a2 = 0.5411961001461971; // cos(3/8 pi) * sqrt(2)
125 const float a4 = 1.3065629648763766; // cos(pi/8) * sqrt(2)
126 // static const float a5 = 0.5 * (a4 - a2);
127 const float a5 = 0.3826834323650897;
129 // phase 2 (phase 1 is just moving around)
130 const float p2_4 = y5 - y3;
131 const float p2_5 = y1 + y7;
132 const float p2_6 = y1 - y7;
133 const float p2_7 = y5 + y3;
136 const float p3_2 = y2 - y6;
137 const float p3_3 = y2 + y6;
138 const float p3_5 = p2_5 - p2_7;
139 const float p3_7 = p2_5 + p2_7;
142 const float p4_2 = a1 * p3_2;
143 const float p4_4 = p2_4 * a2 + (p2_4 + p2_6) * a5; // Inverted.
144 const float p4_5 = a1 * p3_5;
145 const float p4_6 = p2_6 * a4 - (p2_4 + p2_6) * a5;
148 const float p5_0 = y0 + y4;
149 const float p5_1 = y0 - y4;
150 const float p5_3 = p4_2 + p3_3;
153 const float p6_0 = p5_0 + p5_3;
154 const float p6_1 = p5_1 + p4_2;
155 const float p6_2 = p5_1 - p4_2;
156 const float p6_3 = p5_0 - p5_3;
157 const float p6_5 = p4_5 + p4_4;
158 const float p6_6 = p4_5 + p4_6;
159 const float p6_7 = p4_6 + p3_7;
172 shared float temp[64 * 8 * PARALLEL_SLICES];
174 void pick_timer(inout uvec2 start, inout uvec2 t)
177 uvec2 now = clock2x32ARB();
179 uvec2 delta = now - start;
180 if (now.x < start.x) {
184 uvec2 new_t = t + delta;
190 start = clock2x32ARB();
196 uvec2 local_timing[10];
198 for (int timer_idx = 0; timer_idx < 10; ++timer_idx) {
199 local_timing[timer_idx] = uvec2(0, 0);
201 uvec2 start = clock2x32ARB();
203 uvec2 start = uvec2(0, 0);
204 local_timing[0] = start;
207 const uint blocks_per_row = (imageSize(out_tex).x + 7) / 8;
209 const uint local_x = gl_LocalInvocationID.x % 8;
210 const uint local_y = (gl_LocalInvocationID.x / 8) % 8;
211 const uint local_z = gl_LocalInvocationID.x / 64;
213 const uint slice_num = local_z;
214 const uint thread_num = local_y * 8 + local_x;
216 const uint block_row = gl_WorkGroupID.y * PARALLEL_SLICES + slice_num;
217 //const uint coeff_num = ff_zigzag_direct[thread_num];
218 const uint coeff_num = thread_num;
219 const uint stream_num = coeff_num * num_blocks + block_row;
220 const uint model_num = stream_mapping[coeff_num];
221 const uint sign_bias = sign_bias_per_model[model_num];
223 // Initialize rANS decoder.
224 uint offset = streams[stream_num].src_offset;
225 uint rans = RansDecInit(offset);
227 float q = (coeff_num == 0) ? 1.0 : (quant_matrix[coeff_num] * quant_scalefac / 128.0 / sqrt(2.0)); // FIXME: fold
229 //int w = (coeff_num == 0) ? 32 : int(quant_matrix[coeff_num]);
232 pick_timer(start, local_timing[0]);
234 for (uint block_idx = BLOCKS_PER_STREAM / 8; block_idx --> 0; ) {
235 pick_timer(start, local_timing[1]);
237 // rANS decode one coefficient across eight blocks (so 64x8 coefficients).
238 for (uint subblock_idx = 8; subblock_idx --> 0; ) {
240 uint bottom_bits = RansDecGet(rans, prob_bits + 1);
242 if (bottom_bits >= sign_bias) {
243 bottom_bits -= sign_bias;
247 int k = int(cum2sym(bottom_bits, model_num)); // Can go out-of-bounds; that will return zero.
248 uvec2 sym = get_dsym(k, model_num);
249 RansDecAdvance(rans, offset, sym.x, sym.y, prob_bits + 1);
251 if (k == ESCAPE_LIMIT) {
252 k = int(RansDecGet(rans, prob_bits));
253 RansDecAdvance(rans, offset, k, 1, prob_bits);
259 if (coeff_num == 0) {
265 uint y = block_row * 16 + block_y * 8 + local_y;
266 uint x = block_x * 64 + subblock_idx * 8 + local_x;
267 imageStore(coeff_tex, ivec2(x, y), ivec4(k, 0,0,0));
270 temp[slice_num * 64 * 8 + subblock_idx * 64 + coeff_num] = k * q;
271 //temp[subblock_idx * 64 + 8 * y + x] = (2 * k * w * 4) / 32; // 100% matching unquant
274 pick_timer(start, local_timing[2]);
276 memoryBarrierShared();
279 pick_timer(start, local_timing[3]);
281 // Horizontal DCT one row (so 64 rows).
282 idct_1d(temp[slice_num * 64 * 8 + thread_num * 8 + 0],
283 temp[slice_num * 64 * 8 + thread_num * 8 + 1],
284 temp[slice_num * 64 * 8 + thread_num * 8 + 2],
285 temp[slice_num * 64 * 8 + thread_num * 8 + 3],
286 temp[slice_num * 64 * 8 + thread_num * 8 + 4],
287 temp[slice_num * 64 * 8 + thread_num * 8 + 5],
288 temp[slice_num * 64 * 8 + thread_num * 8 + 6],
289 temp[slice_num * 64 * 8 + thread_num * 8 + 7]);
291 pick_timer(start, local_timing[4]);
293 memoryBarrierShared();
296 pick_timer(start, local_timing[5]);
298 // Vertical DCT one row (so 64 columns).
299 uint row_offset = local_z * 64 * 8 + local_y * 64 + local_x;
300 idct_1d(temp[row_offset + 0 * 8],
301 temp[row_offset + 1 * 8],
302 temp[row_offset + 2 * 8],
303 temp[row_offset + 3 * 8],
304 temp[row_offset + 4 * 8],
305 temp[row_offset + 5 * 8],
306 temp[row_offset + 6 * 8],
307 temp[row_offset + 7 * 8]);
309 pick_timer(start, local_timing[6]);
311 uint global_block_idx = (block_row * 40 + block_idx) * 8 + local_y;
312 uint block_x = global_block_idx % blocks_per_row;
313 uint block_y = global_block_idx / blocks_per_row;
315 uint y = block_y * 8;
316 uint x = block_x * 8 + local_x;
317 for (uint yl = 0; yl < 8; ++yl) {
318 imageStore(out_tex, ivec2(x, yl + y), vec4(temp[row_offset + yl * 8], 0.0, 0.0, 1.0));
321 pick_timer(start, local_timing[7]);
323 memoryBarrierShared(); // is this needed?
326 pick_timer(start, local_timing[8]);
327 pick_timer(start, local_timing[9]); // should be nearly nothing
331 for (int timer_idx = 0; timer_idx < 10; ++timer_idx) {
332 uint global_idx = thread_num * 10 + timer_idx;
334 uint old_val = atomicAdd(timing[global_idx].x, local_timing[timer_idx].x);
335 if (old_val + local_timing[timer_idx].x < old_val) {
336 ++local_timing[timer_idx].y;
338 atomicAdd(timing[global_idx].y, local_timing[timer_idx].y);