10 #include "renormalize.h"
12 // This is just the sample program. All the meat is in rans_byte.h.
14 static void panic(const char *fmt, ...)
19 fputs("Error: ", stderr);
20 vfprintf(stderr, fmt, arg);
27 static uint8_t* read_file(char const* filename, size_t* out_size)
29 FILE* f = fopen(filename, "rb");
31 panic("file not found: %s\n", filename);
33 fseek(f, 0, SEEK_END);
34 size_t size = ftell(f);
35 fseek(f, 0, SEEK_SET);
37 uint8_t* buf = new uint8_t[size];
38 if (fread(buf, size, 1, f) != 1)
39 panic("read failed\n");
53 uint32_t cum_freqs[257];
55 void count_freqs(uint8_t const* in, size_t nbytes);
56 void calc_cum_freqs();
57 void normalize_freqs(uint32_t target_total);
60 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
62 for (int i=0; i < 256; i++)
65 for (size_t i=0; i < nbytes; i++)
69 void SymbolStats::calc_cum_freqs()
72 for (int i=0; i < 256; i++)
73 cum_freqs[i+1] = cum_freqs[i] + freqs[i];
76 void SymbolStats::normalize_freqs(uint32_t target_total)
78 assert(target_total >= 256);
82 OptimalRenormalize(cum_freqs, 256, target_total);
84 // calculate updated freqs and make sure we didn't screw anything up
85 assert(cum_freqs[0] == 0 && cum_freqs[256] == target_total);
86 for (int i=0; i < 256; i++) {
88 assert(cum_freqs[i+1] == cum_freqs[i]);
90 assert(cum_freqs[i+1] > cum_freqs[i]);
93 freqs[i] = cum_freqs[i+1] - cum_freqs[i];
100 uint8_t* in_bytes = read_file("book1", &in_size);
102 static const uint32_t prob_bits = 14;
103 static const uint32_t prob_scale = 1 << prob_bits;
106 stats.count_freqs(in_bytes, in_size);
107 stats.normalize_freqs(prob_scale);
109 // cumlative->symbol table
110 // this is super brute force
111 uint8_t cum2sym[prob_scale];
112 for (int s=0; s < 256; s++)
113 for (uint32_t i=stats.cum_freqs[s]; i < stats.cum_freqs[s+1]; i++)
116 static size_t out_max_size = 32<<20; // 32MB
117 static size_t out_max_elems = out_max_size / sizeof(uint32_t);
118 uint32_t* out_buf = new uint32_t[out_max_elems];
119 uint32_t* out_end = out_buf + out_max_elems;
120 uint8_t* dec_bytes = new uint8_t[in_size];
123 uint32_t *rans_begin;
124 Rans64EncSymbol esyms[256];
125 Rans64DecSymbol dsyms[256];
127 for (int i=0; i < 256; i++) {
128 Rans64EncSymbolInit(&esyms[i], stats.cum_freqs[i], stats.freqs[i], prob_bits);
129 Rans64DecSymbolInit(&dsyms[i], stats.cum_freqs[i], stats.freqs[i]);
132 // ---- regular rANS encode/decode. Typical usage.
134 memset(dec_bytes, 0xcc, in_size);
136 printf("rANS encode:\n");
137 for (int run=0; run < 5; run++) {
138 double start_time = timer();
139 uint64_t enc_start_time = __rdtsc();
142 Rans64EncInit(&rans);
144 uint32_t* ptr = out_end; // *end* of output buffer
145 for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
146 int s = in_bytes[i-1];
147 Rans64EncPutSymbol(&rans, &ptr, &esyms[s], prob_bits);
149 Rans64EncFlush(&rans, &ptr);
152 uint64_t enc_clocks = __rdtsc() - enc_start_time;
153 double enc_time = timer() - start_time;
154 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
156 printf("rANS: %d bytes\n", (int) ((out_end - rans_begin) * sizeof(uint32_t)));
159 for (int run=0; run < 5; run++) {
160 double start_time = timer();
161 uint64_t dec_start_time = __rdtsc();
164 uint32_t* ptr = rans_begin;
165 Rans64DecInit(&rans, &ptr);
167 for (size_t i=0; i < in_size; i++) {
168 uint32_t s = cum2sym[Rans64DecGet(&rans, prob_bits)];
169 dec_bytes[i] = (uint8_t) s;
170 Rans64DecAdvanceSymbol(&rans, &ptr, &dsyms[s], prob_bits);
173 uint64_t dec_clocks = __rdtsc() - dec_start_time;
174 double dec_time = timer() - start_time;
175 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
178 // check decode results
179 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
180 printf("decode ok!\n");
182 printf("ERROR: bad decoder!\n");
184 // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
186 memset(dec_bytes, 0xcc, in_size);
188 // try interleaved rANS encode
189 printf("\ninterleaved rANS encode:\n");
190 for (int run=0; run < 5; run++) {
191 double start_time = timer();
192 uint64_t enc_start_time = __rdtsc();
194 Rans64State rans0, rans1;
195 Rans64EncInit(&rans0);
196 Rans64EncInit(&rans1);
198 uint32_t* ptr = out_end;
200 // odd number of bytes?
202 int s = in_bytes[in_size - 1];
203 Rans64EncPutSymbol(&rans0, &ptr, &esyms[s], prob_bits);
206 for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
207 int s1 = in_bytes[i-1];
208 int s0 = in_bytes[i-2];
209 Rans64EncPutSymbol(&rans1, &ptr, &esyms[s1], prob_bits);
210 Rans64EncPutSymbol(&rans0, &ptr, &esyms[s0], prob_bits);
212 Rans64EncFlush(&rans1, &ptr);
213 Rans64EncFlush(&rans0, &ptr);
216 uint64_t enc_clocks = __rdtsc() - enc_start_time;
217 double enc_time = timer() - start_time;
218 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
220 printf("interleaved rANS: %d bytes\n", (int) ((out_end - rans_begin) * sizeof(uint32_t)));
222 // try interleaved rANS decode
223 for (int run=0; run < 5; run++) {
224 double start_time = timer();
225 uint64_t dec_start_time = __rdtsc();
227 Rans64State rans0, rans1;
228 uint32_t* ptr = rans_begin;
229 Rans64DecInit(&rans0, &ptr);
230 Rans64DecInit(&rans1, &ptr);
232 for (size_t i=0; i < (in_size & ~1); i += 2) {
233 uint32_t s0 = cum2sym[Rans64DecGet(&rans0, prob_bits)];
234 uint32_t s1 = cum2sym[Rans64DecGet(&rans1, prob_bits)];
235 dec_bytes[i+0] = (uint8_t) s0;
236 dec_bytes[i+1] = (uint8_t) s1;
237 Rans64DecAdvanceSymbolStep(&rans0, &dsyms[s0], prob_bits);
238 Rans64DecAdvanceSymbolStep(&rans1, &dsyms[s1], prob_bits);
239 Rans64DecRenorm(&rans0, &ptr);
240 Rans64DecRenorm(&rans1, &ptr);
243 // last byte, if number of bytes was odd
245 uint32_t s0 = cum2sym[Rans64DecGet(&rans0, prob_bits)];
246 dec_bytes[in_size - 1] = (uint8_t) s0;
247 Rans64DecAdvanceSymbol(&rans0, &ptr, &dsyms[s0], prob_bits);
250 uint64_t dec_clocks = __rdtsc() - dec_start_time;
251 double dec_time = timer() - start_time;
252 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
255 // check decode results
256 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
257 printf("decode ok!\n");
259 printf("ERROR: bad decoder!\n");