10 #include "renormalize.h"
12 // This is just the sample program. All the meat is in rans_byte.h.
14 static void panic(const char *fmt, ...)
19 fputs("Error: ", stderr);
20 vfprintf(stderr, fmt, arg);
27 static uint8_t* read_file(char const* filename, size_t* out_size)
29 FILE* f = fopen(filename, "rb");
31 panic("file not found: %s\n", filename);
33 fseek(f, 0, SEEK_END);
34 size_t size = ftell(f);
35 fseek(f, 0, SEEK_SET);
37 uint8_t* buf = new uint8_t[size];
38 if (fread(buf, size, 1, f) != 1)
39 panic("read failed\n");
53 uint32_t cum_freqs[257];
55 void count_freqs(uint8_t const* in, size_t nbytes);
56 void calc_cum_freqs();
57 void normalize_freqs(uint32_t target_total);
60 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
62 for (int i=0; i < 256; i++)
65 for (size_t i=0; i < nbytes; i++)
69 void SymbolStats::calc_cum_freqs()
72 for (int i=0; i < 256; i++)
73 cum_freqs[i+1] = cum_freqs[i] + freqs[i];
76 void SymbolStats::normalize_freqs(uint32_t target_total)
78 assert(target_total >= 256);
82 OptimalRenormalize(cum_freqs, 256, target_total);
84 // calculate updated freqs and make sure we didn't screw anything up
85 assert(cum_freqs[0] == 0 && cum_freqs[256] == target_total);
86 for (int i=0; i < 256; i++) {
88 assert(cum_freqs[i+1] == cum_freqs[i]);
90 assert(cum_freqs[i+1] > cum_freqs[i]);
93 freqs[i] = cum_freqs[i+1] - cum_freqs[i];
100 uint8_t* in_bytes = read_file("book1", &in_size);
102 static const uint32_t prob_bits = 14;
103 static const uint32_t prob_scale = 1 << prob_bits;
106 stats.count_freqs(in_bytes, in_size);
107 stats.normalize_freqs(prob_scale);
109 // cumlative->symbol table
110 // this is super brute force
111 uint8_t cum2sym[prob_scale];
112 for (int s=0; s < 256; s++)
113 for (uint32_t i=stats.cum_freqs[s]; i < stats.cum_freqs[s+1]; i++)
116 static size_t out_max_size = 32<<20; // 32MB
117 uint8_t* out_buf = new uint8_t[out_max_size];
118 uint8_t* dec_bytes = new uint8_t[in_size];
122 RansEncSymbol esyms[256];
123 RansDecSymbol dsyms[256];
125 for (int i=0; i < 256; i++) {
126 RansEncSymbolInit(&esyms[i], stats.cum_freqs[i], stats.freqs[i], prob_bits);
127 RansDecSymbolInit(&dsyms[i], stats.cum_freqs[i], stats.freqs[i]);
130 // ---- regular rANS encode/decode. Typical usage.
132 memset(dec_bytes, 0xcc, in_size);
134 printf("rANS encode:\n");
135 for (int run=0; run < 5; run++) {
136 double start_time = timer();
137 uint64_t enc_start_time = __rdtsc();
142 uint8_t* ptr = out_buf + out_max_size; // *end* of output buffer
143 for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
144 int s = in_bytes[i-1];
145 RansEncPutSymbol(&rans, &ptr, &esyms[s]);
147 RansEncFlush(&rans, &ptr);
150 uint64_t enc_clocks = __rdtsc() - enc_start_time;
151 double enc_time = timer() - start_time;
152 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
154 printf("rANS: %d bytes\n", (int) (out_buf + out_max_size - rans_begin));
157 for (int run=0; run < 5; run++) {
158 double start_time = timer();
159 uint64_t dec_start_time = __rdtsc();
162 uint8_t* ptr = rans_begin;
163 RansDecInit(&rans, &ptr);
165 for (size_t i=0; i < in_size; i++) {
166 uint32_t s = cum2sym[RansDecGet(&rans, prob_bits)];
167 dec_bytes[i] = (uint8_t) s;
168 RansDecAdvanceSymbol(&rans, &ptr, &dsyms[s], prob_bits);
171 uint64_t dec_clocks = __rdtsc() - dec_start_time;
172 double dec_time = timer() - start_time;
173 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
176 // check decode results
177 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
178 printf("decode ok!\n");
180 printf("ERROR: bad decoder!\n");
182 // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
184 memset(dec_bytes, 0xcc, in_size);
186 // try interleaved rANS encode
187 printf("\ninterleaved rANS encode:\n");
188 for (int run=0; run < 5; run++) {
189 double start_time = timer();
190 uint64_t enc_start_time = __rdtsc();
192 RansState rans0, rans1;
196 uint8_t* ptr = out_buf + out_max_size; // *end* of output buffer
198 // odd number of bytes?
200 int s = in_bytes[in_size - 1];
201 RansEncPutSymbol(&rans0, &ptr, &esyms[s]);
204 for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
205 int s1 = in_bytes[i-1];
206 int s0 = in_bytes[i-2];
207 RansEncPutSymbol(&rans1, &ptr, &esyms[s1]);
208 RansEncPutSymbol(&rans0, &ptr, &esyms[s0]);
210 RansEncFlush(&rans1, &ptr);
211 RansEncFlush(&rans0, &ptr);
214 uint64_t enc_clocks = __rdtsc() - enc_start_time;
215 double enc_time = timer() - start_time;
216 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
218 printf("interleaved rANS: %d bytes\n", (int) (out_buf + out_max_size - rans_begin));
220 // try interleaved rANS decode
221 for (int run=0; run < 5; run++) {
222 double start_time = timer();
223 uint64_t dec_start_time = __rdtsc();
225 RansState rans0, rans1;
226 uint8_t* ptr = rans_begin;
227 RansDecInit(&rans0, &ptr);
228 RansDecInit(&rans1, &ptr);
230 for (size_t i=0; i < (in_size & ~1); i += 2) {
231 uint32_t s0 = cum2sym[RansDecGet(&rans0, prob_bits)];
232 uint32_t s1 = cum2sym[RansDecGet(&rans1, prob_bits)];
233 dec_bytes[i+0] = (uint8_t) s0;
234 dec_bytes[i+1] = (uint8_t) s1;
235 RansDecAdvanceSymbolStep(&rans0, &dsyms[s0], prob_bits);
236 RansDecAdvanceSymbolStep(&rans1, &dsyms[s1], prob_bits);
237 RansDecRenorm(&rans0, &ptr);
238 RansDecRenorm(&rans1, &ptr);
241 // last byte, if number of bytes was odd
243 uint32_t s0 = cum2sym[RansDecGet(&rans0, prob_bits)];
244 dec_bytes[in_size - 1] = (uint8_t) s0;
245 RansDecAdvanceSymbol(&rans0, &ptr, &dsyms[s0], prob_bits);
248 uint64_t dec_clocks = __rdtsc() - dec_start_time;
249 double dec_time = timer() - start_time;
250 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
253 // check decode results
254 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
255 printf("decode ok!\n");
257 printf("ERROR: bad decoder!\n");