]> git.sesse.net Git - narabu/blobdiff - tally.shader
Make the encoder 100% GPU. Not working yet, though.
[narabu] / tally.shader
index c83bfe4a850044811885b885563c6f2ba8671f30..1623a1dd06e845af873ba62adcdaa177ffe87c4e 100644 (file)
@@ -7,6 +7,7 @@ layout(local_size_x = 256) in;
 layout(std430, binding = 9) buffer layoutName
 {
        uint dist[4 * 256];
+       uvec2 ransdist[4 * 256];
 };
 
 const uint prob_bits = 12;
@@ -85,7 +86,7 @@ void main()
                        barrier();
 
                        // Stick the thread ID in the lower mantissa bits so we never get a tie.
-                       uint my_vote = (floatBitsToUint(loss) & ~0xff) | gl_LocalInvocationID.x;
+                       uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
                        if (new_val <= 1) {
                                // We can't touch this one any more, but it needs to participate in the barriers,
                                // so we can't break.
@@ -114,7 +115,7 @@ void main()
 
                for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
                        // Stick the thread ID in the lower mantissa bits so we never get a tie.
-                       uint my_vote = (floatBitsToUint(benefit) & ~0xff) | gl_LocalInvocationID.x;
+                       uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
                        if (new_val == 0) {
                                // It's meaningless to increase this, but it needs to participate in the barriers,
                                // so we can't break.
@@ -140,7 +141,7 @@ void main()
        }
 
        // Parallel prefix sum.
-       new_dist[(i + 255) & 255] = new_val;  // Move the zero symbol last.
+       new_dist[(i + 255) & 255u] = new_val;  // Move the zero symbol last.
        memoryBarrierShared();
        barrier();
 
@@ -158,5 +159,5 @@ void main()
                memoryBarrierShared();
                barrier();
        }
-       dist[base + i] = new_dist[i];
+       ransdist[base + i] = uvec2(new_val, new_dist[i]);
 }