]> git.sesse.net Git - narabu/blobdiff - decoder.shader
Switch to 64-bit rANS, although probably due for immediate revert (just want to prese...
[narabu] / decoder.shader
index 6d54e4dc5671173f540a1574666e51870a14b1c2..012cbe3bbac0ec13bb40b0cf42bd9ecc4bdcadf0 100644 (file)
@@ -1,17 +1,23 @@
-#version 430
+#version 440
 #extension GL_ARB_shader_clock : enable
 
+#define PARALLEL_SLICES 1
+
 #define ENABLE_TIMING 0
 
-layout(local_size_x = 8, local_size_y = 8) in;
+layout(local_size_x = 64*PARALLEL_SLICES) in;
 layout(r8ui) uniform restrict readonly uimage2D cum2sym_tex;
 layout(rg16ui) uniform restrict readonly uimage2D dsyms_tex;
 layout(r8) uniform restrict writeonly image2D out_tex;
+layout(r32i) uniform restrict writeonly iimage2D coeff_tex;
+layout(r32i) uniform restrict writeonly iimage2D coeff2_tex;
+uniform int num_blocks;
 
 const uint prob_bits = 12;
 const uint prob_scale = 1 << prob_bits;
 const uint NUM_SYMS = 256;
 const uint ESCAPE_LIMIT = NUM_SYMS - 1;
+const uint BLOCKS_PER_STREAM = 320;
 
 // These need to be folded into quant_matrix.
 const float dc_scalefac = 8.0;
@@ -37,6 +43,16 @@ const uint ff_zigzag_direct[64] = {
     58, 59, 52, 45, 38, 31, 39, 46,
     53, 60, 61, 54, 47, 55, 62, 63
 };
+const uint stream_mapping[64] = {
+       0, 0, 1, 1, 2, 2, 3, 3,
+       0, 0, 1, 2, 2, 2, 3, 3,
+       1, 1, 2, 2, 2, 3, 3, 3,
+       1, 1, 2, 2, 2, 3, 3, 3,
+       1, 2, 2, 2, 2, 3, 3, 3,
+       2, 2, 2, 2, 3, 3, 3, 3,
+       2, 2, 3, 3, 3, 3, 3, 3,
+       3, 3, 3, 3, 3, 3, 3, 3,
+};
 
 layout(std430, binding = 9) buffer layoutName
 {
@@ -48,57 +64,58 @@ layout(std430, binding = 10) buffer layoutName2
 };
 
 struct CoeffStream {
-       uint src_offset, src_len, sign_offset, sign_len, extra_bits;
+       uint src_offset, src_len;
 };
 layout(std430, binding = 0) buffer whatever3
 {
        CoeffStream streams[];
 };
+uniform uint sign_bias_per_model[16];
 
-uniform uint src_offset, src_len, sign_offset, sign_len, extra_bits;
-
-const uint RANS_BYTE_L = (1u << 23);  // lower bound of our normalization interval
+struct myuint64 {
+       uint high, low;
+};
 
-uint last_offset = -1, ransbuf;
+const uint RANS64_L = (1u << 31);  // lower bound of our normalization interval
 
-uint get_rans_byte(uint offset)
+myuint64 RansDecInit(inout uint offset)
 {
-       if (last_offset != (offset >> 2)) {
-               last_offset = offset >> 2;
-               ransbuf = data_SSBO[offset >> 2];
-       }
-       return bitfieldExtract(ransbuf, 8 * int(offset & 3u), 8);
+       myuint64 x;
+       x.low  = data_SSBO[offset++];
+       x.high = data_SSBO[offset++];
+       return x;
+}
 
-       // We assume little endian.
-//     return bitfieldExtract(data_SSBO[offset >> 2], 8 * int(offset & 3u), 8);
+uint RansDecGet(myuint64 r, uint scale_bits)
+{
+       return r.low & ((1u << scale_bits) - 1);
 }
 
-void RansDecInit(out uint r, inout uint offset)
+void RansDecAdvance(inout myuint64 rans, inout uint offset, const uint start, const uint freq, uint prob_bits)
 {
-       uint x;
+       const uint mask = (1u << prob_bits) - 1;
+       const uint recovered_lowbits = (rans.low & mask) - start;
 
-       x  = get_rans_byte(offset);
-       x |= get_rans_byte(offset + 1) << 8;
-       x |= get_rans_byte(offset + 2) << 16;
-       x |= get_rans_byte(offset + 3) << 24;
-       offset += 4;
+       // rans >>= prob_bits;
+       rans.low = (rans.low >> prob_bits) | ((rans.high & mask) << (32 - prob_bits));
+       rans.high >>= prob_bits;
 
-       r = x;
-}
+       // rans *= freq;
+       uint h1, l1, h2, l2;
+       umulExtended(rans.low, freq, h1, l1);
+       umulExtended(rans.high, freq, h2, l2);
+       rans.low = l1;
+       rans.high = l2 + h1;
 
-uint RansDecGet(uint r, uint scale_bits)
-{
-       return r & ((1u << scale_bits) - 1);
-}
+       // rans += recovered_lowbits;
+       uint carry;
+       rans.low = uaddCarry(rans.low, recovered_lowbits, carry);
+       rans.high += carry;
 
-void RansDecAdvance(inout uint rans, inout uint offset, const uint start, const uint freq, uint prob_bits)
-{
-       const uint mask = (1u << prob_bits) - 1;
-       rans = freq * (rans >> prob_bits) + (rans & mask) - start;
-       
        // renormalize
-       while (rans < RANS_BYTE_L) {
-               rans = (rans << 8) | get_rans_byte(offset++);
+       if (rans.high == 0 && rans.low < RANS64_L) {
+               rans.high = rans.low;
+               rans.low = data_SSBO[offset++];
        }
 }
 
@@ -163,7 +180,7 @@ void idct_1d(inout float y0, inout float y1, inout float y2, inout float y3, ino
        y7 = p6_0 - p6_7;
 }
 
-shared float temp[64 * 8];
+shared float temp[64 * 8 * PARALLEL_SLICES];
 
 void pick_timer(inout uvec2 start, inout uvec2 t)
 {
@@ -194,70 +211,76 @@ void main()
        }
        uvec2 start = clock2x32ARB();
 #else
-       uvec2 start;
+       uvec2 start = uvec2(0, 0);
+       local_timing[0] = start;
 #endif
 
-       const uint num_blocks = 720 / 16;  // FIXME: make a uniform
-       const uint thread_num = gl_LocalInvocationID.y * 8 + gl_LocalInvocationID.x;
+       const uint blocks_per_row = (imageSize(out_tex).x + 7) / 8;
+
+       const uint local_x = gl_LocalInvocationID.x % 8;
+       const uint local_y = (gl_LocalInvocationID.x / 8) % 8;
+       const uint local_z = gl_LocalInvocationID.x / 64;
+
+       const uint slice_num = local_z;
+       const uint thread_num = local_y * 8 + local_x;
 
-       const uint block_row = gl_WorkGroupID.y;
+       const uint block_row = gl_WorkGroupID.y * PARALLEL_SLICES + slice_num;
        //const uint coeff_num = ff_zigzag_direct[thread_num];
        const uint coeff_num = thread_num;
        const uint stream_num = coeff_num * num_blocks + block_row;
-       //const uint stream_num = block_row * num_blocks + coeff_num;  // HACK
-       const uint model_num = min((coeff_num % 8) + (coeff_num / 8), 7);
+       const uint model_num = stream_mapping[coeff_num];
+       const uint sign_bias = sign_bias_per_model[model_num];
 
        // Initialize rANS decoder.
-       uint offset = streams[stream_num].src_offset;
-       uint rans;
-       RansDecInit(rans, offset);
-
-       // Initialize sign bit decoder. TODO: this ought to be 32-bit-aligned instead!
-       uint soffset = streams[stream_num].sign_offset;
-       uint sign_buf = get_rans_byte(soffset++) >> streams[stream_num].extra_bits;
-       uint sign_bits_left = 8 - streams[stream_num].extra_bits;
+       uint offset = streams[stream_num].src_offset >> 2;
+       myuint64 rans = RansDecInit(offset);
 
        float q = (coeff_num == 0) ? 1.0 : (quant_matrix[coeff_num] * quant_scalefac / 128.0 / sqrt(2.0));  // FIXME: fold
        q *= (1.0 / 255.0);
        //int w = (coeff_num == 0) ? 32 : int(quant_matrix[coeff_num]);
-       int last_k = 0;
+       int last_k = 128;
 
        pick_timer(start, local_timing[0]);
 
-       for (uint block_idx = 40; block_idx --> 0; ) {
-               uint block_x = block_idx % 20;
-               uint block_y = block_idx / 20;
-               if (block_x == 19) last_k = 0;
-
+       for (uint block_idx = BLOCKS_PER_STREAM / 8; block_idx --> 0; ) {
                pick_timer(start, local_timing[1]);
 
                // rANS decode one coefficient across eight blocks (so 64x8 coefficients).
                for (uint subblock_idx = 8; subblock_idx --> 0; ) {
                        // Read a symbol.
-                       int k = int(cum2sym(RansDecGet(rans, prob_bits), model_num));
+                       uint bottom_bits = RansDecGet(rans, prob_bits + 1);
+                       bool sign = false;
+                       if (bottom_bits >= sign_bias) {
+                               bottom_bits -= sign_bias;
+                               rans.low -= sign_bias;
+                               sign = true;
+                       }
+                       int k = int(cum2sym(bottom_bits, model_num));  // Can go out-of-bounds; that will return zero.
                        uvec2 sym = get_dsym(k, model_num);
-                       RansDecAdvance(rans, offset, sym.x, sym.y, prob_bits);
+                       RansDecAdvance(rans, offset, sym.x, sym.y, prob_bits + 1);
 
                        if (k == ESCAPE_LIMIT) {
                                k = int(RansDecGet(rans, prob_bits));
                                RansDecAdvance(rans, offset, k, 1, prob_bits);
                        }
-                       if (k != 0) {
-                               if (sign_bits_left == 0) {
-                                       sign_buf = get_rans_byte(soffset++);
-                                       sign_bits_left = 8;
-                               }
-                               if ((sign_buf & 1u) == 1u) k = -k;
-                               --sign_bits_left;
-                               sign_buf >>= 1;
+                       if (sign) {
+                               k = -k;
                        }
+#if 0
+                       if (coeff_num == 0) {
+                               //imageStore(coeff_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(k, 0,0,0));
+                               imageStore(coeff_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(rans.low, 0,0,0));
+                               imageStore(coeff2_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(rans.high, 0,0,0));
+                       }
+#endif
 
                        if (coeff_num == 0) {
                                k += last_k;
                                last_k = k;
                        }
 
-                       temp[subblock_idx * 64 + coeff_num] = k * q;
+
+                       temp[slice_num * 64 * 8 + subblock_idx * 64 + coeff_num] = k * q;
                        //temp[subblock_idx * 64 + 8 * y + x] = (2 * k * w * 4) / 32;  // 100% matching unquant
                }
 
@@ -269,14 +292,14 @@ void main()
                pick_timer(start, local_timing[3]);
 
                // Horizontal DCT one row (so 64 rows).
-               idct_1d(temp[thread_num * 8 + 0],
-                       temp[thread_num * 8 + 1],
-                       temp[thread_num * 8 + 2],
-                       temp[thread_num * 8 + 3],
-                       temp[thread_num * 8 + 4],
-                       temp[thread_num * 8 + 5],
-                       temp[thread_num * 8 + 6],
-                       temp[thread_num * 8 + 7]);
+               idct_1d(temp[slice_num * 64 * 8 + thread_num * 8 + 0],
+                       temp[slice_num * 64 * 8 + thread_num * 8 + 1],
+                       temp[slice_num * 64 * 8 + thread_num * 8 + 2],
+                       temp[slice_num * 64 * 8 + thread_num * 8 + 3],
+                       temp[slice_num * 64 * 8 + thread_num * 8 + 4],
+                       temp[slice_num * 64 * 8 + thread_num * 8 + 5],
+                       temp[slice_num * 64 * 8 + thread_num * 8 + 6],
+                       temp[slice_num * 64 * 8 + thread_num * 8 + 7]);
 
                pick_timer(start, local_timing[4]);
 
@@ -286,7 +309,7 @@ void main()
                pick_timer(start, local_timing[5]);
 
                // Vertical DCT one row (so 64 columns).
-               uint row_offset = gl_LocalInvocationID.y * 64 + gl_LocalInvocationID.x;
+               uint row_offset = local_z * 64 * 8 + local_y * 64 + local_x;
                idct_1d(temp[row_offset + 0 * 8],
                        temp[row_offset + 1 * 8],
                        temp[row_offset + 2 * 8],
@@ -298,8 +321,12 @@ void main()
 
                pick_timer(start, local_timing[6]);
 
-               uint y = block_row * 16 + block_y * 8;
-               uint x = block_x * 64 + gl_LocalInvocationID.y * 8 + gl_LocalInvocationID.x;
+               uint global_block_idx = (block_row * 40 + block_idx) * 8 + local_y;
+               uint block_x = global_block_idx % blocks_per_row;
+               uint block_y = global_block_idx / blocks_per_row;
+
+               uint y = block_y * 8;
+               uint x = block_x * 8 + local_x;
                for (uint yl = 0; yl < 8; ++yl) {
                        imageStore(out_tex, ivec2(x, yl + y), vec4(temp[row_offset + yl * 8], 0.0, 0.0, 1.0));
                }