]> git.sesse.net Git - narabu/blob - decoder.shader
Add the GPU decoder itself.
[narabu] / decoder.shader
1 #version 430
2 #extension GL_ARB_shader_clock : enable
3
4 #define ENABLE_TIMING 0
5
6 layout(local_size_x = 8, local_size_y = 8) in;
7 layout(r8ui) uniform restrict readonly uimage2D cum2sym_tex;
8 layout(rg16ui) uniform restrict readonly uimage2D dsyms_tex;
9 layout(r8) uniform restrict writeonly image2D out_tex;
10
11 const uint prob_bits = 12;
12 const uint prob_scale = 1 << prob_bits;
13 const uint NUM_SYMS = 256;
14 const uint ESCAPE_LIMIT = NUM_SYMS - 1;
15
16 // These need to be folded into quant_matrix.
17 const float dc_scalefac = 8.0;
18 const float quant_scalefac = 4.0;
19
20 const float quant_matrix[64] = {
21          8, 16, 19, 22, 26, 27, 29, 34,
22         16, 16, 22, 24, 27, 29, 34, 37,
23         19, 22, 26, 27, 29, 34, 34, 38,
24         22, 22, 26, 27, 29, 34, 37, 40,
25         22, 26, 27, 29, 32, 35, 40, 48,
26         26, 27, 29, 32, 35, 40, 48, 58,
27         26, 27, 29, 34, 38, 46, 56, 69,
28         27, 29, 35, 38, 46, 56, 69, 83
29 };
30 const uint ff_zigzag_direct[64] = {
31     0,   1,  8, 16,  9,  2,  3, 10,
32     17, 24, 32, 25, 18, 11,  4,  5,
33     12, 19, 26, 33, 40, 48, 41, 34,
34     27, 20, 13,  6,  7, 14, 21, 28,
35     35, 42, 49, 56, 57, 50, 43, 36,
36     29, 22, 15, 23, 30, 37, 44, 51,
37     58, 59, 52, 45, 38, 31, 39, 46,
38     53, 60, 61, 54, 47, 55, 62, 63
39 };
40
41 layout(std430, binding = 9) buffer layoutName
42 {
43         uint data_SSBO[];
44 };
45 layout(std430, binding = 10) buffer layoutName2
46 {
47         uvec2 timing[10 * 64];
48 };
49
50 struct CoeffStream {
51         uint src_offset, src_len, sign_offset, sign_len, extra_bits;
52 };
53 layout(std430, binding = 0) buffer whatever3
54 {
55         CoeffStream streams[];
56 };
57
58 uniform uint src_offset, src_len, sign_offset, sign_len, extra_bits;
59
60 const uint RANS_BYTE_L = (1u << 23);  // lower bound of our normalization interval
61
62 uint last_offset = -1, ransbuf;
63
64 uint get_rans_byte(uint offset)
65 {
66         if (last_offset != (offset >> 2)) {
67                 last_offset = offset >> 2;
68                 ransbuf = data_SSBO[offset >> 2];
69         }
70         return bitfieldExtract(ransbuf, 8 * int(offset & 3u), 8);
71
72         // We assume little endian.
73 //      return bitfieldExtract(data_SSBO[offset >> 2], 8 * int(offset & 3u), 8);
74 }
75
76 void RansDecInit(out uint r, inout uint offset)
77 {
78         uint x;
79
80         x  = get_rans_byte(offset);
81         x |= get_rans_byte(offset + 1) << 8;
82         x |= get_rans_byte(offset + 2) << 16;
83         x |= get_rans_byte(offset + 3) << 24;
84         offset += 4;
85
86         r = x;
87 }
88
89 uint RansDecGet(uint r, uint scale_bits)
90 {
91         return r & ((1u << scale_bits) - 1);
92 }
93
94 void RansDecAdvance(inout uint rans, inout uint offset, const uint start, const uint freq, uint prob_bits)
95 {
96         const uint mask = (1u << prob_bits) - 1;
97         rans = freq * (rans >> prob_bits) + (rans & mask) - start;
98         
99         // renormalize
100         while (rans < RANS_BYTE_L) {
101                 rans = (rans << 8) | get_rans_byte(offset++);
102         }
103 }
104
105 uint cum2sym(uint bits, uint table)
106 {
107         return imageLoad(cum2sym_tex, ivec2(bits, table)).x;
108 }
109
110 uvec2 get_dsym(uint k, uint table)
111 {
112         return imageLoad(dsyms_tex, ivec2(k, table)).xy;
113 }
114
115 void idct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
116 {
117         const float a1 = 0.7071067811865474;   // sqrt(2)
118         const float a2 = 0.5411961001461971;   // cos(3/8 pi) * sqrt(2)
119         const float a4 = 1.3065629648763766;   // cos(pi/8) * sqrt(2)
120         // static const float a5 = 0.5 * (a4 - a2);
121         const float a5 = 0.3826834323650897;
122
123         // phase 2 (phase 1 is just moving around)
124         const float p2_4 = y5 - y3;
125         const float p2_5 = y1 + y7;
126         const float p2_6 = y1 - y7;
127         const float p2_7 = y5 + y3;
128
129         // phase 3
130         const float p3_2 = y2 - y6;
131         const float p3_3 = y2 + y6;
132         const float p3_5 = p2_5 - p2_7;
133         const float p3_7 = p2_5 + p2_7;
134
135         // phase 4
136         const float p4_2 = a1 * p3_2;
137         const float p4_4 = p2_4 * a2 + (p2_4 + p2_6) * a5;  // Inverted.
138         const float p4_5 = a1 * p3_5;
139         const float p4_6 = p2_6 * a4 - (p2_4 + p2_6) * a5;
140
141         // phase 5
142         const float p5_0 = y0 + y4;
143         const float p5_1 = y0 - y4;
144         const float p5_3 = p4_2 + p3_3;
145
146         // phase 6
147         const float p6_0 = p5_0 + p5_3;
148         const float p6_1 = p5_1 + p4_2;
149         const float p6_2 = p5_1 - p4_2;
150         const float p6_3 = p5_0 - p5_3;
151         const float p6_5 = p4_5 + p4_4;
152         const float p6_6 = p4_5 + p4_6;
153         const float p6_7 = p4_6 + p3_7;
154
155         // phase 7
156         y0 = p6_0 + p6_7;
157         y1 = p6_1 + p6_6;
158         y2 = p6_2 + p6_5;
159         y3 = p6_3 - p4_4;
160         y4 = p6_3 + p4_4;
161         y5 = p6_2 - p6_5;
162         y6 = p6_1 - p6_6;
163         y7 = p6_0 - p6_7;
164 }
165
166 shared float temp[64 * 8];
167
168 void pick_timer(inout uvec2 start, inout uvec2 t)
169 {
170 #if ENABLE_TIMING
171         uvec2 now = clock2x32ARB();
172
173         uvec2 delta = now - start;
174         if (now.x < start.x) {
175                 --delta.y;
176         }
177
178         uvec2 new_t = t + delta;
179         if (new_t.x < t.x) {
180                 ++new_t.y;
181         }
182         t = new_t;
183
184         start = clock2x32ARB();
185 #endif
186 }
187
188 void main()
189 {
190         uvec2 local_timing[10];
191 #if ENABLE_TIMING
192         for (int timer_idx = 0; timer_idx < 10; ++timer_idx) {
193                 local_timing[timer_idx] = uvec2(0, 0);
194         }
195         uvec2 start = clock2x32ARB();
196 #else
197         uvec2 start;
198 #endif
199
200         const uint num_blocks = 720 / 16;  // FIXME: make a uniform
201         const uint thread_num = gl_LocalInvocationID.y * 8 + gl_LocalInvocationID.x;
202
203         const uint block_row = gl_WorkGroupID.y;
204         //const uint coeff_num = ff_zigzag_direct[thread_num];
205         const uint coeff_num = thread_num;
206         const uint stream_num = coeff_num * num_blocks + block_row;
207         //const uint stream_num = block_row * num_blocks + coeff_num;  // HACK
208         const uint model_num = min((coeff_num % 8) + (coeff_num / 8), 7);
209
210         // Initialize rANS decoder.
211         uint offset = streams[stream_num].src_offset;
212         uint rans;
213         RansDecInit(rans, offset);
214
215         // Initialize sign bit decoder. TODO: this ought to be 32-bit-aligned instead!
216         uint soffset = streams[stream_num].sign_offset;
217         uint sign_buf = get_rans_byte(soffset++) >> streams[stream_num].extra_bits;
218         uint sign_bits_left = 8 - streams[stream_num].extra_bits;
219
220         float q = (coeff_num == 0) ? 1.0 : (quant_matrix[coeff_num] * quant_scalefac / 128.0 / sqrt(2.0));  // FIXME: fold
221         q *= (1.0 / 255.0);
222         //int w = (coeff_num == 0) ? 32 : int(quant_matrix[coeff_num]);
223         int last_k = 0;
224
225         pick_timer(start, local_timing[0]);
226
227         for (uint block_idx = 40; block_idx --> 0; ) {
228                 uint block_x = block_idx % 20;
229                 uint block_y = block_idx / 20;
230                 if (block_x == 19) last_k = 0;
231
232                 pick_timer(start, local_timing[1]);
233
234                 // rANS decode one coefficient across eight blocks (so 64x8 coefficients).
235                 for (uint subblock_idx = 8; subblock_idx --> 0; ) {
236                         // Read a symbol.
237                         int k = int(cum2sym(RansDecGet(rans, prob_bits), model_num));
238                         uvec2 sym = get_dsym(k, model_num);
239                         RansDecAdvance(rans, offset, sym.x, sym.y, prob_bits);
240
241                         if (k == ESCAPE_LIMIT) {
242                                 k = int(RansDecGet(rans, prob_bits));
243                                 RansDecAdvance(rans, offset, k, 1, prob_bits);
244                         }
245                         if (k != 0) {
246                                 if (sign_bits_left == 0) {
247                                         sign_buf = get_rans_byte(soffset++);
248                                         sign_bits_left = 8;
249                                 }
250                                 if ((sign_buf & 1u) == 1u) k = -k;
251                                 --sign_bits_left;
252                                 sign_buf >>= 1;
253                         }
254
255                         if (coeff_num == 0) {
256                                 k += last_k;
257                                 last_k = k;
258                         }
259
260                         temp[subblock_idx * 64 + coeff_num] = k * q;
261                         //temp[subblock_idx * 64 + 8 * y + x] = (2 * k * w * 4) / 32;  // 100% matching unquant
262                 }
263
264                 pick_timer(start, local_timing[2]);
265
266                 memoryBarrierShared();
267                 barrier();
268
269                 pick_timer(start, local_timing[3]);
270
271                 // Horizontal DCT one row (so 64 rows).
272                 idct_1d(temp[thread_num * 8 + 0],
273                         temp[thread_num * 8 + 1],
274                         temp[thread_num * 8 + 2],
275                         temp[thread_num * 8 + 3],
276                         temp[thread_num * 8 + 4],
277                         temp[thread_num * 8 + 5],
278                         temp[thread_num * 8 + 6],
279                         temp[thread_num * 8 + 7]);
280
281                 pick_timer(start, local_timing[4]);
282
283                 memoryBarrierShared();
284                 barrier();
285
286                 pick_timer(start, local_timing[5]);
287
288                 // Vertical DCT one row (so 64 columns).
289                 uint row_offset = gl_LocalInvocationID.y * 64 + gl_LocalInvocationID.x;
290                 idct_1d(temp[row_offset + 0 * 8],
291                         temp[row_offset + 1 * 8],
292                         temp[row_offset + 2 * 8],
293                         temp[row_offset + 3 * 8],
294                         temp[row_offset + 4 * 8],
295                         temp[row_offset + 5 * 8],
296                         temp[row_offset + 6 * 8],
297                         temp[row_offset + 7 * 8]);
298
299                 pick_timer(start, local_timing[6]);
300
301                 uint y = block_row * 16 + block_y * 8;
302                 uint x = block_x * 64 + gl_LocalInvocationID.y * 8 + gl_LocalInvocationID.x;
303                 for (uint yl = 0; yl < 8; ++yl) {
304                         imageStore(out_tex, ivec2(x, yl + y), vec4(temp[row_offset + yl * 8], 0.0, 0.0, 1.0));
305                 }
306
307                 pick_timer(start, local_timing[7]);
308
309                 memoryBarrierShared();  // is this needed?
310                 barrier();
311
312                 pick_timer(start, local_timing[8]);
313                 pick_timer(start, local_timing[9]);  // should be nearly nothing
314         }
315
316 #if ENABLE_TIMING
317         for (int timer_idx = 0; timer_idx < 10; ++timer_idx) {
318                 uint global_idx = thread_num * 10 + timer_idx;
319
320                 uint old_val = atomicAdd(timing[global_idx].x, local_timing[timer_idx].x);
321                 if (old_val + local_timing[timer_idx].x < old_val) {
322                         ++local_timing[timer_idx].y;
323                 }
324                 atomicAdd(timing[global_idx].y, local_timing[timer_idx].y);
325         }
326 #endif
327 }