X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=decoder.shader;h=012cbe3bbac0ec13bb40b0cf42bd9ecc4bdcadf0;hb=3fb87c6b953be3382cd216c74ff6aa025c8eaa2a;hp=6d54e4dc5671173f540a1574666e51870a14b1c2;hpb=28409aed1a0cbf8d2e8d9d157d08c3f6d9a3f51a;p=narabu diff --git a/decoder.shader b/decoder.shader index 6d54e4d..012cbe3 100644 --- a/decoder.shader +++ b/decoder.shader @@ -1,17 +1,23 @@ -#version 430 +#version 440 #extension GL_ARB_shader_clock : enable +#define PARALLEL_SLICES 1 + #define ENABLE_TIMING 0 -layout(local_size_x = 8, local_size_y = 8) in; +layout(local_size_x = 64*PARALLEL_SLICES) in; layout(r8ui) uniform restrict readonly uimage2D cum2sym_tex; layout(rg16ui) uniform restrict readonly uimage2D dsyms_tex; layout(r8) uniform restrict writeonly image2D out_tex; +layout(r32i) uniform restrict writeonly iimage2D coeff_tex; +layout(r32i) uniform restrict writeonly iimage2D coeff2_tex; +uniform int num_blocks; const uint prob_bits = 12; const uint prob_scale = 1 << prob_bits; const uint NUM_SYMS = 256; const uint ESCAPE_LIMIT = NUM_SYMS - 1; +const uint BLOCKS_PER_STREAM = 320; // These need to be folded into quant_matrix. const float dc_scalefac = 8.0; @@ -37,6 +43,16 @@ const uint ff_zigzag_direct[64] = { 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63 }; +const uint stream_mapping[64] = { + 0, 0, 1, 1, 2, 2, 3, 3, + 0, 0, 1, 2, 2, 2, 3, 3, + 1, 1, 2, 2, 2, 3, 3, 3, + 1, 1, 2, 2, 2, 3, 3, 3, + 1, 2, 2, 2, 2, 3, 3, 3, + 2, 2, 2, 2, 3, 3, 3, 3, + 2, 2, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, +}; layout(std430, binding = 9) buffer layoutName { @@ -48,57 +64,58 @@ layout(std430, binding = 10) buffer layoutName2 }; struct CoeffStream { - uint src_offset, src_len, sign_offset, sign_len, extra_bits; + uint src_offset, src_len; }; layout(std430, binding = 0) buffer whatever3 { CoeffStream streams[]; }; +uniform uint sign_bias_per_model[16]; -uniform uint src_offset, src_len, sign_offset, sign_len, extra_bits; - -const uint RANS_BYTE_L = (1u << 23); // lower bound of our normalization interval +struct myuint64 { + uint high, low; +}; -uint last_offset = -1, ransbuf; +const uint RANS64_L = (1u << 31); // lower bound of our normalization interval -uint get_rans_byte(uint offset) +myuint64 RansDecInit(inout uint offset) { - if (last_offset != (offset >> 2)) { - last_offset = offset >> 2; - ransbuf = data_SSBO[offset >> 2]; - } - return bitfieldExtract(ransbuf, 8 * int(offset & 3u), 8); + myuint64 x; + x.low = data_SSBO[offset++]; + x.high = data_SSBO[offset++]; + return x; +} - // We assume little endian. -// return bitfieldExtract(data_SSBO[offset >> 2], 8 * int(offset & 3u), 8); +uint RansDecGet(myuint64 r, uint scale_bits) +{ + return r.low & ((1u << scale_bits) - 1); } -void RansDecInit(out uint r, inout uint offset) +void RansDecAdvance(inout myuint64 rans, inout uint offset, const uint start, const uint freq, uint prob_bits) { - uint x; + const uint mask = (1u << prob_bits) - 1; + const uint recovered_lowbits = (rans.low & mask) - start; - x = get_rans_byte(offset); - x |= get_rans_byte(offset + 1) << 8; - x |= get_rans_byte(offset + 2) << 16; - x |= get_rans_byte(offset + 3) << 24; - offset += 4; + // rans >>= prob_bits; + rans.low = (rans.low >> prob_bits) | ((rans.high & mask) << (32 - prob_bits)); + rans.high >>= prob_bits; - r = x; -} + // rans *= freq; + uint h1, l1, h2, l2; + umulExtended(rans.low, freq, h1, l1); + umulExtended(rans.high, freq, h2, l2); + rans.low = l1; + rans.high = l2 + h1; -uint RansDecGet(uint r, uint scale_bits) -{ - return r & ((1u << scale_bits) - 1); -} + // rans += recovered_lowbits; + uint carry; + rans.low = uaddCarry(rans.low, recovered_lowbits, carry); + rans.high += carry; -void RansDecAdvance(inout uint rans, inout uint offset, const uint start, const uint freq, uint prob_bits) -{ - const uint mask = (1u << prob_bits) - 1; - rans = freq * (rans >> prob_bits) + (rans & mask) - start; - // renormalize - while (rans < RANS_BYTE_L) { - rans = (rans << 8) | get_rans_byte(offset++); + if (rans.high == 0 && rans.low < RANS64_L) { + rans.high = rans.low; + rans.low = data_SSBO[offset++]; } } @@ -163,7 +180,7 @@ void idct_1d(inout float y0, inout float y1, inout float y2, inout float y3, ino y7 = p6_0 - p6_7; } -shared float temp[64 * 8]; +shared float temp[64 * 8 * PARALLEL_SLICES]; void pick_timer(inout uvec2 start, inout uvec2 t) { @@ -194,70 +211,76 @@ void main() } uvec2 start = clock2x32ARB(); #else - uvec2 start; + uvec2 start = uvec2(0, 0); + local_timing[0] = start; #endif - const uint num_blocks = 720 / 16; // FIXME: make a uniform - const uint thread_num = gl_LocalInvocationID.y * 8 + gl_LocalInvocationID.x; + const uint blocks_per_row = (imageSize(out_tex).x + 7) / 8; + + const uint local_x = gl_LocalInvocationID.x % 8; + const uint local_y = (gl_LocalInvocationID.x / 8) % 8; + const uint local_z = gl_LocalInvocationID.x / 64; + + const uint slice_num = local_z; + const uint thread_num = local_y * 8 + local_x; - const uint block_row = gl_WorkGroupID.y; + const uint block_row = gl_WorkGroupID.y * PARALLEL_SLICES + slice_num; //const uint coeff_num = ff_zigzag_direct[thread_num]; const uint coeff_num = thread_num; const uint stream_num = coeff_num * num_blocks + block_row; - //const uint stream_num = block_row * num_blocks + coeff_num; // HACK - const uint model_num = min((coeff_num % 8) + (coeff_num / 8), 7); + const uint model_num = stream_mapping[coeff_num]; + const uint sign_bias = sign_bias_per_model[model_num]; // Initialize rANS decoder. - uint offset = streams[stream_num].src_offset; - uint rans; - RansDecInit(rans, offset); - - // Initialize sign bit decoder. TODO: this ought to be 32-bit-aligned instead! - uint soffset = streams[stream_num].sign_offset; - uint sign_buf = get_rans_byte(soffset++) >> streams[stream_num].extra_bits; - uint sign_bits_left = 8 - streams[stream_num].extra_bits; + uint offset = streams[stream_num].src_offset >> 2; + myuint64 rans = RansDecInit(offset); float q = (coeff_num == 0) ? 1.0 : (quant_matrix[coeff_num] * quant_scalefac / 128.0 / sqrt(2.0)); // FIXME: fold q *= (1.0 / 255.0); //int w = (coeff_num == 0) ? 32 : int(quant_matrix[coeff_num]); - int last_k = 0; + int last_k = 128; pick_timer(start, local_timing[0]); - for (uint block_idx = 40; block_idx --> 0; ) { - uint block_x = block_idx % 20; - uint block_y = block_idx / 20; - if (block_x == 19) last_k = 0; - + for (uint block_idx = BLOCKS_PER_STREAM / 8; block_idx --> 0; ) { pick_timer(start, local_timing[1]); // rANS decode one coefficient across eight blocks (so 64x8 coefficients). for (uint subblock_idx = 8; subblock_idx --> 0; ) { // Read a symbol. - int k = int(cum2sym(RansDecGet(rans, prob_bits), model_num)); + uint bottom_bits = RansDecGet(rans, prob_bits + 1); + bool sign = false; + if (bottom_bits >= sign_bias) { + bottom_bits -= sign_bias; + rans.low -= sign_bias; + sign = true; + } + int k = int(cum2sym(bottom_bits, model_num)); // Can go out-of-bounds; that will return zero. uvec2 sym = get_dsym(k, model_num); - RansDecAdvance(rans, offset, sym.x, sym.y, prob_bits); + RansDecAdvance(rans, offset, sym.x, sym.y, prob_bits + 1); if (k == ESCAPE_LIMIT) { k = int(RansDecGet(rans, prob_bits)); RansDecAdvance(rans, offset, k, 1, prob_bits); } - if (k != 0) { - if (sign_bits_left == 0) { - sign_buf = get_rans_byte(soffset++); - sign_bits_left = 8; - } - if ((sign_buf & 1u) == 1u) k = -k; - --sign_bits_left; - sign_buf >>= 1; + if (sign) { + k = -k; } +#if 0 + if (coeff_num == 0) { + //imageStore(coeff_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(k, 0,0,0)); + imageStore(coeff_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(rans.low, 0,0,0)); + imageStore(coeff2_tex, ivec2((block_row * 40 + block_idx) * 8 + subblock_idx, 0), ivec4(rans.high, 0,0,0)); + } +#endif if (coeff_num == 0) { k += last_k; last_k = k; } - temp[subblock_idx * 64 + coeff_num] = k * q; + + temp[slice_num * 64 * 8 + subblock_idx * 64 + coeff_num] = k * q; //temp[subblock_idx * 64 + 8 * y + x] = (2 * k * w * 4) / 32; // 100% matching unquant } @@ -269,14 +292,14 @@ void main() pick_timer(start, local_timing[3]); // Horizontal DCT one row (so 64 rows). - idct_1d(temp[thread_num * 8 + 0], - temp[thread_num * 8 + 1], - temp[thread_num * 8 + 2], - temp[thread_num * 8 + 3], - temp[thread_num * 8 + 4], - temp[thread_num * 8 + 5], - temp[thread_num * 8 + 6], - temp[thread_num * 8 + 7]); + idct_1d(temp[slice_num * 64 * 8 + thread_num * 8 + 0], + temp[slice_num * 64 * 8 + thread_num * 8 + 1], + temp[slice_num * 64 * 8 + thread_num * 8 + 2], + temp[slice_num * 64 * 8 + thread_num * 8 + 3], + temp[slice_num * 64 * 8 + thread_num * 8 + 4], + temp[slice_num * 64 * 8 + thread_num * 8 + 5], + temp[slice_num * 64 * 8 + thread_num * 8 + 6], + temp[slice_num * 64 * 8 + thread_num * 8 + 7]); pick_timer(start, local_timing[4]); @@ -286,7 +309,7 @@ void main() pick_timer(start, local_timing[5]); // Vertical DCT one row (so 64 columns). - uint row_offset = gl_LocalInvocationID.y * 64 + gl_LocalInvocationID.x; + uint row_offset = local_z * 64 * 8 + local_y * 64 + local_x; idct_1d(temp[row_offset + 0 * 8], temp[row_offset + 1 * 8], temp[row_offset + 2 * 8], @@ -298,8 +321,12 @@ void main() pick_timer(start, local_timing[6]); - uint y = block_row * 16 + block_y * 8; - uint x = block_x * 64 + gl_LocalInvocationID.y * 8 + gl_LocalInvocationID.x; + uint global_block_idx = (block_row * 40 + block_idx) * 8 + local_y; + uint block_x = global_block_idx % blocks_per_row; + uint block_y = global_block_idx / blocks_per_row; + + uint y = block_y * 8; + uint x = block_x * 8 + local_x; for (uint yl = 0; yl < 8; ++yl) { imageStore(out_tex, ivec2(x, yl + y), vec4(temp[row_offset + yl * 8], 0.0, 0.0, 1.0)); }