3 // http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
5 layout(local_size_x = 256) in;
7 layout(std430, binding = 9) buffer layoutName
10 uvec2 ransdist[4 * 256];
13 const uint prob_bits = 12;
14 const uint prob_scale = 1 << prob_bits;
16 // FIXME: should come through a uniform.
17 const uint sums[4] = { 57600, 115200, 302400, 446400 };
19 shared uint new_dist[256];
20 shared uint shared_tmp = 0;
22 // Multiple voting areas for the min/max reductions, so we don't need to clear them
24 shared uint voting_areas[256];
28 uint base = 256 * gl_WorkGroupID.x;
29 uint sum = sums[gl_WorkGroupID.x];
31 uint i = gl_LocalInvocationID.x; // Which element we are working on.
33 // Zero has no sign bit. Yes, this is trickery.
35 uint old_zero_freq = dist[base];
36 uint new_zero_freq = (old_zero_freq + 1) / 2; // Rounding up ensures we get a zero frequency.
37 zero_correction = old_zero_freq - new_zero_freq;
38 shared_tmp = sum - zero_correction;
39 dist[base] = new_zero_freq;
42 memoryBarrierShared();
47 // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
48 // This rounding is presumably optimal according to cbloom.
49 float true_prob = float(dist[base + i]) / sum;
50 float from_scaled = true_prob * prob_scale;
51 float down = floor(from_scaled);
52 uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
53 if (dist[base + i] > 0) {
54 new_val = max(new_val, 1);
57 // Reset shared_tmp. We could do without this barrier if we wanted to,
63 memoryBarrierShared();
66 // Parallel sum to find the total.
67 atomicAdd(shared_tmp, new_val);
69 memoryBarrierShared();
72 uint actual_sum = shared_tmp;
74 // Apply corrections one by one, greedily, until we are at the exact right sum.
75 if (actual_sum > prob_scale) {
76 float loss = true_prob * log2(new_val / float(new_val - 1));
78 voting_areas[i] = 0xffffffff;
79 memoryBarrierShared();
84 for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
85 memoryBarrierShared();
88 // Stick the thread ID in the lower mantissa bits so we never get a tie.
89 uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
91 // We can't touch this one any more, but it needs to participate in the barriers,
96 // Find out which thread has the one with the smallest loss.
97 // (Positive floats compare like uints just fine.)
98 voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
99 memoryBarrierShared();
102 if (my_vote == voting_areas[vote_no]) {
104 loss = true_prob * log2(new_val / float(new_val - 1));
108 float benefit = -true_prob * log2(new_val / float(new_val + 1));
111 memoryBarrierShared();
116 for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
117 // Stick the thread ID in the lower mantissa bits so we never get a tie.
118 uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
120 // It's meaningless to increase this, but it needs to participate in the barriers,
121 // so we can't break.
125 // Find out which thread has the one with the most benefit.
126 // (Positive floats compare like uints just fine.)
127 voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
128 memoryBarrierShared();
131 if (my_vote == voting_areas[vote_no]) {
133 benefit = -true_prob * log2(new_val / float(new_val + 1));
139 // Undo what we did above.
143 // Parallel prefix sum.
144 new_dist[(i + 255) & 255u] = new_val; // Move the zero symbol last.
145 memoryBarrierShared();
148 new_val = new_dist[i];
150 for (uint layer = 2; layer <= 256; layer *= 2) {
151 if ((i & (layer - 1)) == layer - 1) {
152 new_dist[i] += new_dist[i - (layer / 2)];
154 memoryBarrierShared();
157 for (uint layer = 128; layer >= 2; layer /= 2) {
158 if ((i & (layer - 1)) == layer - 1 && i != 255) {
159 new_dist[i + (layer / 2)] += new_dist[i];
161 memoryBarrierShared();
164 ransdist[base + i] = uvec2(new_dist[i] - new_val, new_val);