]> git.sesse.net Git - narabu/blob - ryg_rans/main_simd.cpp
9fe54ccd7fb951403a3da57035256b3fac34faa0
[narabu] / ryg_rans / main_simd.cpp
1 #include "platform.h"
2 #include <stdio.h>
3 #include <stdarg.h>
4 #include <stdlib.h>
5 #include <stdint.h>
6 #include <string.h>
7 #include <assert.h>
8
9 #include "rans_word_sse41.h"
10
11 // This is just the sample program. All the meat is in rans_byte.h.
12
13 static void panic(const char *fmt, ...)
14 {
15     va_list arg;
16
17     va_start(arg, fmt);
18     fputs("Error: ", stderr);
19     vfprintf(stderr, fmt, arg);
20     va_end(arg);
21     fputs("\n", stderr);
22
23     exit(1);
24 }
25
26 static uint8_t* read_file(char const* filename, size_t* out_size)
27 {
28     FILE* f = fopen(filename, "rb");
29     if (!f)
30         panic("file not found: %s\n", filename);
31
32     fseek(f, 0, SEEK_END);
33     size_t size = ftell(f);
34     fseek(f, 0, SEEK_SET);
35
36     uint8_t* buf = new uint8_t[size];
37     if (fread(buf, size, 1, f) != 1)
38         panic("read failed\n");
39
40     fclose(f);
41     if (out_size)
42         *out_size = size;
43
44     return buf;
45 }
46
47 // ---- Stats
48
49 struct SymbolStats
50 {
51     uint32_t freqs[256];
52     uint32_t cum_freqs[257];
53
54     void count_freqs(uint8_t const* in, size_t nbytes);
55     void calc_cum_freqs();
56     void normalize_freqs(uint32_t target_total);
57 };
58
59 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
60 {
61     for (int i=0; i < 256; i++)
62         freqs[i] = 0;
63
64     for (size_t i=0; i < nbytes; i++)
65         freqs[in[i]]++;
66 }
67
68 void SymbolStats::calc_cum_freqs()
69 {
70     cum_freqs[0] = 0;
71     for (int i=0; i < 256; i++)
72         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
73 }
74
75 void SymbolStats::normalize_freqs(uint32_t target_total)
76 {
77     assert(target_total >= 256);
78     
79     calc_cum_freqs();
80     uint32_t cur_total = cum_freqs[256];
81     
82     // resample distribution based on cumulative freqs
83     for (int i = 1; i <= 256; i++)
84         cum_freqs[i] = ((uint64_t)target_total * cum_freqs[i])/cur_total;
85
86     // if we nuked any non-0 frequency symbol to 0, we need to steal
87     // the range to make the frequency nonzero from elsewhere.
88     //
89     // this is not at all optimal, i'm just doing the first thing that comes to mind.
90     for (int i=0; i < 256; i++) {
91         if (freqs[i] && cum_freqs[i+1] == cum_freqs[i]) {
92             // symbol i was set to zero freq
93
94             // find best symbol to steal frequency from (try to steal from low-freq ones)
95             uint32_t best_freq = ~0u;
96             int best_steal = -1;
97             for (int j=0; j < 256; j++) {
98                 uint32_t freq = cum_freqs[j+1] - cum_freqs[j];
99                 if (freq > 1 && freq < best_freq) {
100                     best_freq = freq;
101                     best_steal = j;
102                 }
103             }
104             assert(best_steal != -1);
105
106             // and steal from it!
107             if (best_steal < i) {
108                 for (int j = best_steal + 1; j <= i; j++)
109                     cum_freqs[j]--;
110             } else {
111                 assert(best_steal > i);
112                 for (int j = i + 1; j <= best_steal; j++)
113                     cum_freqs[j]++;
114             }
115         }
116     }
117
118     // calculate updated freqs and make sure we didn't screw anything up
119     assert(cum_freqs[0] == 0 && cum_freqs[256] == target_total);
120     for (int i=0; i < 256; i++) {
121         if (freqs[i] == 0)
122             assert(cum_freqs[i+1] == cum_freqs[i]);
123         else
124             assert(cum_freqs[i+1] > cum_freqs[i]);
125
126         // calc updated freq
127         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
128     }
129 }
130
131 int main()
132 {
133     size_t in_size;
134     uint8_t* in_bytes = read_file("book1", &in_size);
135
136     SymbolStats stats;
137     stats.count_freqs(in_bytes, in_size);
138     stats.normalize_freqs(RANS_WORD_M);
139
140     // init decoding tables
141     RansWordTables tab;
142     for (int s=0; s < 256; s++)
143         RansWordTablesInitSymbol(&tab, (uint8_t)s, stats.cum_freqs[s], stats.freqs[s]);
144
145     size_t out_max_size = in_size + (in_size >> 3) + 128;
146     uint8_t* out_buf = new uint8_t[out_max_size + 16]; // extra bytes at end
147     uint8_t* dec_bytes = new uint8_t[in_size];
148
149     // try rANS encode
150     uint16_t *rans_begin;
151
152     // ---- regular rANS encode/decode. Typical usage.
153
154     memset(dec_bytes, 0xcc, in_size);
155
156     printf("rANS encode:\n");
157     for (int run=0; run < 5; run++) {
158         double start_time = timer();
159         uint64_t enc_start_time = __rdtsc();
160
161         RansWordEnc rans = RansWordEncInit();
162
163         uint16_t* ptr = (uint16_t *) (out_buf + out_max_size); // *end* of output buffer
164         for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
165             int s = in_bytes[i-1];
166             RansWordEncPut(&rans, &ptr, stats.cum_freqs[s], stats.freqs[s]);
167         }
168         RansWordEncFlush(&rans, &ptr);
169         rans_begin = ptr;
170
171         uint64_t enc_clocks = __rdtsc() - enc_start_time;
172         double enc_time = timer() - start_time;
173         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
174     }
175     printf("rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t *)rans_begin));
176
177     // try rANS decode
178     for (int run=0; run < 5; run++) {
179         double start_time = timer();
180         uint64_t dec_start_time = __rdtsc();
181
182         RansWordDec rans;
183         uint16_t* ptr = rans_begin;
184         RansWordDecInit(&rans, &ptr);
185
186         for (size_t i=0; i < in_size; i++) {
187             uint8_t s = RansWordDecSym(&rans, &tab);
188             dec_bytes[i] = (uint8_t) s;
189             RansWordDecRenorm(&rans, &ptr);
190         }
191
192         uint64_t dec_clocks = __rdtsc() - dec_start_time;
193         double dec_time = timer() - start_time;
194         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
195     }
196
197     // check decode results
198     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
199         printf("decode ok!\n");
200     else
201         printf("ERROR: bad decoder!\n");
202
203     // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
204
205     memset(dec_bytes, 0xcc, in_size);
206
207     // try interleaved rANS encode
208     printf("\ninterleaved rANS encode:\n");
209     for (int run=0; run < 5; run++) {
210         double start_time = timer();
211         uint64_t enc_start_time = __rdtsc();
212
213         RansWordEnc rans0 = RansWordEncInit();
214         RansWordEnc rans1 = RansWordEncInit();
215
216         uint16_t* ptr = (uint16_t *)(out_buf + out_max_size); // *end* of output buffer
217
218         // odd number of bytes?
219         if (in_size & 1) {
220             int s = in_bytes[in_size - 1];
221             RansWordEncPut(&rans0, &ptr, stats.cum_freqs[s], stats.freqs[s]);
222         }
223
224         for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
225             int s1 = in_bytes[i-1];
226             int s0 = in_bytes[i-2];
227             RansWordEncPut(&rans1, &ptr, stats.cum_freqs[s1], stats.freqs[s1]);
228             RansWordEncPut(&rans0, &ptr, stats.cum_freqs[s0], stats.freqs[s0]);
229         }
230         RansWordEncFlush(&rans1, &ptr);
231         RansWordEncFlush(&rans0, &ptr);
232         rans_begin = ptr;
233
234         uint64_t enc_clocks = __rdtsc() - enc_start_time;
235         double enc_time = timer() - start_time;
236         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
237     }
238     printf("interleaved rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t*)rans_begin));
239
240     // try interleaved rANS decode
241     for (int run=0; run < 5; run++) {
242         double start_time = timer();
243         uint64_t dec_start_time = __rdtsc();
244
245         RansWordDec rans0, rans1;
246         uint16_t* ptr = rans_begin;
247         RansWordDecInit(&rans0, &ptr);
248         RansWordDecInit(&rans1, &ptr);
249
250         for (size_t i=0; i < (in_size & ~1); i += 2) {
251             uint8_t s0 = RansWordDecSym(&rans0, &tab);
252             uint8_t s1 = RansWordDecSym(&rans1, &tab);
253             dec_bytes[i+0] = (uint8_t) s0;
254             dec_bytes[i+1] = (uint8_t) s1;
255             RansWordDecRenorm(&rans0, &ptr);
256             RansWordDecRenorm(&rans1, &ptr);
257         }
258
259         // last byte, if number of bytes was odd
260         if (in_size & 1) {
261             uint8_t s0 = RansWordDecSym(&rans0, &tab);
262             dec_bytes[in_size - 1] = (uint8_t) s0;
263         }
264
265         uint64_t dec_clocks = __rdtsc() - dec_start_time;
266         double dec_time = timer() - start_time;
267         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
268     }
269
270     // check decode results
271     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
272         printf("decode ok!\n");
273     else
274         printf("ERROR: bad decoder!\n");
275
276     // ---- SIMD interleaved rANS encode/decode.
277
278     memset(dec_bytes, 0xcc, in_size);
279
280     // try SIMD rANS encode
281     // this is written for clarity not speed.
282     printf("\ninterleaved SIMD rANS encode: (encode itself isn't SIMD)\n");
283     for (int run=0; run < 5; run++) {
284         double start_time = timer();
285         uint64_t enc_start_time = __rdtsc();
286
287         RansWordEnc rans[8];
288         for (int i=0; i < 8; i++)
289             rans[i] = RansWordEncInit();
290
291         uint16_t* ptr = (uint16_t *)(out_buf + out_max_size); // *end* of output buffer
292
293         // last few bytes
294         for (size_t i=in_size; i > 0; i--) { // NB: working in reverse
295             int s = in_bytes[i - 1];
296             RansWordEncPut(&rans[(i - 1) & 7], &ptr, stats.cum_freqs[s], stats.freqs[s]);
297         }
298         for (int i=8; i > 0; i--)
299             RansWordEncFlush(&rans[i - 1], &ptr);
300         rans_begin = ptr;
301
302         uint64_t enc_clocks = __rdtsc() - enc_start_time;
303         double enc_time = timer() - start_time;
304         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
305     }
306     printf("SIMD rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t*)rans_begin));
307
308     // try SIMD rANS decode
309     for (int run=0; run < 5; run++) {
310         double start_time = timer();
311         uint64_t dec_start_time = __rdtsc();
312
313         RansSimdDec rans0, rans1;
314         uint16_t* ptr = rans_begin;
315         RansSimdDecInit(&rans0, &ptr);
316         RansSimdDecInit(&rans1, &ptr);
317
318         for (size_t i=0; i < (in_size & ~7); i += 8) {
319             uint32_t s03 = RansSimdDecSym(&rans0, &tab);
320             uint32_t s47 = RansSimdDecSym(&rans1, &tab);
321             *(uint32_t *)(dec_bytes + i) = s03;
322             *(uint32_t *)(dec_bytes + i + 4) = s47;
323             RansSimdDecRenorm(&rans0, &ptr);
324             RansSimdDecRenorm(&rans1, &ptr);
325         }
326
327         // last few bytes
328         for (size_t i=(in_size & ~7); i < in_size; i++) {
329             RansSimdDec* which = (i & 4) != 0 ? &rans1 : &rans0;
330             uint8_t s = RansWordDecSym(&which->lane[i & 3], &tab);
331             dec_bytes[i] = s;
332         }
333
334         uint64_t dec_clocks = __rdtsc() - dec_start_time;
335         double dec_time = timer() - start_time;
336         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
337     }
338
339     // check decode results
340     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
341         printf("decode ok!\n");
342     else
343         printf("ERROR: bad decoder!\n");
344
345     delete[] out_buf;
346     delete[] dec_bytes;
347     delete[] in_bytes;
348     return 0;
349 }