]> git.sesse.net Git - narabu/blob - tally.shader
351d3fb54e7aeb22dad59ec10019199382799663
[narabu] / tally.shader
1 #version 440
2 #extension GL_NV_gpu_shader5 : enable
3
4 // http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
5
6 layout(local_size_x = 256) in;
7
8 layout(std430, binding = 9) buffer layoutName
9 {
10         uint dist[4 * 256];
11         uint ransfreq[4 * 256];
12         uvec4 ransdist[4 * 256];
13         uint sign_biases[4];
14 };
15
16 const uint prob_bits = 12;
17 const uint prob_scale = 1 << prob_bits;
18 const uint RANS_BYTE_L = (1u << 23);
19
20 // FIXME: should come through a uniform.
21 const uint sums[4] = { 57600, 115200, 302400, 446400 };
22
23 shared uint new_dist[256];
24 shared uint shared_tmp = 0;
25
26 // Multiple voting areas for the min/max reductions, so we don't need to clear them
27 // for every round.
28 shared uint voting_areas[256];
29
30 void main()
31 {
32         uint base = 256 * gl_WorkGroupID.x;
33         uint sum = sums[gl_WorkGroupID.x];
34         uint zero_correction;
35         uint i = gl_LocalInvocationID.x;  // Which element we are working on.
36
37         // Zero has no sign bit. Yes, this is trickery.
38         if (i == 0) {
39                 uint old_zero_freq = dist[base];
40                 uint new_zero_freq = (old_zero_freq + 1) / 2;  // Rounding up ensures we get a zero frequency.
41                 zero_correction = old_zero_freq - new_zero_freq;
42                 shared_tmp = sum - zero_correction;
43                 dist[base] = new_zero_freq;
44         }
45
46         memoryBarrierShared();
47         barrier();
48
49         sum = shared_tmp;
50
51         // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
52         // This rounding is presumably optimal according to cbloom.
53         float true_prob = float(dist[base + i]) / sum;
54         float from_scaled = true_prob * prob_scale;
55         float down = floor(from_scaled);
56         uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
57         if (dist[base + i] > 0) {
58                 new_val = max(new_val, 1);
59         }
60
61         // Reset shared_tmp. We could do without this barrier if we wanted to,
62         // but meh. :-)
63         if (i == 0) {
64                 shared_tmp = 0;
65         }
66
67         memoryBarrierShared();
68         barrier();
69
70         // Parallel sum to find the total.
71         atomicAdd(shared_tmp, new_val);
72
73         memoryBarrierShared();
74         barrier();
75
76         uint actual_sum = shared_tmp;
77
78         // Apply corrections one by one, greedily, until we are at the exact right sum.
79         if (actual_sum > prob_scale) {
80                 float loss = true_prob * log2(new_val / float(new_val - 1));
81
82                 voting_areas[i] = 0xffffffff;
83                 memoryBarrierShared();
84                 barrier();
85
86                 uint vote_no = 0;
87
88                 for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
89                         memoryBarrierShared();
90                         barrier();
91
92                         // Stick the thread ID in the lower mantissa bits so we never get a tie.
93                         uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
94                         if (new_val <= 1) {
95                                 // We can't touch this one any more, but it needs to participate in the barriers,
96                                 // so we can't break.
97                                 my_vote = 0xffffffff;
98                         }
99
100                         // Find out which thread has the one with the smallest loss.
101                         // (Positive floats compare like uints just fine.)
102                         voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
103                         memoryBarrierShared();
104                         barrier();
105
106                         if (my_vote == voting_areas[vote_no]) {
107                                 --new_val;
108                                 loss = true_prob * log2(new_val / float(new_val - 1));
109                         }
110                 }
111         } else {
112                 float benefit = -true_prob * log2(new_val / float(new_val + 1));
113
114                 voting_areas[i] = 0;
115                 memoryBarrierShared();
116                 barrier();
117
118                 uint vote_no = 0;
119
120                 for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
121                         // Stick the thread ID in the lower mantissa bits so we never get a tie.
122                         uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
123                         if (new_val == 0) {
124                                 // It's meaningless to increase this, but it needs to participate in the barriers,
125                                 // so we can't break.
126                                 my_vote = 0;
127                         }
128
129                         // Find out which thread has the one with the most benefit.
130                         // (Positive floats compare like uints just fine.)
131                         voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
132                         memoryBarrierShared();
133                         barrier();
134
135                         if (my_vote == voting_areas[vote_no]) {
136                                 ++new_val;
137                                 benefit = -true_prob * log2(new_val / float(new_val + 1));
138                         }
139                 }
140         }
141
142         if (i == 0) {
143                 // Undo what we did above.
144                 new_val *= 2;
145         }
146
147         // Parallel prefix sum.
148         new_dist[(i + 255) & 255u] = new_val;  // Move the zero symbol last.
149         memoryBarrierShared();
150         barrier();
151
152         new_val = new_dist[i];
153
154         // TODO: Why do we need this next barrier? It makes no sense.
155         memoryBarrierShared();
156         barrier();
157
158         for (uint layer = 2; layer <= 256; layer *= 2) {
159                 if ((i & (layer - 1)) == layer - 1) {
160                         new_dist[i] += new_dist[i - (layer / 2)];
161                 }
162                 memoryBarrierShared();
163                 barrier();
164         }
165         for (uint layer = 128; layer >= 2; layer /= 2) {
166                 if ((i & (layer - 1)) == layer - 1 && i != 255) {
167                         new_dist[i + (layer / 2)] += new_dist[i];
168                 }
169                 memoryBarrierShared();
170                 barrier();
171         }
172
173         uint start = new_dist[i] - new_val;
174         uint freq = new_val;
175
176         uint x_max = ((RANS_BYTE_L >> (prob_bits + 1)) << 8) * freq;
177         uint cmpl_freq = ((1 << (prob_bits + 1)) - freq);
178         uint rcp_freq, rcp_shift, bias;
179         if (freq < 2) {
180                 rcp_freq = ~0u;
181                 rcp_shift = 0;
182                 bias = start + (1 << (prob_bits + 1)) - 1;
183         } else {
184                 uint shift = 0;
185                 while (freq > (1u << shift)) {
186                         shift++;
187                 }
188
189                 rcp_freq = uint(((uint64_t(1) << (shift + 31)) + freq-1) / freq);
190                 rcp_shift = shift - 1;
191                 bias = start;
192         }
193
194         ransfreq[base + i] = freq;
195         ransdist[base + i] = uvec4(x_max, rcp_freq, bias, (cmpl_freq << 16) | rcp_shift);
196
197         if (i == 255) {
198                 sign_biases[gl_WorkGroupID.x] = new_dist[i];
199         }
200 }