]> git.sesse.net Git - narabu/blob - ryg_rans/rans64.h
More fixes of hard-coded values.
[narabu] / ryg_rans / rans64.h
1 // 64-bit rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2014
2 //
3 // This uses 64-bit states (63-bit actually) which allows renormalizing
4 // by writing out a whole 32 bits at a time (b=2^32) while still
5 // retaining good precision and allowing for high probability resolution.
6 //
7 // The only caveat is that this version requires 64-bit arithmetic; in
8 // particular, the encoder approximation in the bottom half requires a
9 // fast way to obtain the top 64 bits of an unsigned 64*64 bit product.
10 // 
11 // In short, as written, this code works on 64-bit targets only!
12
13 #ifndef RANS64_HEADER
14 #define RANS64_HEADER
15
16 #include <stdint.h>
17
18 #ifdef assert
19 #define Rans64Assert assert
20 #else
21 #define Rans64Assert(x)
22 #endif
23
24 // --------------------------------------------------------------------------
25
26 // This code needs support for 64-bit long multiplies with 128-bit result
27 // (or more precisely, the top 64 bits of a 128-bit result). This is not
28 // really portable functionality, so we need some compiler-specific hacks
29 // here.
30
31 #if defined(_MSC_VER)
32
33 #include <intrin.h>
34
35 static inline uint64_t Rans64MulHi(uint64_t a, uint64_t b)
36 {
37     return __umulh(a, b);
38 }
39
40 #elif defined(__GNUC__)
41
42 static inline uint64_t Rans64MulHi(uint64_t a, uint64_t b)
43 {
44     return (uint64_t) (((unsigned __int128)a * b) >> 64);
45 }
46
47 #else
48
49 #error Unknown/unsupported compiler!
50
51 #endif
52
53 // --------------------------------------------------------------------------
54
55 // L ('l' in the paper) is the lower bound of our normalization interval.
56 // Between this and our 32-bit-aligned emission, we use 63 (not 64!) bits.
57 // This is done intentionally because exact reciprocals for 63-bit uints
58 // fit in 64-bit uints: this permits some optimizations during encoding.
59 #define RANS64_L (1ull << 31)  // lower bound of our normalization interval
60
61 // State for a rANS encoder. Yep, that's all there is to it.
62 typedef uint64_t Rans64State;
63
64 // Initialize a rANS encoder.
65 static inline void Rans64EncInit(Rans64State* r)
66 {
67     *r = RANS64_L;
68 }
69
70 // Encodes a single symbol with range start "start" and frequency "freq".
71 // All frequencies are assumed to sum to "1 << scale_bits", and the
72 // resulting bytes get written to ptr (which is updated).
73 //
74 // NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from
75 // beginning to end! Likewise, the output bytestream is written *backwards*:
76 // ptr starts pointing at the end of the output buffer and keeps decrementing.
77 static inline void Rans64EncPut(Rans64State* r, uint32_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
78 {
79     Rans64Assert(freq != 0);
80
81     // renormalize (never needs to loop)
82     uint64_t x = *r;
83     uint64_t x_max = ((RANS64_L >> scale_bits) << 32) * freq; // this turns into a shift.
84     if (x >= x_max) {
85         *pptr -= 1;
86         **pptr = (uint32_t) x;
87         x >>= 32;
88         Rans64Assert(x < x_max);
89     }
90
91     // x = C(s,x)
92     *r = ((x / freq) << scale_bits) + (x % freq) + start;
93 }
94
95 // Flushes the rANS encoder.
96 static inline void Rans64EncFlush(Rans64State* r, uint32_t** pptr)
97 {
98     uint64_t x = *r;
99
100     *pptr -= 2;
101     (*pptr)[0] = (uint32_t) (x >> 0);
102     (*pptr)[1] = (uint32_t) (x >> 32);
103 }
104
105 // Initializes a rANS decoder.
106 // Unlike the encoder, the decoder works forwards as you'd expect.
107 static inline void Rans64DecInit(Rans64State* r, uint32_t** pptr)
108 {
109     uint64_t x;
110
111     x  = (uint64_t) ((*pptr)[0]) << 0;
112     x |= (uint64_t) ((*pptr)[1]) << 32;
113     *pptr += 2;
114     *r = x;
115 }
116
117 // Returns the current cumulative frequency (map it to a symbol yourself!)
118 static inline uint32_t Rans64DecGet(Rans64State* r, uint32_t scale_bits)
119 {
120     return *r & ((1u << scale_bits) - 1);
121 }
122
123 // Advances in the bit stream by "popping" a single symbol with range start
124 // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits",
125 // and the resulting bytes get written to ptr (which is updated).
126 static inline void Rans64DecAdvance(Rans64State* r, uint32_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits)
127 {
128     uint64_t mask = (1ull << scale_bits) - 1;
129
130     // s, x = D(x)
131     uint64_t x = *r;
132     x = freq * (x >> scale_bits) + (x & mask) - start;
133
134     // renormalize
135     if (x < RANS64_L) {
136         x = (x << 32) | **pptr;
137         *pptr += 1;
138         Rans64Assert(x >= RANS64_L);
139     }
140
141     *r = x;
142 }
143
144 // --------------------------------------------------------------------------
145
146 // That's all you need for a full encoder; below here are some utility
147 // functions with extra convenience or optimizations.
148
149 // Encoder symbol description
150 // This (admittedly odd) selection of parameters was chosen to make
151 // RansEncPutSymbol as cheap as possible.
152 typedef struct {
153     uint64_t rcp_freq;  // Fixed-point reciprocal frequency
154     uint32_t freq;      // Symbol frequency
155     uint32_t bias;      // Bias
156     uint32_t cmpl_freq; // Complement of frequency: (1 << scale_bits) - freq
157     uint32_t rcp_shift; // Reciprocal shift
158 } Rans64EncSymbol;
159
160 // Decoder symbols are straightforward.
161 typedef struct {
162     uint32_t start;     // Start of range.
163     uint32_t freq;      // Symbol frequency.
164 } Rans64DecSymbol;
165
166 // Initializes an encoder symbol to start "start" and frequency "freq"
167 static inline void Rans64EncSymbolInit(Rans64EncSymbol* s, uint32_t start, uint32_t freq, uint32_t scale_bits)
168 {
169     Rans64Assert(scale_bits <= 31);
170     Rans64Assert(start <= (1u << scale_bits));
171     Rans64Assert(freq <= (1u << scale_bits) - start);
172
173     // Say M := 1 << scale_bits.
174     //
175     // The original encoder does:
176     //   x_new = (x/freq)*M + start + (x%freq)
177     //
178     // The fast encoder does (schematically):
179     //   q     = mul_hi(x, rcp_freq) >> rcp_shift   (division)
180     //   r     = x - q*freq                         (remainder)
181     //   x_new = q*M + bias + r                     (new x)
182     // plugging in r into x_new yields:
183     //   x_new = bias + x + q*(M - freq)
184     //        =: bias + x + q*cmpl_freq             (*)
185     //
186     // and we can just precompute cmpl_freq. Now we just need to
187     // set up our parameters such that the original encoder and
188     // the fast encoder agree.
189
190     s->freq = freq;
191     s->cmpl_freq = ((1 << scale_bits) - freq);
192     if (freq < 2) {
193         // freq=0 symbols are never valid to encode, so it doesn't matter what
194         // we set our values to.
195         //
196         // freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately,
197         // our fixed-point reciprocal approximation can only multiply by values
198         // smaller than 1.
199         //
200         // So we use the "next best thing": rcp_freq=~0, rcp_shift=0.
201         // This gives:
202         //   q = mul_hi(x, rcp_freq) >> rcp_shift
203         //     = mul_hi(x, (1<<64) - 1)) >> 0
204         //     = floor(x - x/(2^64))
205         //     = x - 1 if 1 <= x < 2^64
206         // and we know that x>0 (x=0 is never in a valid normalization interval).
207         //
208         // So we now need to choose the other parameters such that
209         //   x_new = x*M + start
210         // plug it in:
211         //     x*M + start                   (desired result)
212         //   = bias + x + q*cmpl_freq        (*)
213         //   = bias + x + (x - 1)*(M - 1)    (plug in q=x-1, cmpl_freq)
214         //   = bias + 1 + (x - 1)*M
215         //   = x*M + (bias + 1 - M)
216         //
217         // so we have start = bias + 1 - M, or equivalently
218         //   bias = start + M - 1.
219         s->rcp_freq = ~0ull;
220         s->rcp_shift = 0;
221         s->bias = start + (1 << scale_bits) - 1;
222     } else {
223         // Alverson, "Integer Division using reciprocals"
224         // shift=ceil(log2(freq))
225         uint32_t shift = 0;
226         uint64_t x0, x1, t0, t1;
227         while (freq > (1u << shift))
228             shift++;
229
230         // long divide ((uint128) (1 << (shift + 63)) + freq-1) / freq
231         // by splitting it into two 64:64 bit divides (this works because
232         // the dividend has a simple form.)
233         x0 = freq - 1;
234         x1 = 1ull << (shift + 31);
235
236         t1 = x1 / freq;
237         x0 += (x1 % freq) << 32;
238         t0 = x0 / freq;
239
240         s->rcp_freq = t0 + (t1 << 32);
241         s->rcp_shift = shift - 1;
242
243         // With these values, 'q' is the correct quotient, so we
244         // have bias=start.
245         s->bias = start;
246     }
247 }
248
249 // Initialize a decoder symbol to start "start" and frequency "freq"
250 static inline void Rans64DecSymbolInit(Rans64DecSymbol* s, uint32_t start, uint32_t freq)
251 {
252     Rans64Assert(start <= (1 << 31));
253     Rans64Assert(freq <= (1 << 31) - start);
254     s->start = start;
255     s->freq = freq;
256 }
257
258 // Encodes a given symbol. This is faster than straight RansEnc since we can do
259 // multiplications instead of a divide.
260 //
261 // See RansEncSymbolInit for a description of how this works.
262 static inline void Rans64EncPutSymbol(Rans64State* r, uint32_t** pptr, Rans64EncSymbol const* sym, uint32_t scale_bits)
263 {
264     Rans64Assert(sym->freq != 0); // can't encode symbol with freq=0
265
266     // renormalize
267     uint64_t x = *r;
268     uint64_t x_max = ((RANS64_L >> scale_bits) << 32) * sym->freq; // turns into a shift
269     if (x >= x_max) {
270         *pptr -= 1;
271         **pptr = (uint32_t) x;
272         x >>= 32;
273     }
274
275     // x = C(s,x)
276     uint64_t q = Rans64MulHi(x, sym->rcp_freq) >> sym->rcp_shift;
277     *r = x + sym->bias + q * sym->cmpl_freq;
278 }
279
280 // Equivalent to RansDecAdvance that takes a symbol.
281 static inline void Rans64DecAdvanceSymbol(Rans64State* r, uint32_t** pptr, Rans64DecSymbol const* sym, uint32_t scale_bits)
282 {
283     Rans64DecAdvance(r, pptr, sym->start, sym->freq, scale_bits);
284 }
285
286 // Advances in the bit stream by "popping" a single symbol with range start
287 // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits".
288 // No renormalization or output happens.
289 static inline void Rans64DecAdvanceStep(Rans64State* r, uint32_t start, uint32_t freq, uint32_t scale_bits)
290 {
291     uint64_t mask = (1u << scale_bits) - 1;
292
293     // s, x = D(x)
294     uint64_t x = *r;
295     *r = freq * (x >> scale_bits) + (x & mask) - start;
296 }
297
298 // Equivalent to RansDecAdvanceStep that takes a symbol.
299 static inline void Rans64DecAdvanceSymbolStep(Rans64State* r, Rans64DecSymbol const* sym, uint32_t scale_bits)
300 {
301     Rans64DecAdvanceStep(r, sym->start, sym->freq, scale_bits);
302 }
303
304 // Renormalize.
305 static inline void Rans64DecRenorm(Rans64State* r, uint32_t** pptr)
306 {
307     // renormalize
308     uint64_t x = *r;
309     if (x < RANS64_L) {
310         x = (x << 32) | **pptr;
311         *pptr += 1;
312         Rans64Assert(x >= RANS64_L);
313     }
314
315     *r = x;
316 }
317
318 #endif // RANS64_HEADER