]> git.sesse.net Git - narabu/blob - ryg_rans/main_simd.cpp
More fixes of hard-coded values.
[narabu] / ryg_rans / main_simd.cpp
1 #include "platform.h"
2 #include <stdio.h>
3 #include <stdarg.h>
4 #include <stdlib.h>
5 #include <stdint.h>
6 #include <string.h>
7 #include <assert.h>
8
9 #include "rans_word_sse41.h"
10 #include "renormalize.h"
11
12 // This is just the sample program. All the meat is in rans_byte.h.
13
14 static void panic(const char *fmt, ...)
15 {
16     va_list arg;
17
18     va_start(arg, fmt);
19     fputs("Error: ", stderr);
20     vfprintf(stderr, fmt, arg);
21     va_end(arg);
22     fputs("\n", stderr);
23
24     exit(1);
25 }
26
27 static uint8_t* read_file(char const* filename, size_t* out_size)
28 {
29     FILE* f = fopen(filename, "rb");
30     if (!f)
31         panic("file not found: %s\n", filename);
32
33     fseek(f, 0, SEEK_END);
34     size_t size = ftell(f);
35     fseek(f, 0, SEEK_SET);
36
37     uint8_t* buf = new uint8_t[size];
38     if (fread(buf, size, 1, f) != 1)
39         panic("read failed\n");
40
41     fclose(f);
42     if (out_size)
43         *out_size = size;
44
45     return buf;
46 }
47
48 // ---- Stats
49
50 struct SymbolStats
51 {
52     uint32_t freqs[256];
53     uint32_t cum_freqs[257];
54
55     void count_freqs(uint8_t const* in, size_t nbytes);
56     void calc_cum_freqs();
57     void normalize_freqs(uint32_t target_total);
58 };
59
60 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
61 {
62     for (int i=0; i < 256; i++)
63         freqs[i] = 0;
64
65     for (size_t i=0; i < nbytes; i++)
66         freqs[in[i]]++;
67 }
68
69 void SymbolStats::calc_cum_freqs()
70 {
71     cum_freqs[0] = 0;
72     for (int i=0; i < 256; i++)
73         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
74 }
75
76 void SymbolStats::normalize_freqs(uint32_t target_total)
77 {
78     assert(target_total >= 256);
79     
80     calc_cum_freqs();
81     
82     OptimalRenormalize(cum_freqs, 256, target_total);
83
84     // calculate updated freqs and make sure we didn't screw anything up
85     assert(cum_freqs[0] == 0 && cum_freqs[256] == target_total);
86     for (int i=0; i < 256; i++) {
87         if (freqs[i] == 0)
88             assert(cum_freqs[i+1] == cum_freqs[i]);
89         else
90             assert(cum_freqs[i+1] > cum_freqs[i]);
91
92         // calc updated freq
93         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
94     }
95 }
96
97 int main()
98 {
99     size_t in_size;
100     uint8_t* in_bytes = read_file("book1", &in_size);
101
102     SymbolStats stats;
103     stats.count_freqs(in_bytes, in_size);
104     stats.normalize_freqs(RANS_WORD_M);
105
106     // init decoding tables
107     RansWordTables tab;
108     for (int s=0; s < 256; s++)
109         RansWordTablesInitSymbol(&tab, (uint8_t)s, stats.cum_freqs[s], stats.freqs[s]);
110
111     size_t out_max_size = in_size + (in_size >> 3) + 128;
112     uint8_t* out_buf = new uint8_t[out_max_size + 16]; // extra bytes at end
113     uint8_t* dec_bytes = new uint8_t[in_size];
114
115     // try rANS encode
116     uint16_t *rans_begin;
117
118     // ---- regular rANS encode/decode. Typical usage.
119
120     memset(dec_bytes, 0xcc, in_size);
121
122     printf("rANS encode:\n");
123     for (int run=0; run < 5; run++) {
124         double start_time = timer();
125         uint64_t enc_start_time = __rdtsc();
126
127         RansWordEnc rans = RansWordEncInit();
128
129         uint16_t* ptr = (uint16_t *) (out_buf + out_max_size); // *end* of output buffer
130         for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
131             int s = in_bytes[i-1];
132             RansWordEncPut(&rans, &ptr, stats.cum_freqs[s], stats.freqs[s]);
133         }
134         RansWordEncFlush(&rans, &ptr);
135         rans_begin = ptr;
136
137         uint64_t enc_clocks = __rdtsc() - enc_start_time;
138         double enc_time = timer() - start_time;
139         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
140     }
141     printf("rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t *)rans_begin));
142
143     // try rANS decode
144     for (int run=0; run < 5; run++) {
145         double start_time = timer();
146         uint64_t dec_start_time = __rdtsc();
147
148         RansWordDec rans;
149         uint16_t* ptr = rans_begin;
150         RansWordDecInit(&rans, &ptr);
151
152         for (size_t i=0; i < in_size; i++) {
153             uint8_t s = RansWordDecSym(&rans, &tab);
154             dec_bytes[i] = (uint8_t) s;
155             RansWordDecRenorm(&rans, &ptr);
156         }
157
158         uint64_t dec_clocks = __rdtsc() - dec_start_time;
159         double dec_time = timer() - start_time;
160         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
161     }
162
163     // check decode results
164     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
165         printf("decode ok!\n");
166     else
167         printf("ERROR: bad decoder!\n");
168
169     // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
170
171     memset(dec_bytes, 0xcc, in_size);
172
173     // try interleaved rANS encode
174     printf("\ninterleaved rANS encode:\n");
175     for (int run=0; run < 5; run++) {
176         double start_time = timer();
177         uint64_t enc_start_time = __rdtsc();
178
179         RansWordEnc rans0 = RansWordEncInit();
180         RansWordEnc rans1 = RansWordEncInit();
181
182         uint16_t* ptr = (uint16_t *)(out_buf + out_max_size); // *end* of output buffer
183
184         // odd number of bytes?
185         if (in_size & 1) {
186             int s = in_bytes[in_size - 1];
187             RansWordEncPut(&rans0, &ptr, stats.cum_freqs[s], stats.freqs[s]);
188         }
189
190         for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
191             int s1 = in_bytes[i-1];
192             int s0 = in_bytes[i-2];
193             RansWordEncPut(&rans1, &ptr, stats.cum_freqs[s1], stats.freqs[s1]);
194             RansWordEncPut(&rans0, &ptr, stats.cum_freqs[s0], stats.freqs[s0]);
195         }
196         RansWordEncFlush(&rans1, &ptr);
197         RansWordEncFlush(&rans0, &ptr);
198         rans_begin = ptr;
199
200         uint64_t enc_clocks = __rdtsc() - enc_start_time;
201         double enc_time = timer() - start_time;
202         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
203     }
204     printf("interleaved rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t*)rans_begin));
205
206     // try interleaved rANS decode
207     for (int run=0; run < 5; run++) {
208         double start_time = timer();
209         uint64_t dec_start_time = __rdtsc();
210
211         RansWordDec rans0, rans1;
212         uint16_t* ptr = rans_begin;
213         RansWordDecInit(&rans0, &ptr);
214         RansWordDecInit(&rans1, &ptr);
215
216         for (size_t i=0; i < (in_size & ~1); i += 2) {
217             uint8_t s0 = RansWordDecSym(&rans0, &tab);
218             uint8_t s1 = RansWordDecSym(&rans1, &tab);
219             dec_bytes[i+0] = (uint8_t) s0;
220             dec_bytes[i+1] = (uint8_t) s1;
221             RansWordDecRenorm(&rans0, &ptr);
222             RansWordDecRenorm(&rans1, &ptr);
223         }
224
225         // last byte, if number of bytes was odd
226         if (in_size & 1) {
227             uint8_t s0 = RansWordDecSym(&rans0, &tab);
228             dec_bytes[in_size - 1] = (uint8_t) s0;
229         }
230
231         uint64_t dec_clocks = __rdtsc() - dec_start_time;
232         double dec_time = timer() - start_time;
233         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
234     }
235
236     // check decode results
237     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
238         printf("decode ok!\n");
239     else
240         printf("ERROR: bad decoder!\n");
241
242     // ---- SIMD interleaved rANS encode/decode.
243
244     memset(dec_bytes, 0xcc, in_size);
245
246     // try SIMD rANS encode
247     // this is written for clarity not speed.
248     printf("\ninterleaved SIMD rANS encode: (encode itself isn't SIMD)\n");
249     for (int run=0; run < 5; run++) {
250         double start_time = timer();
251         uint64_t enc_start_time = __rdtsc();
252
253         RansWordEnc rans[8];
254         for (int i=0; i < 8; i++)
255             rans[i] = RansWordEncInit();
256
257         uint16_t* ptr = (uint16_t *)(out_buf + out_max_size); // *end* of output buffer
258
259         // last few bytes
260         for (size_t i=in_size; i > 0; i--) { // NB: working in reverse
261             int s = in_bytes[i - 1];
262             RansWordEncPut(&rans[(i - 1) & 7], &ptr, stats.cum_freqs[s], stats.freqs[s]);
263         }
264         for (int i=8; i > 0; i--)
265             RansWordEncFlush(&rans[i - 1], &ptr);
266         rans_begin = ptr;
267
268         uint64_t enc_clocks = __rdtsc() - enc_start_time;
269         double enc_time = timer() - start_time;
270         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
271     }
272     printf("SIMD rANS: %d bytes\n", (int) (out_buf + out_max_size - (uint8_t*)rans_begin));
273
274     // try SIMD rANS decode
275     for (int run=0; run < 5; run++) {
276         double start_time = timer();
277         uint64_t dec_start_time = __rdtsc();
278
279         RansSimdDec rans0, rans1;
280         uint16_t* ptr = rans_begin;
281         RansSimdDecInit(&rans0, &ptr);
282         RansSimdDecInit(&rans1, &ptr);
283
284         for (size_t i=0; i < (in_size & ~7); i += 8) {
285             uint32_t s03 = RansSimdDecSym(&rans0, &tab);
286             uint32_t s47 = RansSimdDecSym(&rans1, &tab);
287             *(uint32_t *)(dec_bytes + i) = s03;
288             *(uint32_t *)(dec_bytes + i + 4) = s47;
289             RansSimdDecRenorm(&rans0, &ptr);
290             RansSimdDecRenorm(&rans1, &ptr);
291         }
292
293         // last few bytes
294         for (size_t i=(in_size & ~7); i < in_size; i++) {
295             RansSimdDec* which = (i & 4) != 0 ? &rans1 : &rans0;
296             uint8_t s = RansWordDecSym(&which->lane[i & 3], &tab);
297             dec_bytes[i] = s;
298         }
299
300         uint64_t dec_clocks = __rdtsc() - dec_start_time;
301         double dec_time = timer() - start_time;
302         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
303     }
304
305     // check decode results
306     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
307         printf("decode ok!\n");
308     else
309         printf("ERROR: bad decoder!\n");
310
311     delete[] out_buf;
312     delete[] dec_bytes;
313     delete[] in_bytes;
314     return 0;
315 }