]> git.sesse.net Git - narabu/blob - ryg_rans/main_alias.cpp
More fixes of hard-coded values.
[narabu] / ryg_rans / main_alias.cpp
1 #include "platform.h"
2 #include <stdio.h>
3 #include <stdarg.h>
4 #include <stdlib.h>
5 #include <stdint.h>
6 #include <string.h>
7 #include <assert.h>
8
9 #include "rans_byte.h"
10
11 static void panic(const char *fmt, ...)
12 {
13     va_list arg;
14
15     va_start(arg, fmt);
16     fputs("Error: ", stderr);
17     vfprintf(stderr, fmt, arg);
18     va_end(arg);
19     fputs("\n", stderr);
20
21     exit(1);
22 }
23
24 static uint8_t* read_file(char const* filename, size_t* out_size)
25 {
26     FILE* f = fopen(filename, "rb");
27     if (!f)
28         panic("file not found: %s\n", filename);
29
30     fseek(f, 0, SEEK_END);
31     size_t size = ftell(f);
32     fseek(f, 0, SEEK_SET);
33
34     uint8_t* buf = new uint8_t[size];
35     if (fread(buf, size, 1, f) != 1)
36         panic("read failed\n");
37
38     fclose(f);
39     if (out_size)
40         *out_size = size;
41
42     return buf;
43 }
44
45 // ---- Stats
46
47 struct SymbolStats
48 {
49     static const int LOG2NSYMS = 8;
50     static const int NSYMS = 1 << LOG2NSYMS;
51
52     uint32_t freqs[NSYMS];
53     uint32_t cum_freqs[NSYMS + 1];
54
55     // alias table
56     uint32_t divider[NSYMS];
57     uint32_t slot_adjust[NSYMS*2];
58     uint32_t slot_freqs[NSYMS*2];
59     uint8_t sym_id[NSYMS*2];
60
61     // for encoder
62     uint32_t* alias_remap;
63
64     SymbolStats() : alias_remap(0) {}
65     ~SymbolStats() { delete[] alias_remap; }
66
67     void count_freqs(uint8_t const* in, size_t nbytes);
68     void calc_cum_freqs();
69     void normalize_freqs(uint32_t target_total);
70
71     void make_alias_table();
72 };
73
74 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
75 {
76     for (int i=0; i < NSYMS; i++)
77         freqs[i] = 0;
78
79     for (size_t i=0; i < nbytes; i++)
80         freqs[in[i]]++;
81 }
82
83 void SymbolStats::calc_cum_freqs()
84 {
85     cum_freqs[0] = 0;
86     for (int i=0; i < NSYMS; i++)
87         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
88 }
89
90 void SymbolStats::normalize_freqs(uint32_t target_total)
91 {
92     assert(target_total >= NSYMS);
93     
94     calc_cum_freqs();
95     uint32_t cur_total = cum_freqs[NSYMS];
96     
97     // resample distribution based on cumulative freqs
98     for (int i = 1; i <= NSYMS; i++)
99         cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
100
101     // if we nuked any non-0 frequency symbol to 0, we need to steal
102     // the range to make the frequency nonzero from elsewhere.
103     //
104     // this is not at all optimal, i'm just doing the first thing that comes to mind.
105     for (int i=0; i < NSYMS; i++) {
106         if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
107             // symbol i was set to zero freq
108
109             // find best symbol to steal frequency from (try to steal from low-freq ones)
110             uint32_t best_freq = ~0u;
111             int best_steal = -1;
112             for (int j=0; j < NSYMS; j++) {
113                 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
114                 if (freq > 1 && freq < best_freq) {
115                     best_freq = freq;
116                     best_steal = j;
117                 }
118             }
119             assert(best_steal != -1);
120
121             // and steal from it!
122             if (best_steal < i) {
123                 for (int j = best_steal + 1; j <= i; j++)
124                     cum_freqs[j]--;
125             } else {
126                 assert(best_steal > i);
127                 for (int j = i + 1; j <= best_steal; j++)
128                     cum_freqs[j]++;
129             }
130         }
131     }
132
133     // calculate updated freqs and make sure we didn't screw anything up
134     assert(cum_freqs[0] == 0 && cum_freqs[NSYMS] == target_total);
135     for (int i=0; i < NSYMS; i++) {
136         if (freqs[i] == 0)
137             assert(cum_freqs[i+1] == cum_freqs[i]);
138         else
139             assert(cum_freqs[i+1] > cum_freqs[i]);
140
141         // calc updated freq
142         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
143     }
144 }
145
146 // Set up the alias table.
147 void SymbolStats::make_alias_table()
148 {
149     // verify that our distribution sum divides the number of buckets
150     uint32_t sum = cum_freqs[NSYMS];
151     assert(sum != 0 && (sum % NSYMS) == 0);
152     assert(sum >= NSYMS);
153
154     // target size in every bucket
155     uint32_t tgt_sum = sum / NSYMS;
156
157     // okay, prepare a sweep of vose's algorithm to distribute
158     // the symbols into buckets
159     uint32_t remaining[NSYMS];
160     for (int i=0; i < NSYMS; i++) {
161         remaining[i] = freqs[i];
162         divider[i] = tgt_sum;
163         sym_id[i*2 + 0] = i;
164         sym_id[i*2 + 1] = i;
165     }
166
167     // a "small" symbol is one with less than tgt_sum slots left to distribute
168     // a "large" symbol is one with >=tgt_sum slots.
169     // find initial small/large buckets
170     int cur_large = 0;
171     int cur_small = 0;
172     while (cur_large < NSYMS && remaining[cur_large] < tgt_sum)
173         cur_large++;
174     while (cur_small < NSYMS && remaining[cur_small] >= tgt_sum)
175         cur_small++;
176
177     // cur_small is definitely a small bucket
178     // next_small *might* be.
179     int next_small = cur_small + 1;
180
181     // top up small buckets from large buckets until we're done
182     // this might turn the large bucket we stole from into a small bucket itself.
183     while (cur_large < NSYMS && cur_small < NSYMS) {
184         // this bucket is split between cur_small and cur_large
185         sym_id[cur_small*2 + 0] = cur_large;
186         divider[cur_small] = remaining[cur_small];
187
188         // take the amount we took out of cur_large's bucket
189         remaining[cur_large] -= tgt_sum - divider[cur_small];
190
191         // if the large bucket is still large *or* we haven't processed it yet...
192         if (remaining[cur_large] >= tgt_sum || next_small <= cur_large) {
193             // find the next small bucket to process
194             cur_small = next_small;
195             while (cur_small < NSYMS && remaining[cur_small] >= tgt_sum)
196                 cur_small++;
197             next_small = cur_small + 1;
198         } else // the large bucket we just made small is behind us, need to back-track
199             cur_small = cur_large;
200
201         // if cur_large isn't large anymore, forward to a bucket that is
202         while (cur_large < NSYMS && remaining[cur_large] < tgt_sum)
203             cur_large++;
204     }
205
206     // okay, we now have our alias mapping; distribute the code slots in order
207     uint32_t assigned[NSYMS] = { 0 };
208     alias_remap = new uint32_t[sum];
209
210     for (int i=0; i < NSYMS; i++) {
211         int j = sym_id[i*2 + 0];
212         uint32_t sym0_height = divider[i];
213         uint32_t sym1_height = tgt_sum - divider[i];
214         uint32_t base0 = assigned[i];
215         uint32_t base1 = assigned[j];
216         uint32_t cbase0 = cum_freqs[i] + base0;
217         uint32_t cbase1 = cum_freqs[j] + base1;
218
219         divider[i] = i*tgt_sum + sym0_height;
220
221         slot_freqs[i*2 + 1] = freqs[i];
222         slot_freqs[i*2 + 0] = freqs[j];
223         slot_adjust[i*2 + 1] = i*tgt_sum - base0;
224         slot_adjust[i*2 + 0] = i*tgt_sum - (base1 - sym0_height);
225         for (uint32_t k=0; k < sym0_height; k++)
226             alias_remap[cbase0 + k] = k + i*tgt_sum;
227         for (uint32_t k=0; k < sym1_height; k++)
228             alias_remap[cbase1 + k] = (k + sym0_height) + i*tgt_sum;
229
230         assigned[i] += sym0_height;
231         assigned[j] += sym1_height;
232     }
233
234     // check that each symbol got the number of slots it needed
235     for (int i=0; i < NSYMS; i++)
236         assert(assigned[i] == freqs[i]);
237 }
238
239 // ---- rANS encoding/decoding with alias table
240
241 static inline void RansEncPutAlias(RansState* r, uint8_t** pptr, SymbolStats* const syms, int s, uint32_t scale_bits)
242 {
243     // renormalize
244     uint32_t freq = syms->freqs[s];
245     RansState x = RansEncRenorm(*r, pptr, freq, scale_bits);
246
247     // x = C(s,x)
248     // NOTE: alias_remap here could be replaced with e.g. a binary search.
249     *r = ((x / freq) << scale_bits) + syms->alias_remap[(x % freq) + syms->cum_freqs[s]];
250 }
251
252 static inline uint32_t RansDecGetAlias(RansState* r, SymbolStats* const syms, uint32_t scale_bits)
253 {
254     RansState x = *r;
255
256     // figure out symbol via alias table
257     uint32_t mask = (1u << scale_bits) - 1; // constant for fixed scale_bits!
258     uint32_t xm = x & mask;
259     uint32_t bucket_id = xm >> (scale_bits - SymbolStats::LOG2NSYMS);
260     uint32_t bucket2 = bucket_id * 2;
261     if (xm < syms->divider[bucket_id]) 
262         bucket2++;
263
264     // s, x = D(x)
265     *r = syms->slot_freqs[bucket2] * (x >> scale_bits) + xm - syms->slot_adjust[bucket2];
266     return syms->sym_id[bucket2];
267 }
268
269 // ----
270
271 int main()
272 {
273     size_t in_size;
274     uint8_t* in_bytes = read_file("book1", &in_size);
275
276     static const uint32_t prob_bits = 16;
277     static const uint32_t prob_scale = 1 << prob_bits;
278
279     SymbolStats stats;
280     stats.count_freqs(in_bytes, in_size);
281     stats.normalize_freqs(prob_scale);
282     stats.make_alias_table();
283
284     static size_t out_max_size = 32<<20; // 32MB
285     uint8_t* out_buf = new uint8_t[out_max_size];
286     uint8_t* dec_bytes = new uint8_t[in_size];
287
288     // try rANS encode
289     uint8_t *rans_begin;
290
291     // ---- regular rANS encode/decode. Typical usage.
292
293     memset(dec_bytes, 0xcc, in_size);
294
295     printf("rANS encode:\n");
296     for (int run=0; run < 5; run++) {
297         double start_time = timer();
298         uint64_t enc_start_time = __rdtsc();
299
300         RansState rans;
301         RansEncInit(&rans);
302
303         uint8_t* ptr = out_buf + out_max_size; // *end* of output buffer
304         for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
305             int s = in_bytes[i-1];
306             RansEncPutAlias(&rans, &ptr, &stats, s, prob_bits);
307         }
308         RansEncFlush(&rans, &ptr);
309         rans_begin = ptr;
310
311         uint64_t enc_clocks = __rdtsc() - enc_start_time;
312         double enc_time = timer() - start_time;
313         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
314     }
315     printf("rANS: %d bytes\n", (int) (out_buf + out_max_size - rans_begin));
316
317     // try rANS decode
318     for (int run=0; run < 5; run++) {
319         double start_time = timer();
320         uint64_t dec_start_time = __rdtsc();
321
322         RansState rans;
323         uint8_t* ptr = rans_begin;
324         RansDecInit(&rans, &ptr);
325
326         for (size_t i=0; i < in_size; i++) {
327             uint32_t s = RansDecGetAlias(&rans, &stats, prob_bits);
328             dec_bytes[i] = (uint8_t) s;
329             RansDecRenorm(&rans, &ptr);
330         }
331
332         uint64_t dec_clocks = __rdtsc() - dec_start_time;
333         double dec_time = timer() - start_time;
334         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
335     }
336
337     // check decode results
338     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
339         printf("decode ok!\n");
340     else
341         printf("ERROR: bad decoder!\n");
342
343     // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
344
345     memset(dec_bytes, 0xcc, in_size);
346
347     // try interleaved rANS encode
348     printf("\ninterleaved rANS encode:\n");
349     for (int run=0; run < 5; run++) {
350         double start_time = timer();
351         uint64_t enc_start_time = __rdtsc();
352
353         RansState rans0, rans1;
354         RansEncInit(&rans0);
355         RansEncInit(&rans1);
356
357         uint8_t* ptr = out_buf + out_max_size; // *end* of output buffer
358
359         // odd number of bytes?
360         if (in_size & 1) {
361             int s = in_bytes[in_size - 1];
362             RansEncPutAlias(&rans0, &ptr, &stats, s, prob_bits);
363         }
364
365         for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
366             int s1 = in_bytes[i-1];
367             int s0 = in_bytes[i-2];
368             RansEncPutAlias(&rans1, &ptr, &stats, s1, prob_bits);
369             RansEncPutAlias(&rans0, &ptr, &stats, s0, prob_bits);
370         }
371         RansEncFlush(&rans1, &ptr);
372         RansEncFlush(&rans0, &ptr);
373         rans_begin = ptr;
374
375         uint64_t enc_clocks = __rdtsc() - enc_start_time;
376         double enc_time = timer() - start_time;
377         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
378     }
379     printf("interleaved rANS: %d bytes\n", (int) (out_buf + out_max_size - rans_begin));
380
381     // try interleaved rANS decode
382     for (int run=0; run < 5; run++) {
383         double start_time = timer();
384         uint64_t dec_start_time = __rdtsc();
385
386         RansState rans0, rans1;
387         uint8_t* ptr = rans_begin;
388         RansDecInit(&rans0, &ptr);
389         RansDecInit(&rans1, &ptr);
390
391         for (size_t i=0; i < (in_size & ~1); i += 2) {
392             uint32_t s0 = RansDecGetAlias(&rans0, &stats, prob_bits);
393             uint32_t s1 = RansDecGetAlias(&rans1, &stats, prob_bits);
394             dec_bytes[i+0] = (uint8_t) s0;
395             dec_bytes[i+1] = (uint8_t) s1;
396             RansDecRenorm(&rans0, &ptr);
397             RansDecRenorm(&rans1, &ptr);
398         }
399
400         // last byte, if number of bytes was odd
401         if (in_size & 1) {
402             uint32_t s0 = RansDecGetAlias(&rans0, &stats, prob_bits);
403             dec_bytes[in_size - 1] = (uint8_t) s0;
404             RansDecRenorm(&rans0, &ptr);
405         }
406
407         uint64_t dec_clocks = __rdtsc() - dec_start_time;
408         double dec_time = timer() - start_time;
409         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
410     }
411
412     // check decode results
413     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
414         printf("decode ok!\n");
415     else
416         printf("ERROR: bad decoder!\n");
417
418     delete[] out_buf;
419     delete[] dec_bytes;
420     delete[] in_bytes;
421     return 0;
422 }