9 #include "rans_word_sse41.h"
11 // This is just the sample program. All the meat is in rans_byte.h.
13 static void panic(const char *fmt, ...)
18 fputs("Error: ", stderr);
19 vfprintf(stderr, fmt, arg);
26 static uint8_t* read_file(char const* filename, size_t* out_size)
28 FILE* f = fopen(filename, "rb");
30 panic("file not found: %s\n", filename);
32 fseek(f, 0, SEEK_END);
33 size_t size = ftell(f);
34 fseek(f, 0, SEEK_SET);
36 uint8_t* buf = new uint8_t[size];
37 if (fread(buf, size, 1, f) != 1)
38 panic("read failed\n");
52 uint32_t cum_freqs[257];
54 void count_freqs(uint8_t const* in, size_t nbytes);
55 void calc_cum_freqs();
56 void normalize_freqs(uint32_t target_total);
59 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
61 for (int i=0; i < 256; i++)
64 for (size_t i=0; i < nbytes; i++)
68 void SymbolStats::calc_cum_freqs()
71 for (int i=0; i < 256; i++)
72 cum_freqs[i+1] = cum_freqs[i] + freqs[i];
75 void SymbolStats::normalize_freqs(uint32_t target_total)
77 assert(target_total >= 256);
80 uint32_t cur_total = cum_freqs[256];
82 // resample distribution based on cumulative freqs
83 for (int i = 1; i <= 256; i++)
84 cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
86 // if we nuked any non-0 frequency symbol to 0, we need to steal
87 // the range to make the frequency nonzero from elsewhere.
89 // this is not at all optimal, i'm just doing the first thing that comes to mind.
90 for (int i=0; i < 256; i++) {
91 if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
92 // symbol i was set to zero freq
94 // find best symbol to steal frequency from (try to steal from low-freq ones)
95 uint32_t best_freq = ~0u;
97 for (int j=0; j < 256; j++) {
98 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
99 if (freq > 1 && freq < best_freq) {
104 assert(best_steal != -1);
106 // and steal from it!
107 if (best_steal < i) {
108 for (int j = best_steal + 1; j <= i; j++)
111 assert(best_steal > i);
112 for (int j = i + 1; j <= best_steal; j++)
118 // calculate updated freqs and make sure we didn't screw anything up
119 assert(cum_freqs[0] == 0 && cum_freqs[256] == target_total);
120 for (int i=0; i < 256; i++) {
122 assert(cum_freqs[i+1] == cum_freqs[i]);
124 assert(cum_freqs[i+1] > cum_freqs[i]);
127 freqs[i] = cum_freqs[i+1] - cum_freqs[i];
134 uint8_t* in_bytes = read_file("book1", &in_size);
137 stats.count_freqs(in_bytes, in_size);
138 stats.normalize_freqs(RANS_WORD_M);
140 // init decoding tables
142 for (int s=0; s < 256; s++)
143 RansWordTablesInitSymbol(&tab, (uint8_t)s, stats.cum_freqs[s], stats.freqs[s]);
145 size_t out_max_size = in_size + (in_size >> 3) + 128;
146 uint8_t* out_buf = new uint8_t[out_max_size + 16]; // extra bytes at end
147 uint8_t* dec_bytes = new uint8_t[in_size];
150 uint16_t *rans_begin;
152 // ---- regular rANS encode/decode. Typical usage.
154 memset(dec_bytes, 0xcc, in_size);
156 printf("rANS encode:\n");
157 for (int run=0; run < 5; run++) {
158 double start_time = timer();
159 uint64_t enc_start_time = __rdtsc();
161 RansWordEnc rans = RansWordEncInit();
163 uint16_t* ptr = (uint16_t *) (out_buf + out_max_size); // *end* of output buffer
164 for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
165 int s = in_bytes[i-1];
166 RansWordEncPut(&rans, &ptr, stats.cum_freqs[s], stats.freqs[s]);
168 RansWordEncFlush(&rans, &ptr);
171 uint64_t enc_clocks = __rdtsc() - enc_start_time;
172 double enc_time = timer() - start_time;
173 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
175 printf("rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t *)rans_begin));
178 for (int run=0; run < 5; run++) {
179 double start_time = timer();
180 uint64_t dec_start_time = __rdtsc();
183 uint16_t* ptr = rans_begin;
184 RansWordDecInit(&rans, &ptr);
186 for (size_t i=0; i < in_size; i++) {
187 uint8_t s = RansWordDecSym(&rans, &tab);
188 dec_bytes[i] = (uint8_t) s;
189 RansWordDecRenorm(&rans, &ptr);
192 uint64_t dec_clocks = __rdtsc() - dec_start_time;
193 double dec_time = timer() - start_time;
194 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
197 // check decode results
198 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
199 printf("decode ok!\n");
201 printf("ERROR: bad decoder!\n");
203 // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
205 memset(dec_bytes, 0xcc, in_size);
207 // try interleaved rANS encode
208 printf("\ninterleaved rANS encode:\n");
209 for (int run=0; run < 5; run++) {
210 double start_time = timer();
211 uint64_t enc_start_time = __rdtsc();
213 RansWordEnc rans0 = RansWordEncInit();
214 RansWordEnc rans1 = RansWordEncInit();
216 uint16_t* ptr = (uint16_t *)(out_buf + out_max_size); // *end* of output buffer
218 // odd number of bytes?
220 int s = in_bytes[in_size - 1];
221 RansWordEncPut(&rans0, &ptr, stats.cum_freqs[s], stats.freqs[s]);
224 for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
225 int s1 = in_bytes[i-1];
226 int s0 = in_bytes[i-2];
227 RansWordEncPut(&rans1, &ptr, stats.cum_freqs[s1], stats.freqs[s1]);
228 RansWordEncPut(&rans0, &ptr, stats.cum_freqs[s0], stats.freqs[s0]);
230 RansWordEncFlush(&rans1, &ptr);
231 RansWordEncFlush(&rans0, &ptr);
234 uint64_t enc_clocks = __rdtsc() - enc_start_time;
235 double enc_time = timer() - start_time;
236 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
238 printf("interleaved rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t*)rans_begin));
240 // try interleaved rANS decode
241 for (int run=0; run < 5; run++) {
242 double start_time = timer();
243 uint64_t dec_start_time = __rdtsc();
245 RansWordDec rans0, rans1;
246 uint16_t* ptr = rans_begin;
247 RansWordDecInit(&rans0, &ptr);
248 RansWordDecInit(&rans1, &ptr);
250 for (size_t i=0; i < (in_size & ~1); i += 2) {
251 uint8_t s0 = RansWordDecSym(&rans0, &tab);
252 uint8_t s1 = RansWordDecSym(&rans1, &tab);
253 dec_bytes[i+0] = (uint8_t) s0;
254 dec_bytes[i+1] = (uint8_t) s1;
255 RansWordDecRenorm(&rans0, &ptr);
256 RansWordDecRenorm(&rans1, &ptr);
259 // last byte, if number of bytes was odd
261 uint8_t s0 = RansWordDecSym(&rans0, &tab);
262 dec_bytes[in_size - 1] = (uint8_t) s0;
265 uint64_t dec_clocks = __rdtsc() - dec_start_time;
266 double dec_time = timer() - start_time;
267 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
270 // check decode results
271 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
272 printf("decode ok!\n");
274 printf("ERROR: bad decoder!\n");
276 // ---- SIMD interleaved rANS encode/decode.
278 memset(dec_bytes, 0xcc, in_size);
280 // try SIMD rANS encode
281 // this is written for clarity not speed.
282 printf("\ninterleaved SIMD rANS encode: (encode itself isn't SIMD)\n");
283 for (int run=0; run < 5; run++) {
284 double start_time = timer();
285 uint64_t enc_start_time = __rdtsc();
288 for (int i=0; i < 8; i++)
289 rans[i] = RansWordEncInit();
291 uint16_t* ptr = (uint16_t *)(out_buf + out_max_size); // *end* of output buffer
294 for (size_t i=in_size; i > 0; i--) { // NB: working in reverse
295 int s = in_bytes[i - 1];
296 RansWordEncPut(&rans[(i - 1) & 7], &ptr, stats.cum_freqs[s], stats.freqs[s]);
298 for (int i=8; i > 0; i--)
299 RansWordEncFlush(&rans[i - 1], &ptr);
302 uint64_t enc_clocks = __rdtsc() - enc_start_time;
303 double enc_time = timer() - start_time;
304 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
306 printf("SIMD rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t*)rans_begin));
308 // try SIMD rANS decode
309 for (int run=0; run < 5; run++) {
310 double start_time = timer();
311 uint64_t dec_start_time = __rdtsc();
313 RansSimdDec rans0, rans1;
314 uint16_t* ptr = rans_begin;
315 RansSimdDecInit(&rans0, &ptr);
316 RansSimdDecInit(&rans1, &ptr);
318 for (size_t i=0; i < (in_size & ~7); i += 8) {
319 uint32_t s03 = RansSimdDecSym(&rans0, &tab);
320 uint32_t s47 = RansSimdDecSym(&rans1, &tab);
321 *(uint32_t *)(dec_bytes + i) = s03;
322 *(uint32_t *)(dec_bytes + i + 4) = s47;
323 RansSimdDecRenorm(&rans0, &ptr);
324 RansSimdDecRenorm(&rans1, &ptr);
328 for (size_t i=(in_size & ~7); i < in_size; i++) {
329 RansSimdDec* which = (i & 4) != 0 ? &rans1 : &rans0;
330 uint8_t s = RansWordDecSym(&which->lane[i & 3], &tab);
334 uint64_t dec_clocks = __rdtsc() - dec_start_time;
335 double dec_time = timer() - start_time;
336 printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
339 // check decode results
340 if (memcmp(in_bytes, dec_bytes, in_size) == 0)
341 printf("decode ok!\n");
343 printf("ERROR: bad decoder!\n");