]> git.sesse.net Git - narabu/blobdiff - tally.shader
More fixes of hard-coded values.
[narabu] / tally.shader
index c83bfe4a850044811885b885563c6f2ba8671f30..43ccdb3b03721ed5777b3f3d6b1ea54f44d2f258 100644 (file)
@@ -1,4 +1,5 @@
 #version 440
+#extension GL_ARB_gpu_shader_int64 : enable
 
 // http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
 
@@ -7,10 +8,18 @@ layout(local_size_x = 256) in;
 layout(std430, binding = 9) buffer layoutName
 {
        uint dist[4 * 256];
+       uint ransfreq[4 * 256];
+};
+
+layout(std140, binding = 12) buffer distBlock  // Will become an UBO to rans.shader, thus layout std140.
+{
+       uvec4 ransdist[4 * 256];
+       uint sign_biases[4];
 };
 
 const uint prob_bits = 12;
 const uint prob_scale = 1 << prob_bits;
+const uint RANS_BYTE_L = (1u << 23);
 
 // FIXME: should come through a uniform.
 const uint sums[4] = { 57600, 115200, 302400, 446400 };
@@ -72,7 +81,7 @@ void main()
 
        // Apply corrections one by one, greedily, until we are at the exact right sum.
        if (actual_sum > prob_scale) {
-               float loss = -true_prob * log2(new_val / (new_val - 1));
+               float loss = true_prob * log2(new_val / float(new_val - 1));
 
                voting_areas[i] = 0xffffffff;
                memoryBarrierShared();
@@ -85,7 +94,7 @@ void main()
                        barrier();
 
                        // Stick the thread ID in the lower mantissa bits so we never get a tie.
-                       uint my_vote = (floatBitsToUint(loss) & ~0xff) | gl_LocalInvocationID.x;
+                       uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
                        if (new_val <= 1) {
                                // We can't touch this one any more, but it needs to participate in the barriers,
                                // so we can't break.
@@ -100,11 +109,11 @@ void main()
 
                        if (my_vote == voting_areas[vote_no]) {
                                --new_val;
-                               loss = -true_prob * log2(new_val / (new_val - 1));
+                               loss = true_prob * log2(new_val / float(new_val - 1));
                        }
                }
        } else {
-               float benefit = true_prob * log2(new_val / (new_val + 1));
+               float benefit = -true_prob * log2(new_val / float(new_val + 1));
 
                voting_areas[i] = 0;
                memoryBarrierShared();
@@ -114,7 +123,7 @@ void main()
 
                for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
                        // Stick the thread ID in the lower mantissa bits so we never get a tie.
-                       uint my_vote = (floatBitsToUint(benefit) & ~0xff) | gl_LocalInvocationID.x;
+                       uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
                        if (new_val == 0) {
                                // It's meaningless to increase this, but it needs to participate in the barriers,
                                // so we can't break.
@@ -129,7 +138,7 @@ void main()
 
                        if (my_vote == voting_areas[vote_no]) {
                                ++new_val;
-                               benefit = true_prob * log2(new_val / (new_val + 1));
+                               benefit = -true_prob * log2(new_val / float(new_val + 1));
                        }
                }
        }
@@ -140,7 +149,13 @@ void main()
        }
 
        // Parallel prefix sum.
-       new_dist[(i + 255) & 255] = new_val;  // Move the zero symbol last.
+       new_dist[(i + 255) & 255u] = new_val;  // Move the zero symbol last.
+       memoryBarrierShared();
+       barrier();
+
+       new_val = new_dist[i];
+
+       // TODO: Why do we need this next barrier? It makes no sense.
        memoryBarrierShared();
        barrier();
 
@@ -158,5 +173,32 @@ void main()
                memoryBarrierShared();
                barrier();
        }
-       dist[base + i] = new_dist[i];
+
+       uint start = new_dist[i] - new_val;
+       uint freq = new_val;
+
+       uint x_max = ((RANS_BYTE_L >> (prob_bits + 1)) << 8) * freq;
+       uint cmpl_freq = ((1 << (prob_bits + 1)) - freq);
+       uint rcp_freq, rcp_shift, bias;
+       if (freq < 2) {
+               rcp_freq = ~0u;
+               rcp_shift = 0;
+               bias = start + (1 << (prob_bits + 1)) - 1;
+       } else {
+               uint shift = 0;
+               while (freq > (1u << shift)) {
+                       shift++;
+               }
+
+               rcp_freq = uint(((uint64_t(1) << (shift + 31)) + freq-1) / freq);
+               rcp_shift = shift - 1;
+               bias = start;
+       }
+
+       ransfreq[base + i] = freq;
+       ransdist[base + i] = uvec4(x_max, rcp_freq, bias, (cmpl_freq << 16) | rcp_shift);
+
+       if (i == 255) {
+               sign_biases[gl_WorkGroupID.x] = new_dist[i];
+       }
 }