]> git.sesse.net Git - narabu/blob - tally.shader
More fixes of hard-coded values.
[narabu] / tally.shader
1 #version 440
2 #extension GL_ARB_gpu_shader_int64 : enable
3
4 // http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
5
6 layout(local_size_x = 256) in;
7
8 layout(std430, binding = 9) buffer layoutName
9 {
10         uint dist[4 * 256];
11         uint ransfreq[4 * 256];
12 };
13
14 layout(std140, binding = 12) buffer distBlock  // Will become an UBO to rans.shader, thus layout std140.
15 {
16         uvec4 ransdist[4 * 256];
17         uint sign_biases[4];
18 };
19
20 const uint prob_bits = 12;
21 const uint prob_scale = 1 << prob_bits;
22 const uint RANS_BYTE_L = (1u << 23);
23
24 // FIXME: should come through a uniform.
25 const uint sums[4] = { 57600, 115200, 302400, 446400 };
26
27 shared uint new_dist[256];
28 shared uint shared_tmp = 0;
29
30 // Multiple voting areas for the min/max reductions, so we don't need to clear them
31 // for every round.
32 shared uint voting_areas[256];
33
34 void main()
35 {
36         uint base = 256 * gl_WorkGroupID.x;
37         uint sum = sums[gl_WorkGroupID.x];
38         uint zero_correction;
39         uint i = gl_LocalInvocationID.x;  // Which element we are working on.
40
41         // Zero has no sign bit. Yes, this is trickery.
42         if (i == 0) {
43                 uint old_zero_freq = dist[base];
44                 uint new_zero_freq = (old_zero_freq + 1) / 2;  // Rounding up ensures we get a zero frequency.
45                 zero_correction = old_zero_freq - new_zero_freq;
46                 shared_tmp = sum - zero_correction;
47                 dist[base] = new_zero_freq;
48         }
49
50         memoryBarrierShared();
51         barrier();
52
53         sum = shared_tmp;
54
55         // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
56         // This rounding is presumably optimal according to cbloom.
57         float true_prob = float(dist[base + i]) / sum;
58         float from_scaled = true_prob * prob_scale;
59         float down = floor(from_scaled);
60         uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
61         if (dist[base + i] > 0) {
62                 new_val = max(new_val, 1);
63         }
64
65         // Reset shared_tmp. We could do without this barrier if we wanted to,
66         // but meh. :-)
67         if (i == 0) {
68                 shared_tmp = 0;
69         }
70
71         memoryBarrierShared();
72         barrier();
73
74         // Parallel sum to find the total.
75         atomicAdd(shared_tmp, new_val);
76
77         memoryBarrierShared();
78         barrier();
79
80         uint actual_sum = shared_tmp;
81
82         // Apply corrections one by one, greedily, until we are at the exact right sum.
83         if (actual_sum > prob_scale) {
84                 float loss = true_prob * log2(new_val / float(new_val - 1));
85
86                 voting_areas[i] = 0xffffffff;
87                 memoryBarrierShared();
88                 barrier();
89
90                 uint vote_no = 0;
91
92                 for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
93                         memoryBarrierShared();
94                         barrier();
95
96                         // Stick the thread ID in the lower mantissa bits so we never get a tie.
97                         uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
98                         if (new_val <= 1) {
99                                 // We can't touch this one any more, but it needs to participate in the barriers,
100                                 // so we can't break.
101                                 my_vote = 0xffffffff;
102                         }
103
104                         // Find out which thread has the one with the smallest loss.
105                         // (Positive floats compare like uints just fine.)
106                         voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
107                         memoryBarrierShared();
108                         barrier();
109
110                         if (my_vote == voting_areas[vote_no]) {
111                                 --new_val;
112                                 loss = true_prob * log2(new_val / float(new_val - 1));
113                         }
114                 }
115         } else {
116                 float benefit = -true_prob * log2(new_val / float(new_val + 1));
117
118                 voting_areas[i] = 0;
119                 memoryBarrierShared();
120                 barrier();
121
122                 uint vote_no = 0;
123
124                 for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
125                         // Stick the thread ID in the lower mantissa bits so we never get a tie.
126                         uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
127                         if (new_val == 0) {
128                                 // It's meaningless to increase this, but it needs to participate in the barriers,
129                                 // so we can't break.
130                                 my_vote = 0;
131                         }
132
133                         // Find out which thread has the one with the most benefit.
134                         // (Positive floats compare like uints just fine.)
135                         voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
136                         memoryBarrierShared();
137                         barrier();
138
139                         if (my_vote == voting_areas[vote_no]) {
140                                 ++new_val;
141                                 benefit = -true_prob * log2(new_val / float(new_val + 1));
142                         }
143                 }
144         }
145
146         if (i == 0) {
147                 // Undo what we did above.
148                 new_val *= 2;
149         }
150
151         // Parallel prefix sum.
152         new_dist[(i + 255) & 255u] = new_val;  // Move the zero symbol last.
153         memoryBarrierShared();
154         barrier();
155
156         new_val = new_dist[i];
157
158         // TODO: Why do we need this next barrier? It makes no sense.
159         memoryBarrierShared();
160         barrier();
161
162         for (uint layer = 2; layer <= 256; layer *= 2) {
163                 if ((i & (layer - 1)) == layer - 1) {
164                         new_dist[i] += new_dist[i - (layer / 2)];
165                 }
166                 memoryBarrierShared();
167                 barrier();
168         }
169         for (uint layer = 128; layer >= 2; layer /= 2) {
170                 if ((i & (layer - 1)) == layer - 1 && i != 255) {
171                         new_dist[i + (layer / 2)] += new_dist[i];
172                 }
173                 memoryBarrierShared();
174                 barrier();
175         }
176
177         uint start = new_dist[i] - new_val;
178         uint freq = new_val;
179
180         uint x_max = ((RANS_BYTE_L >> (prob_bits + 1)) << 8) * freq;
181         uint cmpl_freq = ((1 << (prob_bits + 1)) - freq);
182         uint rcp_freq, rcp_shift, bias;
183         if (freq < 2) {
184                 rcp_freq = ~0u;
185                 rcp_shift = 0;
186                 bias = start + (1 << (prob_bits + 1)) - 1;
187         } else {
188                 uint shift = 0;
189                 while (freq > (1u << shift)) {
190                         shift++;
191                 }
192
193                 rcp_freq = uint(((uint64_t(1) << (shift + 31)) + freq-1) / freq);
194                 rcp_shift = shift - 1;
195                 bias = start;
196         }
197
198         ransfreq[base + i] = freq;
199         ransdist[base + i] = uvec4(x_max, rcp_freq, bias, (cmpl_freq << 16) | rcp_shift);
200
201         if (i == 255) {
202                 sign_biases[gl_WorkGroupID.x] = new_dist[i];
203         }
204 }