X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=tally.shader;h=fba526b96b6c29be33db626098a63a50b9ddab22;hb=4e20a14f8ca0bc3259fa2be5bbbd4057080ce62c;hp=c83bfe4a850044811885b885563c6f2ba8671f30;hpb=fd6116de8d7253bed230222bf277a7c8aaa3b8ff;p=narabu diff --git a/tally.shader b/tally.shader index c83bfe4..fba526b 100644 --- a/tally.shader +++ b/tally.shader @@ -7,6 +7,7 @@ layout(local_size_x = 256) in; layout(std430, binding = 9) buffer layoutName { uint dist[4 * 256]; + uvec2 ransdist[4 * 256]; }; const uint prob_bits = 12; @@ -72,7 +73,7 @@ void main() // Apply corrections one by one, greedily, until we are at the exact right sum. if (actual_sum > prob_scale) { - float loss = -true_prob * log2(new_val / (new_val - 1)); + float loss = true_prob * log2(new_val / float(new_val - 1)); voting_areas[i] = 0xffffffff; memoryBarrierShared(); @@ -85,7 +86,7 @@ void main() barrier(); // Stick the thread ID in the lower mantissa bits so we never get a tie. - uint my_vote = (floatBitsToUint(loss) & ~0xff) | gl_LocalInvocationID.x; + uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x; if (new_val <= 1) { // We can't touch this one any more, but it needs to participate in the barriers, // so we can't break. @@ -100,11 +101,11 @@ void main() if (my_vote == voting_areas[vote_no]) { --new_val; - loss = -true_prob * log2(new_val / (new_val - 1)); + loss = true_prob * log2(new_val / float(new_val - 1)); } } } else { - float benefit = true_prob * log2(new_val / (new_val + 1)); + float benefit = -true_prob * log2(new_val / float(new_val + 1)); voting_areas[i] = 0; memoryBarrierShared(); @@ -114,7 +115,7 @@ void main() for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) { // Stick the thread ID in the lower mantissa bits so we never get a tie. - uint my_vote = (floatBitsToUint(benefit) & ~0xff) | gl_LocalInvocationID.x; + uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x; if (new_val == 0) { // It's meaningless to increase this, but it needs to participate in the barriers, // so we can't break. @@ -129,7 +130,7 @@ void main() if (my_vote == voting_areas[vote_no]) { ++new_val; - benefit = true_prob * log2(new_val / (new_val + 1)); + benefit = -true_prob * log2(new_val / float(new_val + 1)); } } } @@ -140,10 +141,12 @@ void main() } // Parallel prefix sum. - new_dist[(i + 255) & 255] = new_val; // Move the zero symbol last. + new_dist[(i + 255) & 255u] = new_val; // Move the zero symbol last. memoryBarrierShared(); barrier(); + new_val = new_dist[i]; + for (uint layer = 2; layer <= 256; layer *= 2) { if ((i & (layer - 1)) == layer - 1) { new_dist[i] += new_dist[i - (layer / 2)]; @@ -158,5 +161,5 @@ void main() memoryBarrierShared(); barrier(); } - dist[base + i] = new_dist[i]; + ransdist[base + i] = uvec2(new_dist[i] - new_val, new_val); }