]> git.sesse.net Git - narabu/blob - tally.shader
Silence some Mesa warnings.
[narabu] / tally.shader
1 #version 440
2
3 // http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
4
5 layout(local_size_x = 256) in;
6
7 layout(std430, binding = 9) buffer layoutName
8 {
9         uint dist[4 * 256];
10 };
11
12 const uint prob_bits = 12;
13 const uint prob_scale = 1 << prob_bits;
14
15 // FIXME: should come through a uniform.
16 const uint sums[4] = { 57600, 115200, 302400, 446400 };
17
18 shared uint new_dist[256];
19 shared uint shared_tmp = 0;
20
21 // Multiple voting areas for the min/max reductions, so we don't need to clear them
22 // for every round.
23 shared uint voting_areas[256];
24
25 void main()
26 {
27         uint base = 256 * gl_WorkGroupID.x;
28         uint sum = sums[gl_WorkGroupID.x];
29         uint zero_correction;
30         uint i = gl_LocalInvocationID.x;  // Which element we are working on.
31
32         // Zero has no sign bit. Yes, this is trickery.
33         if (i == 0) {
34                 uint old_zero_freq = dist[base];
35                 uint new_zero_freq = (old_zero_freq + 1) / 2;  // Rounding up ensures we get a zero frequency.
36                 zero_correction = old_zero_freq - new_zero_freq;
37                 shared_tmp = sum - zero_correction;
38                 dist[base] = new_zero_freq;
39         }
40
41         memoryBarrierShared();
42         barrier();
43
44         sum = shared_tmp;
45
46         // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
47         // This rounding is presumably optimal according to cbloom.
48         float true_prob = float(dist[base + i]) / sum;
49         float from_scaled = true_prob * prob_scale;
50         float down = floor(from_scaled);
51         uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
52         if (dist[base + i] > 0) {
53                 new_val = max(new_val, 1);
54         }
55
56         // Reset shared_tmp. We could do without this barrier if we wanted to,
57         // but meh. :-)
58         if (i == 0) {
59                 shared_tmp = 0;
60         }
61
62         memoryBarrierShared();
63         barrier();
64
65         // Parallel sum to find the total.
66         atomicAdd(shared_tmp, new_val);
67
68         memoryBarrierShared();
69         barrier();
70
71         uint actual_sum = shared_tmp;
72
73         // Apply corrections one by one, greedily, until we are at the exact right sum.
74         if (actual_sum > prob_scale) {
75                 float loss = -true_prob * log2(new_val / (new_val - 1));
76
77                 voting_areas[i] = 0xffffffff;
78                 memoryBarrierShared();
79                 barrier();
80
81                 uint vote_no = 0;
82
83                 for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
84                         memoryBarrierShared();
85                         barrier();
86
87                         // Stick the thread ID in the lower mantissa bits so we never get a tie.
88                         uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
89                         if (new_val <= 1) {
90                                 // We can't touch this one any more, but it needs to participate in the barriers,
91                                 // so we can't break.
92                                 my_vote = 0xffffffff;
93                         }
94
95                         // Find out which thread has the one with the smallest loss.
96                         // (Positive floats compare like uints just fine.)
97                         voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
98                         memoryBarrierShared();
99                         barrier();
100
101                         if (my_vote == voting_areas[vote_no]) {
102                                 --new_val;
103                                 loss = -true_prob * log2(new_val / (new_val - 1));
104                         }
105                 }
106         } else {
107                 float benefit = true_prob * log2(new_val / (new_val + 1));
108
109                 voting_areas[i] = 0;
110                 memoryBarrierShared();
111                 barrier();
112
113                 uint vote_no = 0;
114
115                 for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
116                         // Stick the thread ID in the lower mantissa bits so we never get a tie.
117                         uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
118                         if (new_val == 0) {
119                                 // It's meaningless to increase this, but it needs to participate in the barriers,
120                                 // so we can't break.
121                                 my_vote = 0;
122                         }
123
124                         // Find out which thread has the one with the most benefit.
125                         // (Positive floats compare like uints just fine.)
126                         voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
127                         memoryBarrierShared();
128                         barrier();
129
130                         if (my_vote == voting_areas[vote_no]) {
131                                 ++new_val;
132                                 benefit = true_prob * log2(new_val / (new_val + 1));
133                         }
134                 }
135         }
136
137         if (i == 0) {
138                 // Undo what we did above.
139                 new_val *= 2;
140         }
141
142         // Parallel prefix sum.
143         new_dist[(i + 255) & 255u] = new_val;  // Move the zero symbol last.
144         memoryBarrierShared();
145         barrier();
146
147         for (uint layer = 2; layer <= 256; layer *= 2) {
148                 if ((i & (layer - 1)) == layer - 1) {
149                         new_dist[i] += new_dist[i - (layer / 2)];
150                 }
151                 memoryBarrierShared();
152                 barrier();
153         }
154         for (uint layer = 128; layer >= 2; layer /= 2) {
155                 if ((i & (layer - 1)) == layer - 1 && i != 255) {
156                         new_dist[i + (layer / 2)] += new_dist[i];
157                 }
158                 memoryBarrierShared();
159                 barrier();
160         }
161         dist[base + i] = new_dist[i];
162 }