]> git.sesse.net Git - narabu/blobdiff - tally.shader
Fix some off-by-ones in the tally shader.
[narabu] / tally.shader
index c83bfe4a850044811885b885563c6f2ba8671f30..fba526b96b6c29be33db626098a63a50b9ddab22 100644 (file)
@@ -7,6 +7,7 @@ layout(local_size_x = 256) in;
 layout(std430, binding = 9) buffer layoutName
 {
        uint dist[4 * 256];
+       uvec2 ransdist[4 * 256];
 };
 
 const uint prob_bits = 12;
@@ -72,7 +73,7 @@ void main()
 
        // Apply corrections one by one, greedily, until we are at the exact right sum.
        if (actual_sum > prob_scale) {
-               float loss = -true_prob * log2(new_val / (new_val - 1));
+               float loss = true_prob * log2(new_val / float(new_val - 1));
 
                voting_areas[i] = 0xffffffff;
                memoryBarrierShared();
@@ -85,7 +86,7 @@ void main()
                        barrier();
 
                        // Stick the thread ID in the lower mantissa bits so we never get a tie.
-                       uint my_vote = (floatBitsToUint(loss) & ~0xff) | gl_LocalInvocationID.x;
+                       uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
                        if (new_val <= 1) {
                                // We can't touch this one any more, but it needs to participate in the barriers,
                                // so we can't break.
@@ -100,11 +101,11 @@ void main()
 
                        if (my_vote == voting_areas[vote_no]) {
                                --new_val;
-                               loss = -true_prob * log2(new_val / (new_val - 1));
+                               loss = true_prob * log2(new_val / float(new_val - 1));
                        }
                }
        } else {
-               float benefit = true_prob * log2(new_val / (new_val + 1));
+               float benefit = -true_prob * log2(new_val / float(new_val + 1));
 
                voting_areas[i] = 0;
                memoryBarrierShared();
@@ -114,7 +115,7 @@ void main()
 
                for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
                        // Stick the thread ID in the lower mantissa bits so we never get a tie.
-                       uint my_vote = (floatBitsToUint(benefit) & ~0xff) | gl_LocalInvocationID.x;
+                       uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
                        if (new_val == 0) {
                                // It's meaningless to increase this, but it needs to participate in the barriers,
                                // so we can't break.
@@ -129,7 +130,7 @@ void main()
 
                        if (my_vote == voting_areas[vote_no]) {
                                ++new_val;
-                               benefit = true_prob * log2(new_val / (new_val + 1));
+                               benefit = -true_prob * log2(new_val / float(new_val + 1));
                        }
                }
        }
@@ -140,10 +141,12 @@ void main()
        }
 
        // Parallel prefix sum.
-       new_dist[(i + 255) & 255] = new_val;  // Move the zero symbol last.
+       new_dist[(i + 255) & 255u] = new_val;  // Move the zero symbol last.
        memoryBarrierShared();
        barrier();
 
+       new_val = new_dist[i];
+
        for (uint layer = 2; layer <= 256; layer *= 2) {
                if ((i & (layer - 1)) == layer - 1) {
                        new_dist[i] += new_dist[i - (layer / 2)];
@@ -158,5 +161,5 @@ void main()
                memoryBarrierShared();
                barrier();
        }
-       dist[base + i] = new_dist[i];
+       ransdist[base + i] = uvec2(new_dist[i] - new_val, new_val);
 }