]> git.sesse.net Git - narabu/blobdiff - tally.shader
Add rANS normalization to the encoder.
[narabu] / tally.shader
diff --git a/tally.shader b/tally.shader
new file mode 100644 (file)
index 0000000..c83bfe4
--- /dev/null
@@ -0,0 +1,162 @@
+#version 440
+
+// http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
+
+layout(local_size_x = 256) in;
+
+layout(std430, binding = 9) buffer layoutName
+{
+       uint dist[4 * 256];
+};
+
+const uint prob_bits = 12;
+const uint prob_scale = 1 << prob_bits;
+
+// FIXME: should come through a uniform.
+const uint sums[4] = { 57600, 115200, 302400, 446400 };
+
+shared uint new_dist[256];
+shared uint shared_tmp = 0;
+
+// Multiple voting areas for the min/max reductions, so we don't need to clear them
+// for every round.
+shared uint voting_areas[256];
+
+void main()
+{
+       uint base = 256 * gl_WorkGroupID.x;
+       uint sum = sums[gl_WorkGroupID.x];
+       uint zero_correction;
+       uint i = gl_LocalInvocationID.x;  // Which element we are working on.
+
+       // Zero has no sign bit. Yes, this is trickery.
+       if (i == 0) {
+               uint old_zero_freq = dist[base];
+               uint new_zero_freq = (old_zero_freq + 1) / 2;  // Rounding up ensures we get a zero frequency.
+               zero_correction = old_zero_freq - new_zero_freq;
+               shared_tmp = sum - zero_correction;
+               dist[base] = new_zero_freq;
+       }
+
+       memoryBarrierShared();
+       barrier();
+
+       sum = shared_tmp;
+
+       // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
+       // This rounding is presumably optimal according to cbloom.
+       float true_prob = float(dist[base + i]) / sum;
+       float from_scaled = true_prob * prob_scale;
+       float down = floor(from_scaled);
+       uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
+       if (dist[base + i] > 0) {
+               new_val = max(new_val, 1);
+       }
+
+       // Reset shared_tmp. We could do without this barrier if we wanted to,
+       // but meh. :-)
+       if (i == 0) {
+               shared_tmp = 0;
+       }
+
+       memoryBarrierShared();
+       barrier();
+
+       // Parallel sum to find the total.
+       atomicAdd(shared_tmp, new_val);
+
+       memoryBarrierShared();
+       barrier();
+
+       uint actual_sum = shared_tmp;
+
+       // Apply corrections one by one, greedily, until we are at the exact right sum.
+       if (actual_sum > prob_scale) {
+               float loss = -true_prob * log2(new_val / (new_val - 1));
+
+               voting_areas[i] = 0xffffffff;
+               memoryBarrierShared();
+               barrier();
+
+               uint vote_no = 0;
+
+               for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
+                       memoryBarrierShared();
+                       barrier();
+
+                       // Stick the thread ID in the lower mantissa bits so we never get a tie.
+                       uint my_vote = (floatBitsToUint(loss) & ~0xff) | gl_LocalInvocationID.x;
+                       if (new_val <= 1) {
+                               // We can't touch this one any more, but it needs to participate in the barriers,
+                               // so we can't break.
+                               my_vote = 0xffffffff;
+                       }
+
+                       // Find out which thread has the one with the smallest loss.
+                       // (Positive floats compare like uints just fine.)
+                       voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
+                       memoryBarrierShared();
+                       barrier();
+
+                       if (my_vote == voting_areas[vote_no]) {
+                               --new_val;
+                               loss = -true_prob * log2(new_val / (new_val - 1));
+                       }
+               }
+       } else {
+               float benefit = true_prob * log2(new_val / (new_val + 1));
+
+               voting_areas[i] = 0;
+               memoryBarrierShared();
+               barrier();
+
+               uint vote_no = 0;
+
+               for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
+                       // Stick the thread ID in the lower mantissa bits so we never get a tie.
+                       uint my_vote = (floatBitsToUint(benefit) & ~0xff) | gl_LocalInvocationID.x;
+                       if (new_val == 0) {
+                               // It's meaningless to increase this, but it needs to participate in the barriers,
+                               // so we can't break.
+                               my_vote = 0;
+                       }
+
+                       // Find out which thread has the one with the most benefit.
+                       // (Positive floats compare like uints just fine.)
+                       voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
+                       memoryBarrierShared();
+                       barrier();
+
+                       if (my_vote == voting_areas[vote_no]) {
+                               ++new_val;
+                               benefit = true_prob * log2(new_val / (new_val + 1));
+                       }
+               }
+       }
+
+       if (i == 0) {
+               // Undo what we did above.
+               new_val *= 2;
+       }
+
+       // Parallel prefix sum.
+       new_dist[(i + 255) & 255] = new_val;  // Move the zero symbol last.
+       memoryBarrierShared();
+       barrier();
+
+       for (uint layer = 2; layer <= 256; layer *= 2) {
+               if ((i & (layer - 1)) == layer - 1) {
+                       new_dist[i] += new_dist[i - (layer / 2)];
+               }
+               memoryBarrierShared();
+               barrier();
+       }
+       for (uint layer = 128; layer >= 2; layer /= 2) {
+               if ((i & (layer - 1)) == layer - 1 && i != 255) {
+                       new_dist[i + (layer / 2)] += new_dist[i];
+               }
+               memoryBarrierShared();
+               barrier();
+       }
+       dist[base + i] = new_dist[i];
+}