--- /dev/null
+#version 440
+
+// http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
+
+layout(local_size_x = 256) in;
+
+layout(std430, binding = 9) buffer layoutName
+{
+ uint dist[4 * 256];
+};
+
+const uint prob_bits = 12;
+const uint prob_scale = 1 << prob_bits;
+
+// FIXME: should come through a uniform.
+const uint sums[4] = { 57600, 115200, 302400, 446400 };
+
+shared uint new_dist[256];
+shared uint shared_tmp = 0;
+
+// Multiple voting areas for the min/max reductions, so we don't need to clear them
+// for every round.
+shared uint voting_areas[256];
+
+void main()
+{
+ uint base = 256 * gl_WorkGroupID.x;
+ uint sum = sums[gl_WorkGroupID.x];
+ uint zero_correction;
+ uint i = gl_LocalInvocationID.x; // Which element we are working on.
+
+ // Zero has no sign bit. Yes, this is trickery.
+ if (i == 0) {
+ uint old_zero_freq = dist[base];
+ uint new_zero_freq = (old_zero_freq + 1) / 2; // Rounding up ensures we get a zero frequency.
+ zero_correction = old_zero_freq - new_zero_freq;
+ shared_tmp = sum - zero_correction;
+ dist[base] = new_zero_freq;
+ }
+
+ memoryBarrierShared();
+ barrier();
+
+ sum = shared_tmp;
+
+ // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
+ // This rounding is presumably optimal according to cbloom.
+ float true_prob = float(dist[base + i]) / sum;
+ float from_scaled = true_prob * prob_scale;
+ float down = floor(from_scaled);
+ uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
+ if (dist[base + i] > 0) {
+ new_val = max(new_val, 1);
+ }
+
+ // Reset shared_tmp. We could do without this barrier if we wanted to,
+ // but meh. :-)
+ if (i == 0) {
+ shared_tmp = 0;
+ }
+
+ memoryBarrierShared();
+ barrier();
+
+ // Parallel sum to find the total.
+ atomicAdd(shared_tmp, new_val);
+
+ memoryBarrierShared();
+ barrier();
+
+ uint actual_sum = shared_tmp;
+
+ // Apply corrections one by one, greedily, until we are at the exact right sum.
+ if (actual_sum > prob_scale) {
+ float loss = -true_prob * log2(new_val / (new_val - 1));
+
+ voting_areas[i] = 0xffffffff;
+ memoryBarrierShared();
+ barrier();
+
+ uint vote_no = 0;
+
+ for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
+ memoryBarrierShared();
+ barrier();
+
+ // Stick the thread ID in the lower mantissa bits so we never get a tie.
+ uint my_vote = (floatBitsToUint(loss) & ~0xff) | gl_LocalInvocationID.x;
+ if (new_val <= 1) {
+ // We can't touch this one any more, but it needs to participate in the barriers,
+ // so we can't break.
+ my_vote = 0xffffffff;
+ }
+
+ // Find out which thread has the one with the smallest loss.
+ // (Positive floats compare like uints just fine.)
+ voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
+ memoryBarrierShared();
+ barrier();
+
+ if (my_vote == voting_areas[vote_no]) {
+ --new_val;
+ loss = -true_prob * log2(new_val / (new_val - 1));
+ }
+ }
+ } else {
+ float benefit = true_prob * log2(new_val / (new_val + 1));
+
+ voting_areas[i] = 0;
+ memoryBarrierShared();
+ barrier();
+
+ uint vote_no = 0;
+
+ for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
+ // Stick the thread ID in the lower mantissa bits so we never get a tie.
+ uint my_vote = (floatBitsToUint(benefit) & ~0xff) | gl_LocalInvocationID.x;
+ if (new_val == 0) {
+ // It's meaningless to increase this, but it needs to participate in the barriers,
+ // so we can't break.
+ my_vote = 0;
+ }
+
+ // Find out which thread has the one with the most benefit.
+ // (Positive floats compare like uints just fine.)
+ voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
+ memoryBarrierShared();
+ barrier();
+
+ if (my_vote == voting_areas[vote_no]) {
+ ++new_val;
+ benefit = true_prob * log2(new_val / (new_val + 1));
+ }
+ }
+ }
+
+ if (i == 0) {
+ // Undo what we did above.
+ new_val *= 2;
+ }
+
+ // Parallel prefix sum.
+ new_dist[(i + 255) & 255] = new_val; // Move the zero symbol last.
+ memoryBarrierShared();
+ barrier();
+
+ for (uint layer = 2; layer <= 256; layer *= 2) {
+ if ((i & (layer - 1)) == layer - 1) {
+ new_dist[i] += new_dist[i - (layer / 2)];
+ }
+ memoryBarrierShared();
+ barrier();
+ }
+ for (uint layer = 128; layer >= 2; layer /= 2) {
+ if ((i & (layer - 1)) == layer - 1 && i != 255) {
+ new_dist[i + (layer / 2)] += new_dist[i];
+ }
+ memoryBarrierShared();
+ barrier();
+ }
+ dist[base + i] = new_dist[i];
+}