11 static void panic(const char *fmt, ...)
16 fputs("Error: ", stderr);
17 vfprintf(stderr, fmt, arg);
24 static uint8_t* read_file(char const* filename, size_t* out_size)
26 FILE* f = fopen(filename, "rb");
28 panic("file not found: %s\n", filename);
30 fseek(f, 0, SEEK_END);
31 size_t size = ftell(f);
32 fseek(f, 0, SEEK_SET);
34 uint8_t* buf = new uint8_t[size];
35 if (fread(buf, size, 1, f) != 1)
36 panic("read failed\n");
49 static const int LOG2NSYMS = 8;
50 static const int NSYMS = 1 << LOG2NSYMS;
52 uint32_t freqs[NSYMS];
53 uint32_t cum_freqs[NSYMS + 1];
56 uint32_t divider[NSYMS];
57 uint32_t slot_adjust[NSYMS*2];
58 uint32_t slot_freqs[NSYMS*2];
59 uint8_t sym_id[NSYMS*2];
62 uint32_t* alias_remap;
64 SymbolStats() : alias_remap(0) {}
65 ~SymbolStats() { delete[] alias_remap; }
67 void count_freqs(uint8_t const* in, size_t nbytes);
68 void calc_cum_freqs();
69 void normalize_freqs(uint32_t target_total);
71 void make_alias_table();
74 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
76 for (int i=0; i < NSYMS; i++)
79 for (size_t i=0; i < nbytes; i++)
83 void SymbolStats::calc_cum_freqs()
86 for (int i=0; i < NSYMS; i++)
87 cum_freqs[i+1] = cum_freqs[i] + freqs[i];
90 void SymbolStats::normalize_freqs(uint32_t target_total)
92 assert(target_total >= NSYMS);
95 uint32_t cur_total = cum_freqs[NSYMS];
97 // resample distribution based on cumulative freqs
98 for (int i = 1; i <= NSYMS; i++)
99 cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
101 // if we nuked any non-0 frequency symbol to 0, we need to steal
102 // the range to make the frequency nonzero from elsewhere.
104 // this is not at all optimal, i'm just doing the first thing that comes to mind.
105 for (int i=0; i < NSYMS; i++) {
106 if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
107 // symbol i was set to zero freq
109 // find best symbol to steal frequency from (try to steal from low-freq ones)
110 uint32_t best_freq = ~0u;
112 for (int j=0; j < NSYMS; j++) {
113 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
114 if (freq > 1 && freq < best_freq) {
119 assert(best_steal != -1);
121 // and steal from it!
122 if (best_steal < i) {
123 for (int j = best_steal + 1; j <= i; j++)
126 assert(best_steal > i);
127 for (int j = i + 1; j <= best_steal; j++)
133 // calculate updated freqs and make sure we didn't screw anything up
134 assert(cum_freqs[0] == 0 && cum_freqs[NSYMS] == target_total);
135 for (int i=0; i < NSYMS; i++) {
137 assert(cum_freqs[i+1] == cum_freqs[i]);
139 assert(cum_freqs[i+1] > cum_freqs[i]);
142 freqs[i] = cum_freqs[i+1] - cum_freqs[i];
146 // Set up the alias table.
147 void SymbolStats::make_alias_table()
149 // verify that our distribution sum divides the number of buckets
150 uint32_t sum = cum_freqs[NSYMS];
151 assert(sum != 0 && (sum % NSYMS) == 0);
152 assert(sum >= NSYMS);
154 // target size in every bucket
155 uint32_t tgt_sum = sum / NSYMS;
157 // okay, prepare a sweep of vose's algorithm to distribute
158 // the symbols into buckets
159 uint32_t remaining[NSYMS];
160 for (int i=0; i < NSYMS; i++) {
161 remaining[i] = freqs[i];
162 divider[i] = tgt_sum;
167 // a "small" symbol is one with less than tgt_sum slots left to distribute
168 // a "large" symbol is one with >=tgt_sum slots.
169 // find initial small/large buckets
172 while (cur_large < NSYMS && remaining[cur_large] < tgt_sum)
174 while (cur_small < NSYMS && remaining[cur_small] >= tgt_sum)
177 // cur_small is definitely a small bucket
178 // next_small *might* be.
179 int next_small = cur_small + 1;
181 // top up small buckets from large buckets until we're done
182 // this might turn the large bucket we stole from into a small bucket itself.
183 while (cur_large < NSYMS && cur_small < NSYMS) {
184 // this bucket is split between cur_small and cur_large
185 sym_id[cur_small*2 + 0] = cur_large;
186 divider[cur_small] = remaining[cur_small];
188 // take the amount we took out of cur_large's bucket
189 remaining[cur_large] -= tgt_sum - divider[cur_small];
191 // if the large bucket is still large *or* we haven't processed it yet...
192 if (remaining[cur_large] >= tgt_sum || next_small <= cur_large) {
193 // find the next small bucket to process
194 cur_small = next_small;
195 while (cur_small < NSYMS && remaining[cur_small] >= tgt_sum)
197 next_small = cur_small + 1;
198 } else // the large bucket we just made small is behind us, need to back-track
199 cur_small = cur_large;
201 // if cur_large isn't large anymore, forward to a bucket that is
202 while (cur_large < NSYMS && remaining[cur_large] < tgt_sum)
206 // okay, we now have our alias mapping; distribute the code slots in order
207 uint32_t assigned[NSYMS] = { 0 };
208 alias_remap = new uint32_t[sum];
210 for (int i=0; i < NSYMS; i++) {
211 int j = sym_id[i*2 + 0];
212 uint32_t sym0_height = divider[i];
213 uint32_t sym1_height = tgt_sum - divider[i];
214 uint32_t base0 = assigned[i];
215 uint32_t base1 = assigned[j];
216 uint32_t cbase0 = cum_freqs[i] + base0;
217 uint32_t cbase1 = cum_freqs[j] + base1;
219 divider[i] = i*tgt_sum + sym0_height;
221 slot_freqs[i*2 + 1] = freqs[i];
222 slot_freqs[i*2 + 0] = freqs[j];
223 slot_adjust[i*2 + 1] = i*tgt_sum - base0;
224 slot_adjust[i*2 + 0] = i*tgt_sum - (base1 - sym0_height);
225 for (uint32_t k=0; k < sym0_height; k++)
226 alias_remap[cbase0 + k] = k + i*tgt_sum;
227 for (uint32_t k=0; k < sym1_height; k++)
228 alias_remap[cbase1 + k] = (k + sym0_height) + i*tgt_sum;
230 assigned[i] += sym0_height;
231 assigned[j] += sym1_height;
234 // check that each symbol got the number of slots it needed
235 for (int i=0; i < NSYMS; i++)
236 assert(assigned[i] == freqs[i]);
239 // ---- rANS encoding/decoding with alias table
241 static inline void RansEncPutAlias(RansState* r, uint8_t** pptr, SymbolStats* const syms, int s, uint32_t scale_bits)
244 uint32_t freq = syms->freqs[s];
245 RansState x = RansEncRenorm(*r, pptr, freq, scale_bits);
248 // NOTE: alias_remap here could be replaced with e.g. a binary search.
249 *r = ((x / freq) << scale_bits) + syms->alias_remap[(x % freq) + syms->cum_freqs[s]];
252 static inline uint32_t RansDecGetAlias(RansState* r, SymbolStats* const syms, uint32_t scale_bits)
256 // figure out symbol via alias table
257 uint32_t mask = (1u << scale_bits) - 1; // constant for fixed scale_bits!
258 uint32_t xm = x & mask;
259 uint32_t bucket_id = xm >> (scale_bits - SymbolStats::LOG2NSYMS);
260 uint32_t bucket2 = bucket_id * 2;
261 if (xm < syms->divider[bucket_id])
265 *r = syms->slot_freqs[bucket2] * (x >> scale_bits) + xm - syms->slot_adjust[bucket2];
266 return syms->sym_id[bucket2];
274 uint8_t* in_bytes = read_file("book1", &in_size);
276 static const uint32_t prob_bits = 16;
277 static const uint32_t prob_scale = 1 << prob_bits;
280 stats.count_freqs(in_bytes, in_size);
281 stats.normalize_freqs(prob_scale);
282 stats.make_alias_table();
284 static size_t out_max_size = 32<<20; // 32MB
285 uint8_t* out_buf = new uint8_t[out_max_size];
286 uint8_t* dec_bytes = new uint8_t[in_size];
291 // ---- regular rANS encode/decode. Typical usage.
293 memset(dec_bytes, 0xcc, in_size);
295 printf("rANS encode:\n");
296 for (int run=0; run < 5; run++) {
297 double start_time = timer();
298 uint64_t enc_start_time = __rdtsc();
303 uint8_t* ptr = out_buf + out_max_size; // *end* of output buffer
304 for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
305 int s = in_bytes[i-1];
306 RansEncPutAlias(&rans, &ptr, &stats, s, prob_bits);
308 RansEncFlush(&rans, &ptr);
311 uint64_t enc_clocks = __rdtsc() - enc_start_time;
312 double enc_time = timer() - start_time;
313 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
315 printf("rANS: %d bytes\n", (int) (out_buf + out_max_size - rans_begin));
318 for (int run=0; run < 5; run++) {
319 double start_time = timer();
320 uint64_t dec_start_time = __rdtsc();
323 uint8_t* ptr = rans_begin;
324 RansDecInit(&rans, &ptr);
326 for (size_t i=0; i < in_size; i++) {
327 uint32_t s = RansDecGetAlias(&rans, &stats, prob_bits);
328 dec_bytes[i] = (uint8_t) s;
329 RansDecRenorm(&rans, &ptr);
332 uint64_t dec_clocks = __rdtsc() - dec_start_time;
333 double dec_time = timer() - start_time;
334 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
337 // check decode results
338 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
339 printf("decode ok!\n");
341 printf("ERROR: bad decoder!\n");
343 // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
345 memset(dec_bytes, 0xcc, in_size);
347 // try interleaved rANS encode
348 printf("\ninterleaved rANS encode:\n");
349 for (int run=0; run < 5; run++) {
350 double start_time = timer();
351 uint64_t enc_start_time = __rdtsc();
353 RansState rans0, rans1;
357 uint8_t* ptr = out_buf + out_max_size; // *end* of output buffer
359 // odd number of bytes?
361 int s = in_bytes[in_size - 1];
362 RansEncPutAlias(&rans0, &ptr, &stats, s, prob_bits);
365 for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
366 int s1 = in_bytes[i-1];
367 int s0 = in_bytes[i-2];
368 RansEncPutAlias(&rans1, &ptr, &stats, s1, prob_bits);
369 RansEncPutAlias(&rans0, &ptr, &stats, s0, prob_bits);
371 RansEncFlush(&rans1, &ptr);
372 RansEncFlush(&rans0, &ptr);
375 uint64_t enc_clocks = __rdtsc() - enc_start_time;
376 double enc_time = timer() - start_time;
377 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
379 printf("interleaved rANS: %d bytes\n", (int) (out_buf + out_max_size - rans_begin));
381 // try interleaved rANS decode
382 for (int run=0; run < 5; run++) {
383 double start_time = timer();
384 uint64_t dec_start_time = __rdtsc();
386 RansState rans0, rans1;
387 uint8_t* ptr = rans_begin;
388 RansDecInit(&rans0, &ptr);
389 RansDecInit(&rans1, &ptr);
391 for (size_t i=0; i < (in_size & ~1); i += 2) {
392 uint32_t s0 = RansDecGetAlias(&rans0, &stats, prob_bits);
393 uint32_t s1 = RansDecGetAlias(&rans1, &stats, prob_bits);
394 dec_bytes[i+0] = (uint8_t) s0;
395 dec_bytes[i+1] = (uint8_t) s1;
396 RansDecRenorm(&rans0, &ptr);
397 RansDecRenorm(&rans1, &ptr);
400 // last byte, if number of bytes was odd
402 uint32_t s0 = RansDecGetAlias(&rans0, &stats, prob_bits);
403 dec_bytes[in_size - 1] = (uint8_t) s0;
404 RansDecRenorm(&rans0, &ptr);
407 uint64_t dec_clocks = __rdtsc() - dec_start_time;
408 double dec_time = timer() - start_time;
409 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
412 // check decode results
413 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
414 printf("decode ok!\n");
416 printf("ERROR: bad decoder!\n");