]> git.sesse.net Git - narabu/blob - rans.shader
More fixes of hard-coded values.
[narabu] / rans.shader
1 #version 440
2
3 layout(local_size_x = 1) in;
4
5 const uint prob_bits = 12;
6 const uint prob_scale = 1 << prob_bits;
7 const uint RANS_BYTE_L = (1u << 23);
8 const uint BLOCKS_PER_STREAM = 320;
9 const uint STREAM_BUF_SIZE = 256;  // In uint32s. 1 kB per stream ought to be enough for everyone :-)
10 const uint NUM_SYMS = 256;
11 const uint ESCAPE_LIMIT = NUM_SYMS - 1;
12 const uint WIDTH_BLOCKS = 160;  // TODO: send in as a uniform.
13
14 #define MAPPING(s0, s1, s2, s3, s4, s5, s6, s7) ((s0) | (s1 << 2) | (s2 << 4) | (s3 << 6) | (s4 << 8) | (s5 << 10) | (s6 << 12) | (s7 << 14))
15
16 const uint luma_mapping[8] = {
17         MAPPING(0, 0, 1, 1, 2, 2, 3, 3),
18         MAPPING(0, 0, 1, 2, 2, 2, 3, 3),
19         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
20         MAPPING(1, 1, 2, 2, 2, 3, 3, 3),
21         MAPPING(1, 2, 2, 2, 2, 3, 3, 3),
22         MAPPING(2, 2, 2, 2, 3, 3, 3, 3),
23         MAPPING(2, 2, 3, 3, 3, 3, 3, 3),
24         MAPPING(3, 3, 3, 3, 3, 3, 3, 3),
25 };
26
27 layout(std430, binding = 10) buffer outputBuf
28 {
29         uint rans_output[];
30 };
31
32 layout(std430, binding = 11) buffer outputBuf2
33 {
34         uint rans_bytes_written[];
35 };
36
37 layout(std140, binding = 13) uniform DistBlock
38 {
39         uvec4 ransdist[4 * 256];
40         uint sign_biases[4];
41 };
42
43 struct RansEncoder {
44         uint stream_num;         // const
45         uint lut_base;           // const
46         uint sign_bias;          // const
47         uint rans_start_offset;  // const
48         uint rans_offset;
49         uint rans;
50         uint bit_buffer;
51         uint bytes_in_buffer;
52 };
53
54 layout(r16ui) uniform restrict readonly uimage2D dc_ac7_tex;
55 layout(r16ui) uniform restrict readonly uimage2D ac1_ac6_tex;
56 layout(r16ui) uniform restrict readonly uimage2D ac2_ac5_tex;
57 layout(r8i) uniform restrict readonly iimage2D ac3_tex;
58 layout(r8i) uniform restrict readonly iimage2D ac4_tex;
59
60 void RansPutByte(uint x, inout RansEncoder enc)
61 {
62         enc.bit_buffer = (enc.bit_buffer << 8) | x;
63         if (++enc.bytes_in_buffer == 4) {
64                 rans_output[--enc.rans_offset] = enc.bit_buffer;
65                 enc.bytes_in_buffer = 0;
66         }
67 }
68
69 void RansEncInit(uint streamgroup_num, uint coeff_row, uint coeff_col, uint dist_num, out RansEncoder enc)
70 {
71         enc.stream_num = streamgroup_num * 64 + coeff_row * 8 + coeff_col;
72         enc.lut_base = dist_num * 256;
73         enc.sign_bias = sign_biases[dist_num];
74         enc.rans_offset = enc.stream_num * STREAM_BUF_SIZE + STREAM_BUF_SIZE;  // Starts at the end.
75         enc.rans_start_offset = enc.rans_offset;
76         enc.rans = RANS_BYTE_L;
77         enc.bit_buffer = 0;
78         enc.bytes_in_buffer = 0;
79 }
80
81 void RansEncRenorm(inout RansEncoder enc, uint freq, uint prob_bits)
82 {
83         uint x_max = ((RANS_BYTE_L >> prob_bits) << 8) * freq; // this turns into a shift.
84         while (enc.rans >= x_max) {
85                 RansPutByte(enc.rans & 0xffu, enc);
86                 enc.rans >>= 8;
87         }
88 }
89
90 void RansEncPut(inout RansEncoder enc, uint start, uint freq, uint prob_bits)
91 {
92         RansEncRenorm(enc, freq, prob_bits);
93         enc.rans = ((enc.rans / freq) << prob_bits) + (enc.rans % freq) + start;
94 }
95
96 void RansEncPutSymbol(inout RansEncoder enc, uvec4 sym)
97 {
98         uint x_max = sym.x;
99         uint rcp_freq = sym.y;
100         uint bias = sym.z;
101         uint rcp_shift = (sym.w & 0xffffu);
102         uint cmpl_freq = (sym.w >> 16);
103
104         // renormalize
105         while (enc.rans >= x_max) {
106                 RansPutByte(enc.rans & 0xffu, enc);
107                 enc.rans >>= 8;
108         }
109
110         uint q, unused;
111         umulExtended(enc.rans, rcp_freq, q, unused);
112         enc.rans += bias + (q >> rcp_shift) * cmpl_freq;
113 }
114
115 uint RansEncFlush(inout RansEncoder enc)
116 {
117         RansPutByte(enc.rans >> 24, enc);
118         RansPutByte(enc.rans >> 16, enc);
119         RansPutByte(enc.rans >> 8, enc);
120         RansPutByte(enc.rans >> 0, enc);
121
122         uint num_bytes_written = (enc.rans_start_offset - enc.rans_offset) * 4 + enc.bytes_in_buffer;
123
124         // Make sure there's nothing left in the buffer.
125         RansPutByte(0, enc);
126         RansPutByte(0, enc);
127         RansPutByte(0, enc);
128
129         return num_bytes_written;
130 }
131
132 int sign_extend(uint coeff, uint bits)
133 {
134         return int(coeff << (32 - bits)) >> (32 - bits);
135 }
136
137 void encode_coeff(int signed_k, inout RansEncoder enc)
138 {
139         uint k = abs(signed_k);
140
141         if (k >= ESCAPE_LIMIT) {
142                 // Put the coefficient as a 1/(2^12) symbol _before_
143                 // the 255 coefficient, since the decoder will read the
144                 // 255 coefficient first.
145                 RansEncPut(enc, k, 1, prob_bits);
146                 k = ESCAPE_LIMIT;
147         }
148
149         uvec4 sym = ransdist[enc.lut_base + ((k - 1) & (NUM_SYMS - 1))];
150         RansEncPutSymbol(enc, sym);
151         
152         if (signed_k < 0) {
153                 enc.rans += enc.sign_bias;
154         }
155 }
156
157 void encode_end(inout RansEncoder enc)
158 {
159         uint bytes_written = RansEncFlush(enc);
160         rans_bytes_written[enc.stream_num] = bytes_written;
161 }
162
163 void encode_9_7(uint streamgroup_num, uint coeff_row, layout(r16ui) restrict readonly uimage2D tex, uint col1, uint col2, uint dist1, uint dist2)
164 {
165         RansEncoder enc1, enc2;
166         RansEncInit(streamgroup_num, coeff_row, col1, dist1, enc1);
167         RansEncInit(streamgroup_num, coeff_row, col2, dist2, enc2);
168
169         for (uint subblock_idx = 0; subblock_idx < BLOCKS_PER_STREAM; ++subblock_idx) {
170                 // TODO: Use SSBOs instead of a texture?
171                 uint x = (streamgroup_num * BLOCKS_PER_STREAM + subblock_idx) % WIDTH_BLOCKS;
172                 uint y = (streamgroup_num * BLOCKS_PER_STREAM + subblock_idx) / WIDTH_BLOCKS;
173                 uint f = imageLoad(tex, ivec2(x, y * 8 + coeff_row)).x;
174
175                 encode_coeff(sign_extend(f & 0x1ffu, 9), enc1);
176                 encode_coeff(sign_extend(f >> 9, 7), enc2);
177         }
178
179         encode_end(enc1);
180         encode_end(enc2);
181 }
182
183 void encode_8(uint streamgroup_num, uint coeff_row, layout(r8i) restrict readonly iimage2D tex, uint col, uint dist)
184 {
185         RansEncoder enc;
186         RansEncInit(streamgroup_num, coeff_row, col, dist, enc);
187
188         for (uint subblock_idx = 0; subblock_idx < BLOCKS_PER_STREAM; ++subblock_idx) {
189                 // TODO: Use SSBOs instead of a texture?
190                 uint x = (streamgroup_num * BLOCKS_PER_STREAM + subblock_idx) % WIDTH_BLOCKS;
191                 uint y = (streamgroup_num * BLOCKS_PER_STREAM + subblock_idx) / WIDTH_BLOCKS;
192                 int f = imageLoad(tex, ivec2(x, y * 8 + coeff_row)).x;
193
194                 encode_coeff(f, enc);
195         }
196
197         encode_end(enc);
198 }
199
200 void main()
201 {
202         uint streamgroup_num = gl_WorkGroupID.x;
203         uint coeff_row = gl_WorkGroupID.y;    // 0..7
204         uint coeff_colset = gl_WorkGroupID.z;   // 0 = dc+ac7, 1 = ac1+ac6, 2 = ac2+ac5, 3 = ac3, 4 = ac5
205         uint m = luma_mapping[coeff_row];
206
207         // TODO: DC coeff pred
208
209         if (coeff_colset == 0) {
210                 uint dist_dc = bitfieldExtract(m, 0, 2);
211                 uint dist_ac7 = bitfieldExtract(m, 14, 2);
212                 encode_9_7(streamgroup_num, coeff_row, dc_ac7_tex, 0, 7, dist_dc, dist_ac7);
213         } else if (coeff_colset == 1) {
214                 uint dist_ac1 = bitfieldExtract(m, 2, 2);
215                 uint dist_ac6 = bitfieldExtract(m, 12, 2);
216                 encode_9_7(streamgroup_num, coeff_row, ac1_ac6_tex, 1, 6, dist_ac1, dist_ac6);
217         } else if (coeff_colset == 2) {
218                 uint dist_ac2 = bitfieldExtract(m, 4, 2);
219                 uint dist_ac5 = bitfieldExtract(m, 10, 2);
220                 encode_9_7(streamgroup_num, coeff_row, ac2_ac5_tex, 2, 5, dist_ac2, dist_ac5);
221         } else if (coeff_colset == 3) {
222                 uint dist_ac3 = bitfieldExtract(m, 6, 2);
223                 encode_8(streamgroup_num, coeff_row, ac3_tex, 3, dist_ac3);
224         } else {
225                 uint dist_ac4 = bitfieldExtract(m, 8, 2);
226                 encode_8(streamgroup_num, coeff_row, ac4_tex, 4, dist_ac4);
227         }
228 }