]> git.sesse.net Git - narabu/blob - ryg_rans/rans_byte.h
More fixes of hard-coded values.
[narabu] / ryg_rans / rans_byte.h
1 // Simple byte-aligned rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2014
2 //
3 // Not intended to be "industrial strength"; just meant to illustrate the general
4 // idea.
5
6 #ifndef RANS_BYTE_HEADER
7 #define RANS_BYTE_HEADER
8
9 #include <stdint.h>
10
11 #ifdef assert
12 #define RansAssert assert
13 #else
14 #define RansAssert(x)
15 #endif
16
17 // READ ME FIRST:
18 //
19 // This is designed like a typical arithmetic coder API, but there's three
20 // twists you absolutely should be aware of before you start hacking:
21 //
22 // 1. You need to encode data in *reverse* - last symbol first. rANS works
23 //    like a stack: last in, first out.
24 // 2. Likewise, the encoder outputs bytes *in reverse* - that is, you give
25 //    it a pointer to the *end* of your buffer (exclusive), and it will
26 //    slowly move towards the beginning as more bytes are emitted.
27 // 3. Unlike basically any other entropy coder implementation you might
28 //    have used, you can interleave data from multiple independent rANS
29 //    encoders into the same bytestream without any extra signaling;
30 //    you can also just write some bytes by yourself in the middle if
31 //    you want to. This is in addition to the usual arithmetic encoder
32 //    property of being able to switch models on the fly. Writing raw
33 //    bytes can be useful when you have some data that you know is
34 //    incompressible, and is cheaper than going through the rANS encode
35 //    function. Using multiple rANS coders on the same byte stream wastes
36 //    a few bytes compared to using just one, but execution of two
37 //    independent encoders can happen in parallel on superscalar and
38 //    Out-of-Order CPUs, so this can be *much* faster in tight decoding
39 //    loops.
40 //
41 //    This is why all the rANS functions take the write pointer as an
42 //    argument instead of just storing it in some context struct.
43
44 // --------------------------------------------------------------------------
45
46 // L ('l' in the paper) is the lower bound of our normalization interval.
47 // Between this and our byte-aligned emission, we use 31 (not 32!) bits.
48 // This is done intentionally because exact reciprocals for 31-bit uints
49 // fit in 32-bit uints: this permits some optimizations during encoding.
50 #define RANS_BYTE_L (1u << 23)  // lower bound of our normalization interval
51
52 // State for a rANS encoder. Yep, that's all there is to it.
53 typedef uint32_t RansState;
54
55 // Initialize a rANS encoder.
56 static inline void RansEncInit(RansState* r)
57 {
58     *r = RANS_BYTE_L;
59 }
60
61 // Renormalize the encoder. Internal function.
62 static inline RansState RansEncRenorm(RansState x, uint8_t** pptr, uint32_t freq, uint32_t scale_bits)
63 {
64     uint32_t x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq; // this turns into a shift.
65     if (x >= x_max) {
66         uint8_t* ptr = *pptr;
67         do {
68             *--ptr = (uint8_t) (x & 0xff);
69             x >>= 8;
70         } while (x >= x_max);
71         *pptr = ptr;
72     }
73     return x;
74 }
75
76 // Encodes a single symbol with range start "start" and frequency "freq".
77 // All frequencies are assumed to sum to "1 << scale_bits", and the
78 // resulting bytes get written to ptr (which is updated).
79 //
80 // NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from
81 // beginning to end! Likewise, the output bytestream is written *backwards*:
82 // ptr starts pointing at the end of the output buffer and keeps decrementing.
83 static inline void RansEncPut(RansState* r, uint8_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
84 {
85     // renormalize
86     RansState x = RansEncRenorm(*r, pptr, freq, scale_bits);
87
88     // x = C(s,x)
89     *r = ((x / freq) << scale_bits) + (x % freq) + start;
90 }
91
92 // Flushes the rANS encoder.
93 static inline void RansEncFlush(RansState* r, uint8_t** pptr)
94 {
95     uint32_t x = *r;
96     uint8_t* ptr = *pptr;
97
98     ptr -= 4;
99     ptr[0] = (uint8_t) (x >> 0);
100     ptr[1] = (uint8_t) (x >> 8);
101     ptr[2] = (uint8_t) (x >> 16);
102     ptr[3] = (uint8_t) (x >> 24);
103
104     *pptr = ptr;
105 }
106
107 // Initializes a rANS decoder.
108 // Unlike the encoder, the decoder works forwards as you'd expect.
109 static inline void RansDecInit(RansState* r, uint8_t** pptr)
110 {
111     uint32_t x;
112     uint8_t* ptr = *pptr;
113
114     x  = ptr[0] << 0;
115     x |= ptr[1] << 8;
116     x |= ptr[2] << 16;
117     x |= ptr[3] << 24;
118     ptr += 4;
119
120     *pptr = ptr;
121     *r = x;
122 }
123
124 // Returns the current cumulative frequency (map it to a symbol yourself!)
125 static inline uint32_t RansDecGet(RansState* r, uint32_t scale_bits)
126 {
127     return *r & ((1u << scale_bits) - 1);
128 }
129
130 // Advances in the bit stream by "popping" a single symbol with range start
131 // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits",
132 // and the resulting bytes get written to ptr (which is updated).
133 static inline void RansDecAdvance(RansState* r, uint8_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
134 {
135     uint32_t mask = (1u << scale_bits) - 1;
136
137     // s, x = D(x)
138     uint32_t x = *r;
139     x = freq * (x >> scale_bits) + (x & mask) - start;
140
141     // renormalize
142     if (x < RANS_BYTE_L) {
143         uint8_t* ptr = *pptr;
144         do x = (x << 8) | *ptr++; while (x < RANS_BYTE_L);
145         *pptr = ptr;
146     }
147
148     *r = x;
149 }
150
151 // --------------------------------------------------------------------------
152
153 // That's all you need for a full encoder; below here are some utility
154 // functions with extra convenience or optimizations.
155
156 // Encoder symbol description
157 // This (admittedly odd) selection of parameters was chosen to make
158 // RansEncPutSymbol as cheap as possible.
159 typedef struct {
160     uint32_t x_max;     // (Exclusive) upper bound of pre-normalization interval
161     uint32_t rcp_freq;  // Fixed-point reciprocal frequency
162     uint32_t bias;      // Bias
163     uint16_t cmpl_freq; // Complement of frequency: (1 << scale_bits) - freq
164     uint16_t rcp_shift; // Reciprocal shift
165 } RansEncSymbol;
166
167 // Decoder symbols are straightforward.
168 typedef struct {
169     uint16_t start;     // Start of range.
170     uint16_t freq;      // Symbol frequency.
171 } RansDecSymbol;
172
173 // Initializes an encoder symbol to start "start" and frequency "freq"
174 static inline void RansEncSymbolInit(RansEncSymbol* s, uint32_t start, uint32_t freq, uint32_t scale_bits)
175 {
176     RansAssert(scale_bits <= 16);
177     RansAssert(start <= (1u << scale_bits));
178     RansAssert(freq <= (1u << scale_bits) - start);
179
180     // Say M := 1 << scale_bits.
181     //
182     // The original encoder does:
183     //   x_new = (x/freq)*M + start + (x%freq)
184     //
185     // The fast encoder does (schematically):
186     //   q     = mul_hi(x, rcp_freq) >> rcp_shift   (division)
187     //   r     = x - q*freq                         (remainder)
188     //   x_new = q*M + bias + r                     (new x)
189     // plugging in r into x_new yields:
190     //   x_new = bias + x + q*(M - freq)
191     //        =: bias + x + q*cmpl_freq             (*)
192     //
193     // and we can just precompute cmpl_freq. Now we just need to
194     // set up our parameters such that the original encoder and
195     // the fast encoder agree.
196
197     s->x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq;
198     s->cmpl_freq = (uint16_t) ((1 << scale_bits) - freq);
199     if (freq < 2) {
200         // freq=0 symbols are never valid to encode, so it doesn't matter what
201         // we set our values to.
202         //
203         // freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately,
204         // our fixed-point reciprocal approximation can only multiply by values
205         // smaller than 1.
206         //
207         // So we use the "next best thing": rcp_freq=0xffffffff, rcp_shift=0.
208         // This gives:
209         //   q = mul_hi(x, rcp_freq) >> rcp_shift
210         //     = mul_hi(x, (1<<32) - 1)) >> 0
211         //     = floor(x - x/(2^32))
212         //     = x - 1 if 1 <= x < 2^32
213         // and we know that x>0 (x=0 is never in a valid normalization interval).
214         //
215         // So we now need to choose the other parameters such that
216         //   x_new = x*M + start
217         // plug it in:
218         //     x*M + start                   (desired result)
219         //   = bias + x + q*cmpl_freq        (*)
220         //   = bias + x + (x - 1)*(M - 1)    (plug in q=x-1, cmpl_freq)
221         //   = bias + 1 + (x - 1)*M
222         //   = x*M + (bias + 1 - M)
223         //
224         // so we have start = bias + 1 - M, or equivalently
225         //   bias = start + M - 1.
226         s->rcp_freq = ~0u;
227         s->rcp_shift = 0;
228         s->bias = start + (1 << scale_bits) - 1;
229     } else {
230         // Alverson, "Integer Division using reciprocals"
231         // shift=ceil(log2(freq))
232         uint32_t shift = 0;
233         while (freq > (1u << shift))
234             shift++;
235
236         s->rcp_freq = (uint32_t) (((1ull << (shift + 31)) + freq-1) / freq);
237         s->rcp_shift = shift - 1;
238
239         // With these values, 'q' is the correct quotient, so we
240         // have bias=start.
241         s->bias = start;
242     }
243 }
244
245 // Initialize a decoder symbol to start "start" and frequency "freq"
246 static inline void RansDecSymbolInit(RansDecSymbol* s, uint32_t start, uint32_t freq)
247 {
248     RansAssert(start <= (1 << 16));
249     RansAssert(freq <= (1 << 16) - start);
250     s->start = (uint16_t) start;
251     s->freq = (uint16_t) freq;
252 }
253
254 // Encodes a given symbol. This is faster than straight RansEnc since we can do
255 // multiplications instead of a divide.
256 //
257 // See RansEncSymbolInit for a description of how this works.
258 static inline void RansEncPutSymbol(RansState* r, uint8_t** pptr, RansEncSymbol const* sym)
259 {
260     RansAssert(sym->x_max != 0); // can't encode symbol with freq=0
261
262     // renormalize
263     uint32_t x = *r;
264     uint32_t x_max = sym->x_max;
265     if (x >= x_max) {
266         uint8_t* ptr = *pptr;
267         do {
268             *--ptr = (uint8_t) (x & 0xff);
269             x >>= 8;
270         } while (x >= x_max);
271         *pptr = ptr;
272     }
273
274     // x = C(s,x)
275     // NOTE: written this way so we get a 32-bit "multiply high" when
276     // available. If you're on a 64-bit platform with cheap multiplies
277     // (e.g. x64), just bake the +32 into rcp_shift.
278     uint32_t q = (uint32_t) (((uint64_t)x * sym->rcp_freq) >> 32) >> sym->rcp_shift;
279     *r = x + sym->bias + q * sym->cmpl_freq;
280 }
281
282 // Equivalent to RansDecAdvance that takes a symbol.
283 static inline void RansDecAdvanceSymbol(RansState* r, uint8_t** pptr, RansDecSymbol const* sym, uint32_t scale_bits)
284 {
285     RansDecAdvance(r, pptr, sym->start, sym->freq, scale_bits);
286 }
287
288 // Advances in the bit stream by "popping" a single symbol with range start
289 // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits".
290 // No renormalization or output happens.
291 static inline void RansDecAdvanceStep(RansState* r, uint32_t start, uint32_t freq, uint32_t scale_bits)
292 {
293     uint32_t mask = (1u << scale_bits) - 1;
294
295     // s, x = D(x)
296     uint32_t x = *r;
297     *r = freq * (x >> scale_bits) + (x & mask) - start;
298 }
299
300 // Equivalent to RansDecAdvanceStep that takes a symbol.
301 static inline void RansDecAdvanceSymbolStep(RansState* r, RansDecSymbol const* sym, uint32_t scale_bits)
302 {
303     RansDecAdvanceStep(r, sym->start, sym->freq, scale_bits);
304 }
305
306 // Renormalize.
307 static inline void RansDecRenorm(RansState* r, uint8_t** pptr)
308 {
309     // renormalize
310     uint32_t x = *r;
311     if (x < RANS_BYTE_L) {
312         uint8_t* ptr = *pptr;
313         do x = (x << 8) | *ptr++; while (x < RANS_BYTE_L);
314         *pptr = ptr;
315     }
316
317     *r = x;
318 }
319
320 #endif // RANS_BYTE_HEADER