2 #extension GL_NV_gpu_shader5 : enable
4 // http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
6 layout(local_size_x = 256) in;
8 layout(std430, binding = 9) buffer layoutName
11 uint ransfreq[4 * 256];
12 uvec4 ransdist[4 * 256];
16 const uint prob_bits = 12;
17 const uint prob_scale = 1 << prob_bits;
18 const uint RANS_BYTE_L = (1u << 23);
20 // FIXME: should come through a uniform.
21 const uint sums[4] = { 57600, 115200, 302400, 446400 };
23 shared uint new_dist[256];
24 shared uint shared_tmp = 0;
26 // Multiple voting areas for the min/max reductions, so we don't need to clear them
28 shared uint voting_areas[256];
32 uint base = 256 * gl_WorkGroupID.x;
33 uint sum = sums[gl_WorkGroupID.x];
35 uint i = gl_LocalInvocationID.x; // Which element we are working on.
37 // Zero has no sign bit. Yes, this is trickery.
39 uint old_zero_freq = dist[base];
40 uint new_zero_freq = (old_zero_freq + 1) / 2; // Rounding up ensures we get a zero frequency.
41 zero_correction = old_zero_freq - new_zero_freq;
42 shared_tmp = sum - zero_correction;
43 dist[base] = new_zero_freq;
46 memoryBarrierShared();
51 // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
52 // This rounding is presumably optimal according to cbloom.
53 float true_prob = float(dist[base + i]) / sum;
54 float from_scaled = true_prob * prob_scale;
55 float down = floor(from_scaled);
56 uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
57 if (dist[base + i] > 0) {
58 new_val = max(new_val, 1);
61 // Reset shared_tmp. We could do without this barrier if we wanted to,
67 memoryBarrierShared();
70 // Parallel sum to find the total.
71 atomicAdd(shared_tmp, new_val);
73 memoryBarrierShared();
76 uint actual_sum = shared_tmp;
78 // Apply corrections one by one, greedily, until we are at the exact right sum.
79 if (actual_sum > prob_scale) {
80 float loss = true_prob * log2(new_val / float(new_val - 1));
82 voting_areas[i] = 0xffffffff;
83 memoryBarrierShared();
88 for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
89 memoryBarrierShared();
92 // Stick the thread ID in the lower mantissa bits so we never get a tie.
93 uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
95 // We can't touch this one any more, but it needs to participate in the barriers,
100 // Find out which thread has the one with the smallest loss.
101 // (Positive floats compare like uints just fine.)
102 voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
103 memoryBarrierShared();
106 if (my_vote == voting_areas[vote_no]) {
108 loss = true_prob * log2(new_val / float(new_val - 1));
112 float benefit = -true_prob * log2(new_val / float(new_val + 1));
115 memoryBarrierShared();
120 for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
121 // Stick the thread ID in the lower mantissa bits so we never get a tie.
122 uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
124 // It's meaningless to increase this, but it needs to participate in the barriers,
125 // so we can't break.
129 // Find out which thread has the one with the most benefit.
130 // (Positive floats compare like uints just fine.)
131 voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
132 memoryBarrierShared();
135 if (my_vote == voting_areas[vote_no]) {
137 benefit = -true_prob * log2(new_val / float(new_val + 1));
143 // Undo what we did above.
147 // Parallel prefix sum.
148 new_dist[(i + 255) & 255u] = new_val; // Move the zero symbol last.
149 memoryBarrierShared();
152 new_val = new_dist[i];
154 // TODO: Why do we need this next barrier? It makes no sense.
155 memoryBarrierShared();
158 for (uint layer = 2; layer <= 256; layer *= 2) {
159 if ((i & (layer - 1)) == layer - 1) {
160 new_dist[i] += new_dist[i - (layer / 2)];
162 memoryBarrierShared();
165 for (uint layer = 128; layer >= 2; layer /= 2) {
166 if ((i & (layer - 1)) == layer - 1 && i != 255) {
167 new_dist[i + (layer / 2)] += new_dist[i];
169 memoryBarrierShared();
173 uint start = new_dist[i] - new_val;
176 uint x_max = ((RANS_BYTE_L >> (prob_bits + 1)) << 8) * freq;
177 uint cmpl_freq = ((1 << (prob_bits + 1)) - freq);
178 uint rcp_freq, rcp_shift, bias;
182 bias = start + (1 << (prob_bits + 1)) - 1;
185 while (freq > (1u << shift)) {
189 rcp_freq = uint(((uint64_t(1) << (shift + 31)) + freq-1) / freq);
190 rcp_shift = shift - 1;
194 ransfreq[base + i] = freq;
195 ransdist[base + i] = uvec4(x_max, rcp_freq, bias, (cmpl_freq << 16) | rcp_shift);
198 sign_biases[gl_WorkGroupID.x] = new_dist[i];