]> git.sesse.net Git - narabu/blobdiff - decoder.shader
Switch to 64-bit rANS, although probably due for immediate revert (just want to prese...
[narabu] / decoder.shader
index 3b721264b4026f5a6b9cbb6e9c6a5749b15bd30e..012cbe3bbac0ec13bb40b0cf42bd9ecc4bdcadf0 100644 (file)
@@ -9,7 +9,8 @@ layout(local_size_x = 64*PARALLEL_SLICES) in;
 layout(r8ui) uniform restrict readonly uimage2D cum2sym_tex;
 layout(rg16ui) uniform restrict readonly uimage2D dsyms_tex;
 layout(r8) uniform restrict writeonly image2D out_tex;
-layout(r16i) uniform restrict writeonly iimage2D coeff_tex;
+layout(r32i) uniform restrict writeonly iimage2D coeff_tex;
+layout(r32i) uniform restrict writeonly iimage2D coeff2_tex;
 uniform int num_blocks;
 
 const uint prob_bits = 12;
@@ -71,40 +72,50 @@ layout(std430, binding = 0) buffer whatever3
 };
 uniform uint sign_bias_per_model[16];
 
-const uint RANS_BYTE_L = (1u << 23);  // lower bound of our normalization interval
+struct myuint64 {
+       uint high, low;
+};
 
-uint get_rans_byte(uint offset)
-{
-       // We assume little endian.
-       return bitfieldExtract(data_SSBO[offset >> 2], 8 * int(offset & 3u), 8);
-}
+const uint RANS64_L = (1u << 31);  // lower bound of our normalization interval
 
-uint RansDecInit(inout uint offset)
+myuint64 RansDecInit(inout uint offset)
 {
-       uint x;
-
-       x  = get_rans_byte(offset);
-       x |= get_rans_byte(offset + 1) << 8;
-       x |= get_rans_byte(offset + 2) << 16;
-       x |= get_rans_byte(offset + 3) << 24;
-       offset += 4;
-
+       myuint64 x;
+       x.low  = data_SSBO[offset++];
+       x.high = data_SSBO[offset++];
        return x;
 }
 
-uint RansDecGet(uint r, uint scale_bits)
+uint RansDecGet(myuint64 r, uint scale_bits)
 {
-       return r & ((1u << scale_bits) - 1);
+       return r.low & ((1u << scale_bits) - 1);
 }
 
-void RansDecAdvance(inout uint rans, inout uint offset, const uint start, const uint freq, uint prob_bits)
+void RansDecAdvance(inout myuint64 rans, inout uint offset, const uint start, const uint freq, uint prob_bits)
 {
        const uint mask = (1u << prob_bits) - 1;
-       rans = freq * (rans >> prob_bits) + (rans & mask) - start;
-       
+       const uint recovered_lowbits = (rans.low & mask) - start;
+
+       // rans >>= prob_bits;
+       rans.low = (rans.low >> prob_bits) | ((rans.high & mask) << (32 - prob_bits));
+       rans.high >>= prob_bits;
+
+       // rans *= freq;
+       uint h1, l1, h2, l2;
+       umulExtended(rans.low, freq, h1, l1);
+       umulExtended(rans.high, freq, h2, l2);
+       rans.low = l1;
+       rans.high = l2 + h1;
+
+       // rans += recovered_lowbits;
+       uint carry;
+       rans.low = uaddCarry(rans.low, recovered_lowbits, carry);
+       rans.high += carry;
+
        // renormalize
-       while (rans < RANS_BYTE_L) {
-               rans = (rans << 8) | get_rans_byte(offset++);
+       if (rans.high == 0 && rans.low < RANS64_L) {
+               rans.high = rans.low;
+               rans.low = data_SSBO[offset++];
        }
 }
 
@@ -221,8 +232,8 @@ void main()
        const uint sign_bias = sign_bias_per_model[model_num];
 
        // Initialize rANS decoder.
-       uint offset = streams[stream_num].src_offset;
-       uint rans = RansDecInit(offset);
+       uint offset = streams[stream_num].src_offset >> 2;
+       myuint64 rans = RansDecInit(offset);
 
        float q = (coeff_num == 0) ? 1.0 : (quant_matrix[coeff_num] * quant_scalefac / 128.0 / sqrt(2.0));  // FIXME: fold
        q *= (1.0 / 255.0);
@@ -241,7 +252,7 @@ void main()
                        bool sign = false;
                        if (bottom_bits >= sign_bias) {
                                bottom_bits -= sign_bias;
-                               rans -= sign_bias;
+                               rans.low -= sign_bias;
                                sign = true;
                        }
                        int k = int(cum2sym(bottom_bits, model_num));  // Can go out-of-bounds; that will return zero.
@@ -255,17 +266,19 @@ void main()
                        if (sign) {
                                k = -k;
                        }
+#if 0
+                       if (coeff_num == 0) {
+                               //imageStore(coeff_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(k, 0,0,0));
+                               imageStore(coeff_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(rans.low, 0,0,0));
+                               imageStore(coeff2_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(rans.high, 0,0,0));
+                       }
+#endif
 
                        if (coeff_num == 0) {
                                k += last_k;
                                last_k = k;
                        }
 
-#if 0
-                       uint y = block_row * 16 + block_y * 8 + local_y;
-                       uint x = block_x * 64 + subblock_idx * 8 + local_x;
-                       imageStore(coeff_tex, ivec2(x, y), ivec4(k, 0,0,0));
-#endif
 
                        temp[slice_num * 64 * 8 + subblock_idx * 64 + coeff_num] = k * q;
                        //temp[subblock_idx * 64 + 8 * y + x] = (2 * k * w * 4) / 32;  // 100% matching unquant