]> git.sesse.net Git - narabu/blob - tally.shader
Make the encoder 100% GPU. Not working yet, though.
[narabu] / tally.shader
1 #version 440
2
3 // http://cbloomrants.blogspot.no/2014/02/02-11-14-understanding-ans-10.html
4
5 layout(local_size_x = 256) in;
6
7 layout(std430, binding = 9) buffer layoutName
8 {
9         uint dist[4 * 256];
10         uvec2 ransdist[4 * 256];
11 };
12
13 const uint prob_bits = 12;
14 const uint prob_scale = 1 << prob_bits;
15
16 // FIXME: should come through a uniform.
17 const uint sums[4] = { 57600, 115200, 302400, 446400 };
18
19 shared uint new_dist[256];
20 shared uint shared_tmp = 0;
21
22 // Multiple voting areas for the min/max reductions, so we don't need to clear them
23 // for every round.
24 shared uint voting_areas[256];
25
26 void main()
27 {
28         uint base = 256 * gl_WorkGroupID.x;
29         uint sum = sums[gl_WorkGroupID.x];
30         uint zero_correction;
31         uint i = gl_LocalInvocationID.x;  // Which element we are working on.
32
33         // Zero has no sign bit. Yes, this is trickery.
34         if (i == 0) {
35                 uint old_zero_freq = dist[base];
36                 uint new_zero_freq = (old_zero_freq + 1) / 2;  // Rounding up ensures we get a zero frequency.
37                 zero_correction = old_zero_freq - new_zero_freq;
38                 shared_tmp = sum - zero_correction;
39                 dist[base] = new_zero_freq;
40         }
41
42         memoryBarrierShared();
43         barrier();
44
45         sum = shared_tmp;
46
47         // Normalize the pdf, taking care to never make a nonzero frequency into a zero.
48         // This rounding is presumably optimal according to cbloom.
49         float true_prob = float(dist[base + i]) / sum;
50         float from_scaled = true_prob * prob_scale;
51         float down = floor(from_scaled);
52         uint new_val = (from_scaled * from_scaled <= down * (down + 1)) ? uint(down) : uint(down) + 1;
53         if (dist[base + i] > 0) {
54                 new_val = max(new_val, 1);
55         }
56
57         // Reset shared_tmp. We could do without this barrier if we wanted to,
58         // but meh. :-)
59         if (i == 0) {
60                 shared_tmp = 0;
61         }
62
63         memoryBarrierShared();
64         barrier();
65
66         // Parallel sum to find the total.
67         atomicAdd(shared_tmp, new_val);
68
69         memoryBarrierShared();
70         barrier();
71
72         uint actual_sum = shared_tmp;
73
74         // Apply corrections one by one, greedily, until we are at the exact right sum.
75         if (actual_sum > prob_scale) {
76                 float loss = -true_prob * log2(new_val / (new_val - 1));
77
78                 voting_areas[i] = 0xffffffff;
79                 memoryBarrierShared();
80                 barrier();
81
82                 uint vote_no = 0;
83
84                 for ( ; actual_sum != prob_scale; --actual_sum, ++vote_no) {
85                         memoryBarrierShared();
86                         barrier();
87
88                         // Stick the thread ID in the lower mantissa bits so we never get a tie.
89                         uint my_vote = (floatBitsToUint(loss) & ~0xffu) | gl_LocalInvocationID.x;
90                         if (new_val <= 1) {
91                                 // We can't touch this one any more, but it needs to participate in the barriers,
92                                 // so we can't break.
93                                 my_vote = 0xffffffff;
94                         }
95
96                         // Find out which thread has the one with the smallest loss.
97                         // (Positive floats compare like uints just fine.)
98                         voting_areas[vote_no] = atomicMin(voting_areas[vote_no], my_vote);
99                         memoryBarrierShared();
100                         barrier();
101
102                         if (my_vote == voting_areas[vote_no]) {
103                                 --new_val;
104                                 loss = -true_prob * log2(new_val / (new_val - 1));
105                         }
106                 }
107         } else {
108                 float benefit = true_prob * log2(new_val / (new_val + 1));
109
110                 voting_areas[i] = 0;
111                 memoryBarrierShared();
112                 barrier();
113
114                 uint vote_no = 0;
115
116                 for ( ; actual_sum != prob_scale; ++actual_sum, ++vote_no) {
117                         // Stick the thread ID in the lower mantissa bits so we never get a tie.
118                         uint my_vote = (floatBitsToUint(benefit) & ~0xffu) | gl_LocalInvocationID.x;
119                         if (new_val == 0) {
120                                 // It's meaningless to increase this, but it needs to participate in the barriers,
121                                 // so we can't break.
122                                 my_vote = 0;
123                         }
124
125                         // Find out which thread has the one with the most benefit.
126                         // (Positive floats compare like uints just fine.)
127                         voting_areas[vote_no] = atomicMax(voting_areas[vote_no], my_vote);
128                         memoryBarrierShared();
129                         barrier();
130
131                         if (my_vote == voting_areas[vote_no]) {
132                                 ++new_val;
133                                 benefit = true_prob * log2(new_val / (new_val + 1));
134                         }
135                 }
136         }
137
138         if (i == 0) {
139                 // Undo what we did above.
140                 new_val *= 2;
141         }
142
143         // Parallel prefix sum.
144         new_dist[(i + 255) & 255u] = new_val;  // Move the zero symbol last.
145         memoryBarrierShared();
146         barrier();
147
148         for (uint layer = 2; layer <= 256; layer *= 2) {
149                 if ((i & (layer - 1)) == layer - 1) {
150                         new_dist[i] += new_dist[i - (layer / 2)];
151                 }
152                 memoryBarrierShared();
153                 barrier();
154         }
155         for (uint layer = 128; layer >= 2; layer /= 2) {
156                 if ((i & (layer - 1)) == layer - 1 && i != 255) {
157                         new_dist[i + (layer / 2)] += new_dist[i];
158                 }
159                 memoryBarrierShared();
160                 barrier();
161         }
162         ransdist[base + i] = uvec2(new_val, new_dist[i]);
163 }