11 // This is just the sample program. All the meat is in rans_byte.h.
13 static void panic(const char *fmt, ...)
18 fputs("Error: ", stderr);
19 vfprintf(stderr, fmt, arg);
26 static uint8_t* read_file(char const* filename, size_t* out_size)
28 FILE* f = fopen(filename, "rb");
30 panic("file not found: %s\n", filename);
32 fseek(f, 0, SEEK_END);
33 size_t size = ftell(f);
34 fseek(f, 0, SEEK_SET);
36 uint8_t* buf = new uint8_t[size];
37 if (fread(buf, size, 1, f) != 1)
38 panic("read failed\n");
52 uint32_t cum_freqs[257];
54 void count_freqs(uint8_t const* in, size_t nbytes);
55 void calc_cum_freqs();
56 void normalize_freqs(uint32_t target_total);
59 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
61 for (int i=0; i < 256; i++)
64 for (size_t i=0; i < nbytes; i++)
68 void SymbolStats::calc_cum_freqs()
71 for (int i=0; i < 256; i++)
72 cum_freqs[i+1] = cum_freqs[i] + freqs[i];
75 void SymbolStats::normalize_freqs(uint32_t target_total)
77 assert(target_total >= 256);
80 uint32_t cur_total = cum_freqs[256];
82 // resample distribution based on cumulative freqs
83 for (int i = 1; i <= 256; i++)
84 cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
86 // if we nuked any non-0 frequency symbol to 0, we need to steal
87 // the range to make the frequency nonzero from elsewhere.
89 // this is not at all optimal, i'm just doing the first thing that comes to mind.
90 for (int i=0; i < 256; i++) {
91 if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
92 // symbol i was set to zero freq
94 // find best symbol to steal frequency from (try to steal from low-freq ones)
95 uint32_t best_freq = ~0u;
97 for (int j=0; j < 256; j++) {
98 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
99 if (freq > 1 && freq < best_freq) {
104 assert(best_steal != -1);
106 // and steal from it!
107 if (best_steal < i) {
108 for (int j = best_steal + 1; j <= i; j++)
111 assert(best_steal > i);
112 for (int j = i + 1; j <= best_steal; j++)
118 // calculate updated freqs and make sure we didn't screw anything up
119 assert(cum_freqs[0] == 0 && cum_freqs[256] == target_total);
120 for (int i=0; i < 256; i++) {
122 assert(cum_freqs[i+1] == cum_freqs[i]);
124 assert(cum_freqs[i+1] > cum_freqs[i]);
127 freqs[i] = cum_freqs[i+1] - cum_freqs[i];
134 uint8_t* in_bytes = read_file("book1", &in_size);
136 static const uint32_t prob_bits = 14;
137 static const uint32_t prob_scale = 1 << prob_bits;
140 stats.count_freqs(in_bytes, in_size);
141 stats.normalize_freqs(prob_scale);
143 // cumlative->symbol table
144 // this is super brute force
145 uint8_t cum2sym[prob_scale];
146 for (int s=0; s < 256; s++)
147 for (uint32_t i=stats.cum_freqs[s]; i < stats.cum_freqs[s+1]; i++)
150 static size_t out_max_size = 32<<20; // 32MB
151 uint8_t* out_buf = new uint8_t[out_max_size];
152 uint8_t* dec_bytes = new uint8_t[in_size];
156 RansEncSymbol esyms[256];
157 RansDecSymbol dsyms[256];
159 for (int i=0; i < 256; i++) {
160 RansEncSymbolInit(&esyms[i], stats.cum_freqs[i], stats.freqs[i], prob_bits);
161 RansDecSymbolInit(&dsyms[i], stats.cum_freqs[i], stats.freqs[i]);
164 // ---- regular rANS encode/decode. Typical usage.
166 memset(dec_bytes, 0xcc, in_size);
168 printf("rANS encode:\n");
169 for (int run=0; run < 5; run++) {
170 double start_time = timer();
171 uint64_t enc_start_time = __rdtsc();
176 uint8_t* ptr = out_buf + out_max_size; // *end* of output buffer
177 for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
178 int s = in_bytes[i-1];
179 RansEncPutSymbol(&rans, &ptr, &esyms[s]);
181 RansEncFlush(&rans, &ptr);
184 uint64_t enc_clocks = __rdtsc() - enc_start_time;
185 double enc_time = timer() - start_time;
186 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
188 printf("rANS: %d bytes\n", (int) (out_buf + out_max_size - rans_begin));
191 for (int run=0; run < 5; run++) {
192 double start_time = timer();
193 uint64_t dec_start_time = __rdtsc();
196 uint8_t* ptr = rans_begin;
197 RansDecInit(&rans, &ptr);
199 for (size_t i=0; i < in_size; i++) {
200 uint32_t s = cum2sym[RansDecGet(&rans, prob_bits)];
201 dec_bytes[i] = (uint8_t) s;
202 RansDecAdvanceSymbol(&rans, &ptr, &dsyms[s], prob_bits);
205 uint64_t dec_clocks = __rdtsc() - dec_start_time;
206 double dec_time = timer() - start_time;
207 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
210 // check decode results
211 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
212 printf("decode ok!\n");
214 printf("ERROR: bad decoder!\n");
216 // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
218 memset(dec_bytes, 0xcc, in_size);
220 // try interleaved rANS encode
221 printf("\ninterleaved rANS encode:\n");
222 for (int run=0; run < 5; run++) {
223 double start_time = timer();
224 uint64_t enc_start_time = __rdtsc();
226 RansState rans0, rans1;
230 uint8_t* ptr = out_buf + out_max_size; // *end* of output buffer
232 // odd number of bytes?
234 int s = in_bytes[in_size - 1];
235 RansEncPutSymbol(&rans0, &ptr, &esyms[s]);
238 for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
239 int s1 = in_bytes[i-1];
240 int s0 = in_bytes[i-2];
241 RansEncPutSymbol(&rans1, &ptr, &esyms[s1]);
242 RansEncPutSymbol(&rans0, &ptr, &esyms[s0]);
244 RansEncFlush(&rans1, &ptr);
245 RansEncFlush(&rans0, &ptr);
248 uint64_t enc_clocks = __rdtsc() - enc_start_time;
249 double enc_time = timer() - start_time;
250 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
252 printf("interleaved rANS: %d bytes\n", (int) (out_buf + out_max_size - rans_begin));
254 // try interleaved rANS decode
255 for (int run=0; run < 5; run++) {
256 double start_time = timer();
257 uint64_t dec_start_time = __rdtsc();
259 RansState rans0, rans1;
260 uint8_t* ptr = rans_begin;
261 RansDecInit(&rans0, &ptr);
262 RansDecInit(&rans1, &ptr);
264 for (size_t i=0; i < (in_size & ~1); i += 2) {
265 uint32_t s0 = cum2sym[RansDecGet(&rans0, prob_bits)];
266 uint32_t s1 = cum2sym[RansDecGet(&rans1, prob_bits)];
267 dec_bytes[i+0] = (uint8_t) s0;
268 dec_bytes[i+1] = (uint8_t) s1;
269 RansDecAdvanceSymbolStep(&rans0, &dsyms[s0], prob_bits);
270 RansDecAdvanceSymbolStep(&rans1, &dsyms[s1], prob_bits);
271 RansDecRenorm(&rans0, &ptr);
272 RansDecRenorm(&rans1, &ptr);
275 // last byte, if number of bytes was odd
277 uint32_t s0 = cum2sym[RansDecGet(&rans0, prob_bits)];
278 dec_bytes[in_size - 1] = (uint8_t) s0;
279 RansDecAdvanceSymbol(&rans0, &ptr, &dsyms[s0], prob_bits);
282 uint64_t dec_clocks = __rdtsc() - dec_start_time;
283 double dec_time = timer() - start_time;
284 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
287 // check decode results
288 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
289 printf("decode ok!\n");
291 printf("ERROR: bad decoder!\n");