]> git.sesse.net Git - movit/blob - fp16.cpp
Add more unit tests for fp16.
[movit] / fp16.cpp
1 #include "fp16.h"
2
3 namespace movit {
4 namespace {
5
6 union fp64 {
7         double f;
8         unsigned long long ll;
9 };
10
11 template<class FP16_INT_T,
12          int FP16_BIAS, int FP16_MANTISSA_BITS, int FP16_EXPONENT_BITS, int FP16_MAX_EXPONENT,
13          int FP64_BIAS, int FP64_MANTISSA_BITS, int FP64_EXPONENT_BITS, int FP64_MAX_EXPONENT>
14 inline double fp_upconvert(FP16_INT_T x)
15 {
16         int sign = x >> (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS);
17         int exponent = (x & ((1ULL << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) - 1)) >> FP16_MANTISSA_BITS;
18         unsigned long long mantissa = x & ((1ULL << FP16_MANTISSA_BITS) - 1);
19
20         int sign64;
21         int exponent64;
22         unsigned long long mantissa64;
23
24         if (exponent == 0) {
25                 /* 
26                  * Denormals, or zero. Zero is still zero, denormals become
27                  * ordinary numbers.
28                  */
29                 if (mantissa == 0) {
30                         sign64 = sign;
31                         exponent64 = 0;
32                         mantissa64 = 0;
33                 } else {
34                         sign64 = sign;
35                         exponent64 = FP64_BIAS - FP16_BIAS;
36                         mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS + 1);
37
38                         /* Normalize the number. */
39                         while ((mantissa64 & (1ULL << FP64_MANTISSA_BITS)) == 0) {
40                                 --exponent64;
41                                 mantissa64 <<= 1;
42                         }
43
44                         /* Clear the now-implicit one-bit. */
45                         mantissa64 &= ~(1ULL << FP64_MANTISSA_BITS);
46                 }
47         } else if (exponent == FP16_MAX_EXPONENT) {
48                 /*
49                  * Infinities or NaN (mantissa=0 => infinity, otherwise NaN).
50                  * We don't care much about NaNs, so let us just make sure we
51                  * keep the first bit (which signals signalling/non-signalling
52                  * in many implementations).
53                  */
54                 sign64 = sign;
55                 exponent64 = FP64_MAX_EXPONENT;
56                 mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
57         } else {
58                 sign64 = sign;
59
60                 /* Up-conversion is simple. Just re-bias the exponent... */
61                 exponent64 = exponent + FP64_BIAS - FP16_BIAS;
62
63                 /* ...and convert the mantissa. */
64                 mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
65         }
66
67         union fp64 nx;
68         nx.ll = ((unsigned long long)sign64 << (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS))
69             | ((unsigned long long)exponent64 << FP64_MANTISSA_BITS)
70             | mantissa64;
71         return nx.f;
72 }
73                 
74 unsigned long long shift_right_with_round(unsigned long long x, unsigned shift)
75 {
76         /* shifts >= 64 need to be special-cased */
77         if (shift > 64) {
78                 return 0;
79         } else if (shift == 64) {
80                 if (x > (1ULL << 63)) {
81                         return 1;
82                 } else {
83                         return 0;
84                 }
85         }
86
87         unsigned long long round_part = x & ((1ULL << shift) - 1);
88         if (round_part < (1ULL << (shift - 1))) {
89                 /* round down */
90                 x >>= shift;
91         } else if (round_part > (1ULL << (shift - 1))) {
92                 /* round up */
93                 x >>= shift;
94                 ++x;
95         } else {
96                 /* round to nearest even number */
97                 x >>= shift;
98                 if ((x & 1) != 0) {
99                         ++x;
100                 }
101         }
102         return x;
103 }
104
105 template<class FP16_INT_T,
106          int FP16_BIAS, int FP16_MANTISSA_BITS, int FP16_EXPONENT_BITS, int FP16_MAX_EXPONENT,
107          int FP64_BIAS, int FP64_MANTISSA_BITS, int FP64_EXPONENT_BITS, int FP64_MAX_EXPONENT>
108 inline FP16_INT_T fp_downconvert(double x)
109 {
110         union fp64 nx;
111         nx.f = x;
112         unsigned long long f = nx.ll;
113         int sign = f >> (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS);
114         int exponent = (f & ((1ULL << (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS)) - 1)) >> FP64_MANTISSA_BITS;
115         unsigned long long mantissa = f & ((1ULL << FP64_MANTISSA_BITS) - 1);
116
117         int sign16;
118         int exponent16;
119         unsigned long long mantissa16;
120
121         if (exponent == 0) {
122                 /*
123                  * Denormals, or zero. The largest possible 64-bit
124                  * denormal is about +- 2^-1022, and the smallest possible
125                  * 16-bit denormal is +- 2^-24. Thus, we can safely
126                  * just set all of these to zero (but keep the sign bit).
127                  */
128                 sign16 = sign;
129                 exponent16 = 0;
130                 mantissa16 = 0;
131         } else if (exponent == FP64_MAX_EXPONENT) {
132                 /*
133                  * Infinities or NaN (mantissa=0 => infinity, otherwise NaN).
134                  * We don't care much about NaNs, so let us just keep the first
135                  * bit (which signals signalling/ non-signalling) and make sure 
136                  * that we don't coerce NaNs down to infinities.
137                  */
138                 if (mantissa == 0) {
139                         sign16 = sign;
140                         exponent16 = FP16_MAX_EXPONENT;
141                         mantissa16 = 0;
142                 } else {
143                         sign16 = sign;  /* undefined */
144                         exponent16 = FP16_MAX_EXPONENT;
145                         mantissa16 = mantissa >> (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
146                         if (mantissa16 == 0) {
147                                 mantissa16 = 1;
148                         }
149                 }
150         } else {
151                 /* Re-bias the exponent, and check if we will create a denormal. */
152                 exponent16 = exponent + FP16_BIAS - FP64_BIAS;
153                 if (exponent16 <= 0) {
154                         int shift_amount = FP64_MANTISSA_BITS - FP16_MANTISSA_BITS - exponent16 + 1;
155                         sign16 = sign;
156                         exponent16 = 0;
157                         mantissa16 = shift_right_with_round(mantissa | (1ULL << FP64_MANTISSA_BITS), shift_amount);
158
159                         /*
160                          * We could actually have rounded back into the lowest possible non-denormal
161                          * here, so check for that.
162                          */
163                         if (mantissa16 == (1ULL << FP16_MANTISSA_BITS)) {
164                                 exponent16 = 1;
165                                 mantissa16 = 0;
166                         }
167                 } else {
168                         /*
169                          * First, round off the mantissa, since that could change
170                          * the exponent. We use standard IEEE 754r roundTiesToEven
171                          * mode.
172                          */
173                         sign16 = sign;
174                         mantissa16 = shift_right_with_round(mantissa, FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
175
176                         /* Check if we overflowed and need to increase the exponent. */
177                         if (mantissa16 == (1ULL << FP16_MANTISSA_BITS)) {
178                                 ++exponent16;
179                                 mantissa16 = 0;
180                         }
181
182                         /* Finally, check for overflow, and create +- inf if we need to. */
183                         if (exponent16 >= FP16_MAX_EXPONENT) {
184                                 exponent16 = FP16_MAX_EXPONENT;
185                                 mantissa16 = 0;
186                         }
187                 }
188         }
189
190         return (sign16 << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS))
191             | (exponent16 << FP16_MANTISSA_BITS)
192             | mantissa16;
193 }
194
195 const int FP64_BIAS = 1023;
196 const int FP64_MANTISSA_BITS = 52;
197 const int FP64_EXPONENT_BITS = 11;
198 const int FP64_MAX_EXPONENT = (1 << FP64_EXPONENT_BITS) - 1;
199
200 const int FP32_BIAS = 127;
201 const int FP32_MANTISSA_BITS = 23;
202 const int FP32_EXPONENT_BITS = 8;
203 const int FP32_MAX_EXPONENT = (1 << FP32_EXPONENT_BITS) - 1;
204
205 const int FP16_BIAS = 15;
206 const int FP16_MANTISSA_BITS = 10;
207 const int FP16_EXPONENT_BITS = 5;
208 const int FP16_MAX_EXPONENT = (1 << FP16_EXPONENT_BITS) - 1;
209
210 }  // namespace
211
212 double fp16_to_fp64(fp16_int_t x)
213 {
214         return fp_upconvert<fp16_int_t,
215                FP16_BIAS, FP16_MANTISSA_BITS, FP16_EXPONENT_BITS, FP16_MAX_EXPONENT,
216                FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
217 }
218
219 fp16_int_t fp64_to_fp16(double x)
220 {
221         return fp_downconvert<fp16_int_t,
222                FP16_BIAS, FP16_MANTISSA_BITS, FP16_EXPONENT_BITS, FP16_MAX_EXPONENT,
223                FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
224 }
225
226 double fp32_to_fp64(fp32_int_t x)
227 {
228         return fp_upconvert<fp32_int_t,
229                FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT,
230                FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
231 }
232
233 fp32_int_t fp64_to_fp32(double x)
234 {
235         return fp_downconvert<fp32_int_t,
236                FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT,
237                FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
238 }
239
240 }  // namespace