]> git.sesse.net Git - narabu/blob - rans.shader
Make rans.shader write uint32s, shedding the GL_NV_gpu_shader5 demand.
[narabu] / rans.shader
1 #version 440
2
3 layout(local_size_x = 1) in;
4
5 const uint prob_bits = 12;
6 const uint prob_scale = 1 << prob_bits;
7 const uint RANS_BYTE_L = (1u << 23);
8 const uint BLOCKS_PER_STREAM = 320;
9 const uint STREAM_BUF_SIZE = 256;  // In uint32s. 1 kB per stream ought to be enough for everyone :-)
10 const uint NUM_SYMS = 256;
11 const uint ESCAPE_LIMIT = NUM_SYMS - 1;
12
13 #define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14))
14
15 const uint luma_mapping[8] = {
16         MAPPING(0, 0, 1, 1, 2, 2, 3, 3),
17         MAPPING(0, 0, 1, 2, 2, 2, 3, 3),
18         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
19         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
20         MAPPING(1, 2, 2, 2, 2, 3, 3, 3),
21         MAPPING(2, 2, 2, 2, 3, 3, 3, 3),
22         MAPPING(2, 2, 3, 3, 3, 3, 3, 3),
23         MAPPING(3, 3, 3, 3, 3, 3, 3, 3),
24 };
25
26 layout(std430, binding = 10) buffer outputBuf
27 {
28         uint rans_output[];
29 };
30
31 layout(std430, binding = 11) buffer outputBuf2
32 {
33         uint rans_bytes_written[];
34 };
35
36 layout(std140, binding = 13) uniform DistBlock
37 {
38         uvec4 ransdist[4 * 256];
39         uint sign_biases[4];
40 };
41
42 struct RansEncoder {
43         uint stream_num;         // const
44         uint lut_base;           // const
45         uint sign_bias;          // const
46         uint rans_start_offset;  // const
47         uint rans_offset;
48         uint rans;
49         uint bit_buffer;
50         uint bytes_in_buffer;
51 };
52
53 layout(r16ui) uniform restrict readonly uimage2D dc_ac7_tex;
54 layout(r16ui) uniform restrict readonly uimage2D ac1_ac6_tex;
55 layout(r16ui) uniform restrict readonly uimage2D ac2_ac5_tex;
56 layout(r8i) uniform restrict readonly iimage2D ac3_tex;
57 layout(r8i) uniform restrict readonly iimage2D ac4_tex;
58
59 void RansPutByte(uint x, inout RansEncoder enc)
60 {
61         enc.bit_buffer = (enc.bit_buffer << 8) | x;
62         if (++enc.bytes_in_buffer == 4) {
63                 rans_output[--enc.rans_offset] = enc.bit_buffer;
64                 enc.bytes_in_buffer = 0;
65         }
66 }
67
68 void RansEncInit(uint streamgroup_num, uint coeff_row, uint coeff_col, uint dist_num, out RansEncoder enc)
69 {
70         enc.stream_num = streamgroup_num * 64 + coeff_row * 8 + coeff_col;
71         enc.lut_base = dist_num * 256;
72         enc.sign_bias = sign_biases[dist_num];
73         enc.rans_offset = enc.stream_num * STREAM_BUF_SIZE + STREAM_BUF_SIZE;  // Starts at the end.
74         enc.rans_start_offset = enc.rans_offset;
75         enc.rans = RANS_BYTE_L;
76         enc.bit_buffer = 0;
77         enc.bytes_in_buffer = 0;
78 }
79
80 void RansEncRenorm(inout RansEncoder enc, uint freq, uint prob_bits)
81 {
82         uint x_max = ((RANS_BYTE_L >> prob_bits) << 8) * freq; // this turns into a shift.
83         while (enc.rans >= x_max) {
84                 RansPutByte(enc.rans & 0xffu, enc);
85                 enc.rans >>= 8;
86         }
87 }
88
89 void RansEncPut(inout RansEncoder enc, uint start, uint freq, uint prob_bits)
90 {
91         RansEncRenorm(enc, freq, prob_bits);
92         enc.rans = ((enc.rans / freq) << prob_bits) + (enc.rans % freq) + start;
93 }
94
95 void RansEncPutSymbol(inout RansEncoder enc, uvec4 sym)
96 {
97         uint x_max = sym.x;
98         uint rcp_freq = sym.y;
99         uint bias = sym.z;
100         uint rcp_shift = (sym.w & 0xffffu);
101         uint cmpl_freq = (sym.w >> 16);
102
103         // renormalize
104         while (enc.rans >= x_max) {
105                 RansPutByte(enc.rans & 0xffu, enc);
106                 enc.rans >>= 8;
107         }
108
109         uint q, unused;
110         umulExtended(enc.rans, rcp_freq, q, unused);
111         enc.rans += bias + (q >> rcp_shift) * cmpl_freq;
112 }
113
114 uint RansEncFlush(inout RansEncoder enc)
115 {
116         RansPutByte(enc.rans >> 24, enc);
117         RansPutByte(enc.rans >> 16, enc);
118         RansPutByte(enc.rans >> 8, enc);
119         RansPutByte(enc.rans >> 0, enc);
120
121         uint num_bytes_written = (enc.rans_start_offset - enc.rans_offset) * 4 + enc.bytes_in_buffer;
122
123         // Make sure there's nothing left in the buffer.
124         RansPutByte(0, enc);
125         RansPutByte(0, enc);
126         RansPutByte(0, enc);
127
128         return num_bytes_written;
129 }
130
131 int sign_extend(uint coeff, uint bits)
132 {
133         return int(coeff << (32 - bits)) >> (32 - bits);
134 }
135
136 void encode_coeff(int signed_k, inout RansEncoder enc)
137 {
138         uint k = abs(signed_k);
139
140         if (k >= ESCAPE_LIMIT) {
141                 // Put the coefficient as a 1/(2^12) symbol _before_
142                 // the 255 coefficient, since the decoder will read the
143                 // 255 coefficient first.
144                 RansEncPut(enc, k, 1, prob_bits);
145                 k = ESCAPE_LIMIT;
146         }
147
148         uvec4 sym = ransdist[enc.lut_base + ((k - 1) & (NUM_SYMS - 1))];
149         RansEncPutSymbol(enc, sym);
150         
151         if (signed_k < 0) {
152                 enc.rans += enc.sign_bias;
153         }
154 }
155
156 void encode_end(inout RansEncoder enc)
157 {
158         uint bytes_written = RansEncFlush(enc);
159         rans_bytes_written[enc.stream_num] = bytes_written;
160 }
161
162 void encode_9_7(uint streamgroup_num, uint coeff_row, layout(r16ui) restrict readonly uimage2D tex, uint col1, uint col2, uint dist1, uint dist2)
163 {
164         RansEncoder enc1, enc2;
165         RansEncInit(streamgroup_num, coeff_row, col1, dist1, enc1);
166         RansEncInit(streamgroup_num, coeff_row, col2, dist2, enc2);
167
168         for (uint subblock_idx = 0; subblock_idx < BLOCKS_PER_STREAM; ++subblock_idx) {
169                 // TODO: Use SSBOs instead of a texture?
170                 uint x = (streamgroup_num * BLOCKS_PER_STREAM + subblock_idx) % 160;
171                 uint y = (streamgroup_num * BLOCKS_PER_STREAM + subblock_idx) / 160;
172                 uint f = imageLoad(tex, ivec2(x, y * 8 + coeff_row)).x;
173
174                 encode_coeff(sign_extend(f & 0x1ffu, 9), enc1);
175                 encode_coeff(sign_extend(f >> 9, 7), enc2);
176         }
177
178         encode_end(enc1);
179         encode_end(enc2);
180 }
181
182 void encode_8(uint streamgroup_num, uint coeff_row, layout(r8i) restrict readonly iimage2D tex, uint col, uint dist)
183 {
184         RansEncoder enc;
185         RansEncInit(streamgroup_num, coeff_row, col, dist, enc);
186
187         for (uint subblock_idx = 0; subblock_idx < BLOCKS_PER_STREAM; ++subblock_idx) {
188                 // TODO: Use SSBOs instead of a texture?
189                 uint x = (streamgroup_num * BLOCKS_PER_STREAM + subblock_idx) % 160;
190                 uint y = (streamgroup_num * BLOCKS_PER_STREAM + subblock_idx) / 160;
191                 int f = imageLoad(tex, ivec2(x, y * 8 + coeff_row)).x;
192
193                 encode_coeff(f, enc);
194         }
195
196         encode_end(enc);
197 }
198
199 void main()
200 {
201         uint streamgroup_num = gl_WorkGroupID.x;
202         uint coeff_row = gl_WorkGroupID.y;    // 0..7
203         uint coeff_colset = gl_WorkGroupID.z;   // 0 = dc+ac7, 1 = ac1+ac6, 2 = ac2+ac5, 3 = ac3, 4 = ac5
204         uint m = luma_mapping[coeff_row];
205
206         // TODO: DC coeff pred
207
208         if (coeff_colset == 0) {
209                 uint dist_dc = bitfieldExtract(m, 0, 2);
210                 uint dist_ac7 = bitfieldExtract(m, 14, 2);
211                 encode_9_7(streamgroup_num, coeff_row, dc_ac7_tex, 0, 7, dist_dc, dist_ac7);
212         } else if (coeff_colset == 1) {
213                 uint dist_ac1 = bitfieldExtract(m, 2, 2);
214                 uint dist_ac6 = bitfieldExtract(m, 12, 2);
215                 encode_9_7(streamgroup_num, coeff_row, ac1_ac6_tex, 1, 6, dist_ac1, dist_ac6);
216         } else if (coeff_colset == 2) {
217                 uint dist_ac2 = bitfieldExtract(m, 4, 2);
218                 uint dist_ac5 = bitfieldExtract(m, 10, 2);
219                 encode_9_7(streamgroup_num, coeff_row, ac2_ac5_tex, 2, 5, dist_ac2, dist_ac5);
220         } else if (coeff_colset == 3) {
221                 uint dist_ac3 = bitfieldExtract(m, 6, 2);
222                 encode_8(streamgroup_num, coeff_row, ac3_tex, 3, dist_ac3);
223         } else {
224                 uint dist_ac4 = bitfieldExtract(m, 8, 2);
225                 encode_8(streamgroup_num, coeff_row, ac4_tex, 4, dist_ac4);
226         }
227 }