]> git.sesse.net Git - narabu/blob - ryg_rans/main64.cpp
More fixes of hard-coded values.
[narabu] / ryg_rans / main64.cpp
1 #include "platform.h"
2 #include <stdio.h>
3 #include <stdarg.h>
4 #include <stdlib.h>
5 #include <stdint.h>
6 #include <string.h>
7 #include <assert.h>
8
9 #include "rans64.h"
10 #include "renormalize.h"
11
12 // This is just the sample program. All the meat is in rans_byte.h.
13
14 static void panic(const char *fmt, ...)
15 {
16     va_list arg;
17
18     va_start(arg, fmt);
19     fputs("Error: ", stderr);
20     vfprintf(stderr, fmt, arg);
21     va_end(arg);
22     fputs("\n", stderr);
23
24     exit(1);
25 }
26
27 static uint8_t* read_file(char const* filename, size_t* out_size)
28 {
29     FILE* f = fopen(filename, "rb");
30     if (!f)
31         panic("file not found: %s\n", filename);
32
33     fseek(f, 0, SEEK_END);
34     size_t size = ftell(f);
35     fseek(f, 0, SEEK_SET);
36
37     uint8_t* buf = new uint8_t[size];
38     if (fread(buf, size, 1, f) != 1)
39         panic("read failed\n");
40
41     fclose(f);
42     if (out_size)
43         *out_size = size;
44
45     return buf;
46 }
47
48 // ---- Stats
49
50 struct SymbolStats
51 {
52     uint32_t freqs[256];
53     uint32_t cum_freqs[257];
54
55     void count_freqs(uint8_t const* in, size_t nbytes);
56     void calc_cum_freqs();
57     void normalize_freqs(uint32_t target_total);
58 };
59
60 void SymbolStats::count_freqs(uint8_t const* in, size_t nbytes)
61 {
62     for (int i=0; i < 256; i++)
63         freqs[i] = 0;
64
65     for (size_t i=0; i < nbytes; i++)
66         freqs[in[i]]++;
67 }
68
69 void SymbolStats::calc_cum_freqs()
70 {
71     cum_freqs[0] = 0;
72     for (int i=0; i < 256; i++)
73         cum_freqs[i+1] = cum_freqs[i] + freqs[i];
74 }
75
76 void SymbolStats::normalize_freqs(uint32_t target_total)
77 {
78     assert(target_total >= 256);
79     
80     calc_cum_freqs();
81
82     OptimalRenormalize(cum_freqs, 256, target_total);
83
84     // calculate updated freqs and make sure we didn't screw anything up
85     assert(cum_freqs[0] == 0 && cum_freqs[256] == target_total);
86     for (int i=0; i < 256; i++) {
87         if (freqs[i] == 0)
88             assert(cum_freqs[i+1] == cum_freqs[i]);
89         else
90             assert(cum_freqs[i+1] > cum_freqs[i]);
91
92         // calc updated freq
93         freqs[i] = cum_freqs[i+1] - cum_freqs[i];
94     }
95 }
96
97 int main()
98 {
99     size_t in_size;
100     uint8_t* in_bytes = read_file("book1", &in_size);
101
102     static const uint32_t prob_bits = 14;
103     static const uint32_t prob_scale = 1 << prob_bits;
104
105     SymbolStats stats;
106     stats.count_freqs(in_bytes, in_size);
107     stats.normalize_freqs(prob_scale);
108
109     // cumlative->symbol table
110     // this is super brute force
111     uint8_t cum2sym[prob_scale];
112     for (int s=0; s < 256; s++)
113         for (uint32_t i=stats.cum_freqs[s]; i < stats.cum_freqs[s+1]; i++)
114             cum2sym[i] = s;
115
116     static size_t out_max_size = 32<<20; // 32MB
117     static size_t out_max_elems = out_max_size / sizeof(uint32_t);
118     uint32_t* out_buf = new uint32_t[out_max_elems];
119     uint32_t* out_end = out_buf + out_max_elems;
120     uint8_t* dec_bytes = new uint8_t[in_size];
121
122     // try rANS encode
123     uint32_t *rans_begin;
124     Rans64EncSymbol esyms[256];
125     Rans64DecSymbol dsyms[256];
126
127     for (int i=0; i < 256; i++) {
128         Rans64EncSymbolInit(&esyms[i], stats.cum_freqs[i], stats.freqs[i], prob_bits);
129         Rans64DecSymbolInit(&dsyms[i], stats.cum_freqs[i], stats.freqs[i]);
130     }
131
132     // ---- regular rANS encode/decode. Typical usage.
133
134     memset(dec_bytes, 0xcc, in_size);
135
136     printf("rANS encode:\n");
137     for (int run=0; run < 5; run++) {
138         double start_time = timer();
139         uint64_t enc_start_time = __rdtsc();
140
141         Rans64State rans;
142         Rans64EncInit(&rans);
143
144         uint32_t* ptr = out_end; // *end* of output buffer
145         for (size_t i=in_size; i > 0; i--) { // NB: working in reverse!
146             int s = in_bytes[i-1];
147             Rans64EncPutSymbol(&rans, &ptr, &esyms[s], prob_bits);
148         }
149         Rans64EncFlush(&rans, &ptr);
150         rans_begin = ptr;
151
152         uint64_t enc_clocks = __rdtsc() - enc_start_time;
153         double enc_time = timer() - start_time;
154         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
155     }
156     printf("rANS: %d bytes\n", (int) ((out_end - rans_begin) * sizeof(uint32_t)));
157
158     // try rANS decode
159     for (int run=0; run < 5; run++) {
160         double start_time = timer();
161         uint64_t dec_start_time = __rdtsc();
162
163         Rans64State rans;
164         uint32_t* ptr = rans_begin;
165         Rans64DecInit(&rans, &ptr);
166
167         for (size_t i=0; i < in_size; i++) {
168             uint32_t s = cum2sym[Rans64DecGet(&rans, prob_bits)];
169             dec_bytes[i] = (uint8_t) s;
170             Rans64DecAdvanceSymbol(&rans, &ptr, &dsyms[s], prob_bits);
171         }
172
173         uint64_t dec_clocks = __rdtsc() - dec_start_time;
174         double dec_time = timer() - start_time;
175         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
176     }
177
178     // check decode results
179     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
180         printf("decode ok!\n");
181     else
182         printf("ERROR: bad decoder!\n");
183
184     // ---- interleaved rANS encode/decode. This is the kind of thing you might do to optimize critical paths.
185
186     memset(dec_bytes, 0xcc, in_size);
187
188     // try interleaved rANS encode
189     printf("\ninterleaved rANS encode:\n");
190     for (int run=0; run < 5; run++) {
191         double start_time = timer();
192         uint64_t enc_start_time = __rdtsc();
193
194         Rans64State rans0, rans1;
195         Rans64EncInit(&rans0);
196         Rans64EncInit(&rans1);
197
198         uint32_t* ptr = out_end;
199
200         // odd number of bytes?
201         if (in_size & 1) {
202             int s = in_bytes[in_size - 1];
203             Rans64EncPutSymbol(&rans0, &ptr, &esyms[s], prob_bits);
204         }
205
206         for (size_t i=(in_size & ~1); i > 0; i -= 2) { // NB: working in reverse!
207             int s1 = in_bytes[i-1];
208             int s0 = in_bytes[i-2];
209             Rans64EncPutSymbol(&rans1, &ptr, &esyms[s1], prob_bits);
210             Rans64EncPutSymbol(&rans0, &ptr, &esyms[s0], prob_bits);
211         }
212         Rans64EncFlush(&rans1, &ptr);
213         Rans64EncFlush(&rans0, &ptr);
214         rans_begin = ptr;
215
216         uint64_t enc_clocks = __rdtsc() - enc_start_time;
217         double enc_time = timer() - start_time;
218         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMiB/s)\n", enc_clocks, 1.0 * enc_clocks / in_size, 1.0 * in_size / (enc_time * 1048576.0));
219     }
220     printf("interleaved rANS: %d bytes\n", (int) ((out_end - rans_begin) * sizeof(uint32_t)));
221
222     // try interleaved rANS decode
223     for (int run=0; run < 5; run++) {
224         double start_time = timer();
225         uint64_t dec_start_time = __rdtsc();
226
227         Rans64State rans0, rans1;
228         uint32_t* ptr = rans_begin;
229         Rans64DecInit(&rans0, &ptr);
230         Rans64DecInit(&rans1, &ptr);
231
232         for (size_t i=0; i < (in_size & ~1); i += 2) {
233             uint32_t s0 = cum2sym[Rans64DecGet(&rans0, prob_bits)];
234             uint32_t s1 = cum2sym[Rans64DecGet(&rans1, prob_bits)];
235             dec_bytes[i+0] = (uint8_t) s0;
236             dec_bytes[i+1] = (uint8_t) s1;
237             Rans64DecAdvanceSymbolStep(&rans0, &dsyms[s0], prob_bits);
238             Rans64DecAdvanceSymbolStep(&rans1, &dsyms[s1], prob_bits);
239             Rans64DecRenorm(&rans0, &ptr);
240             Rans64DecRenorm(&rans1, &ptr);
241         }
242
243         // last byte, if number of bytes was odd
244         if (in_size & 1) {
245             uint32_t s0 = cum2sym[Rans64DecGet(&rans0, prob_bits)];
246             dec_bytes[in_size - 1] = (uint8_t) s0;
247             Rans64DecAdvanceSymbol(&rans0, &ptr, &dsyms[s0], prob_bits);
248         }
249
250         uint64_t dec_clocks = __rdtsc() - dec_start_time;
251         double dec_time = timer() - start_time;
252         printf("%"PRIu64" clocks, %.1f clocks/symbol (%5.1fMB/s)\n", dec_clocks, 1.0 * dec_clocks / in_size, 1.0 * in_size / (dec_time * 1048576.0));
253     }
254
255     // check decode results
256     if (memcmp(in_bytes, dec_bytes, in_size) == 0)
257         printf("decode ok!\n");
258     else
259         printf("ERROR: bad decoder!\n");
260
261     delete[] out_buf;
262     delete[] dec_bytes;
263     delete[] in_bytes;
264     return 0;
265 }