3 // http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
5 layout(local_size_x = 256) in;
7 layout(std430, binding = 9) buffer layoutName
12 const uint prob_bits = 12;
13 const uint prob_scale = 1 << prob_bits;
15 // FIXME: should come through a uniform.
16 const uint sums[4] = { 57600, 115200, 302400, 446400 };
18 shared uint new_dist[256];
19 shared uint shared_tmp = 0;
21 // Multiple voting areas for the min/max reductions, so we don't need to clear them
23 shared uint voting_areas[256];
27 uint base = 256 * gl_WorkGroupID.x;
28 uint sum = sums[gl_WorkGroupID.x];
30 uint i = gl_LocalInvocationID.x; // Which element we are working on.
32 // Zero has no sign bit. Yes, this is trickery.
34 uint old_zero_freq = dist[base];
35 uint new_zero_freq = (old_zero_freq + 1) / 2; // Rounding up ensures we get a zero frequency.
36 zero_correction = old_zero_freq - new_zero_freq;
37 shared_tmp = sum - zero_correction;
38 dist[base] = new_zero_freq;
41 memoryBarrierShared();
46 // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
47 // This rounding is presumably optimal according to cbloom.
48 float true_prob = float(dist[base + i]) / sum;
49 float from_scaled = true_prob * prob_scale;
50 float down = floor(from_scaled);
51 uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
52 if (dist[base + i] > 0) {
53 new_val = max(new_val, 1);
56 // Reset shared_tmp. We could do without this barrier if we wanted to,
62 memoryBarrierShared();
65 // Parallel sum to find the total.
66 atomicAdd(shared_tmp, new_val);
68 memoryBarrierShared();
71 uint actual_sum = shared_tmp;
73 // Apply corrections one by one, greedily, until we are at the exact right sum.
74 if (actual_sum > prob_scale) {
75 float loss = -true_prob * log2(new_val / (new_val - 1));
77 voting_areas[i] = 0xffffffff;
78 memoryBarrierShared();
83 for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
84 memoryBarrierShared();
87 // Stick the thread ID in the lower mantissa bits so we never get a tie.
88 uint my_vote = (floatBitsToUint(loss) & ~0xff) | gl_LocalInvocationID.x;
90 // We can't touch this one any more, but it needs to participate in the barriers,
95 // Find out which thread has the one with the smallest loss.
96 // (Positive floats compare like uints just fine.)
97 voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
98 memoryBarrierShared();
101 if (my_vote == voting_areas[vote_no]) {
103 loss = -true_prob * log2(new_val / (new_val - 1));
107 float benefit = true_prob * log2(new_val / (new_val + 1));
110 memoryBarrierShared();
115 for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
116 // Stick the thread ID in the lower mantissa bits so we never get a tie.
117 uint my_vote = (floatBitsToUint(benefit) & ~0xff) | gl_LocalInvocationID.x;
119 // It's meaningless to increase this, but it needs to participate in the barriers,
120 // so we can't break.
124 // Find out which thread has the one with the most benefit.
125 // (Positive floats compare like uints just fine.)
126 voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
127 memoryBarrierShared();
130 if (my_vote == voting_areas[vote_no]) {
132 benefit = true_prob * log2(new_val / (new_val + 1));
138 // Undo what we did above.
142 // Parallel prefix sum.
143 new_dist[(i + 255) & 255] = new_val; // Move the zero symbol last.
144 memoryBarrierShared();
147 for (uint layer = 2; layer <= 256; layer *= 2) {
148 if ((i & (layer - 1)) == layer - 1) {
149 new_dist[i] += new_dist[i - (layer / 2)];
151 memoryBarrierShared();
154 for (uint layer = 128; layer >= 2; layer /= 2) {
155 if ((i & (layer - 1)) == layer - 1 && i != 255) {
156 new_dist[i + (layer / 2)] += new_dist[i];
158 memoryBarrierShared();
161 dist[base + i] = new_dist[i];