]> git.sesse.net Git - movit/blob - fp16.cpp
Make all fp16 routines work with fp32 as input instead of fp64, since that is what...
[movit] / fp16.cpp
1 #include "fp16.h"
2
3 namespace movit {
4 namespace {
5
6 union fp32 {
7         float f;
8         unsigned int u;
9 };
10
11 template<class FP16_INT_T,
12          int FP16_BIAS, int FP16_MANTISSA_BITS, int FP16_EXPONENT_BITS, int FP16_MAX_EXPONENT,
13          int FP32_BIAS, int FP32_MANTISSA_BITS, int FP32_EXPONENT_BITS, int FP32_MAX_EXPONENT>
14 inline float fp_upconvert(FP16_INT_T x)
15 {
16         int sign = x.val >> (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS);
17         int exponent = (x.val & ((1U << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) - 1)) >> FP16_MANTISSA_BITS;
18         unsigned int mantissa = x.val & ((1U << FP16_MANTISSA_BITS) - 1);
19
20         int sign32;
21         int exponent32;
22         unsigned int mantissa32;
23
24         if (exponent == 0) {
25                 /* 
26                  * Denormals, or zero. Zero is still zero, denormals become
27                  * ordinary numbers.
28                  */
29                 if (mantissa == 0) {
30                         sign32 = sign;
31                         exponent32 = 0;
32                         mantissa32 = 0;
33                 } else {
34                         sign32 = sign;
35                         exponent32 = FP32_BIAS - FP16_BIAS;
36                         mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS + 1);
37
38                         /* Normalize the number. */
39                         while ((mantissa32 & (1U << FP32_MANTISSA_BITS)) == 0) {
40                                 --exponent32;
41                                 mantissa32 <<= 1;
42                         }
43
44                         /* Clear the now-implicit one-bit. */
45                         mantissa32 &= ~(1U << FP32_MANTISSA_BITS);
46                 }
47         } else if (exponent == FP16_MAX_EXPONENT) {
48                 /*
49                  * Infinities or NaN (mantissa=0 => infinity, otherwise NaN).
50                  * We don't care much about NaNs, so let us just make sure we
51                  * keep the first bit (which signals signalling/non-signalling
52                  * in many implementations).
53                  */
54                 sign32 = sign;
55                 exponent32 = FP32_MAX_EXPONENT;
56                 mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
57         } else {
58                 sign32 = sign;
59
60                 /* Up-conversion is simple. Just re-bias the exponent... */
61                 exponent32 = exponent + FP32_BIAS - FP16_BIAS;
62
63                 /* ...and convert the mantissa. */
64                 mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
65         }
66
67         union fp32 nx;
68         nx.u = ((unsigned int)sign32 << (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS))
69             | ((unsigned int)exponent32 << FP32_MANTISSA_BITS)
70             | mantissa32;
71         return nx.f;
72 }
73
74 unsigned int shift_right_with_round(unsigned int x, unsigned shift)
75 {
76         /* shifts >= 32 need to be special-cased */
77         if (shift > 32) {
78                 return 0;
79         } else if (shift == 32) {
80                 if (x > (1U << 31)) {
81                         return 1;
82                 } else {
83                         return 0;
84                 }
85         }
86
87         unsigned int round_part = x & ((1U << shift) - 1);
88         if (round_part < (1U << (shift - 1))) {
89                 /* round down */
90                 x >>= shift;
91         } else if (round_part > (1U << (shift - 1))) {
92                 /* round up */
93                 x >>= shift;
94                 ++x;
95         } else {
96                 /* round to nearest even number */
97                 x >>= shift;
98                 if ((x & 1) != 0) {
99                         ++x;
100                 }
101         }
102         return x;
103 }
104
105 template<class FP16_INT_T,
106          int FP16_BIAS, int FP16_MANTISSA_BITS, int FP16_EXPONENT_BITS, int FP16_MAX_EXPONENT,
107          int FP32_BIAS, int FP32_MANTISSA_BITS, int FP32_EXPONENT_BITS, int FP32_MAX_EXPONENT>
108 inline FP16_INT_T fp_downconvert(float x)
109 {
110         union fp32 nx;
111         nx.f = x;
112         unsigned int f = nx.u;
113         int sign = f >> (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS);
114         int exponent = (f & ((1U << (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS)) - 1)) >> FP32_MANTISSA_BITS;
115         unsigned int mantissa = f & ((1U << FP32_MANTISSA_BITS) - 1);
116
117         int sign16;
118         int exponent16;
119         unsigned int mantissa16;
120
121         if (exponent == 0) {
122                 /*
123                  * Denormals, or zero. The largest possible 32-bit
124                  * denormal is about +- 2^-1022, and the smallest possible
125                  * 16-bit denormal is +- 2^-24. Thus, we can safely
126                  * just set all of these to zero (but keep the sign bit).
127                  */
128                 sign16 = sign;
129                 exponent16 = 0;
130                 mantissa16 = 0;
131         } else if (exponent == FP32_MAX_EXPONENT) {
132                 /*
133                  * Infinities or NaN (mantissa=0 => infinity, otherwise NaN).
134                  * We don't care much about NaNs, so let us just keep the first
135                  * bit (which signals signalling/ non-signalling) and make sure 
136                  * that we don't coerce NaNs down to infinities.
137                  */
138                 if (mantissa == 0) {
139                         sign16 = sign;
140                         exponent16 = FP16_MAX_EXPONENT;
141                         mantissa16 = 0;
142                 } else {
143                         sign16 = sign;  /* undefined */
144                         exponent16 = FP16_MAX_EXPONENT;
145                         mantissa16 = mantissa >> (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
146                         if (mantissa16 == 0) {
147                                 mantissa16 = 1;
148                         }
149                 }
150         } else {
151                 /* Re-bias the exponent, and check if we will create a denormal. */
152                 exponent16 = exponent + FP16_BIAS - FP32_BIAS;
153                 if (exponent16 <= 0) {
154                         int shift_amount = FP32_MANTISSA_BITS - FP16_MANTISSA_BITS - exponent16 + 1;
155                         sign16 = sign;
156                         exponent16 = 0;
157                         mantissa16 = shift_right_with_round(mantissa | (1U << FP32_MANTISSA_BITS), shift_amount);
158
159                         /*
160                          * We could actually have rounded back into the lowest possible non-denormal
161                          * here, so check for that.
162                          */
163                         if (mantissa16 == (1U << FP16_MANTISSA_BITS)) {
164                                 exponent16 = 1;
165                                 mantissa16 = 0;
166                         }
167                 } else {
168                         /*
169                          * First, round off the mantissa, since that could change
170                          * the exponent. We use standard IEEE 754r roundTiesToEven
171                          * mode.
172                          */
173                         sign16 = sign;
174                         mantissa16 = shift_right_with_round(mantissa, FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
175
176                         /* Check if we overflowed and need to increase the exponent. */
177                         if (mantissa16 == (1U << FP16_MANTISSA_BITS)) {
178                                 ++exponent16;
179                                 mantissa16 = 0;
180                         }
181
182                         /* Finally, check for overflow, and create +- inf if we need to. */
183                         if (exponent16 >= FP16_MAX_EXPONENT) {
184                                 exponent16 = FP16_MAX_EXPONENT;
185                                 mantissa16 = 0;
186                         }
187                 }
188         }
189
190         FP16_INT_T ret;
191         ret.val = (sign16 << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS))
192             | (exponent16 << FP16_MANTISSA_BITS)
193             | mantissa16;
194         return ret;
195 }
196
197 const int FP64_BIAS = 1023;
198 const int FP64_MANTISSA_BITS = 52;
199 const int FP64_EXPONENT_BITS = 11;
200 const int FP64_MAX_EXPONENT = (1 << FP64_EXPONENT_BITS) - 1;
201
202 const int FP32_BIAS = 127;
203 const int FP32_MANTISSA_BITS = 23;
204 const int FP32_EXPONENT_BITS = 8;
205 const int FP32_MAX_EXPONENT = (1 << FP32_EXPONENT_BITS) - 1;
206
207 const int FP16_BIAS = 15;
208 const int FP16_MANTISSA_BITS = 10;
209 const int FP16_EXPONENT_BITS = 5;
210 const int FP16_MAX_EXPONENT = (1 << FP16_EXPONENT_BITS) - 1;
211
212 }  // namespace
213
214 #ifndef __F16C__
215
216 float fp16_to_fp32(fp16_int_t x)
217 {
218         return fp_upconvert<fp16_int_t,
219                FP16_BIAS, FP16_MANTISSA_BITS, FP16_EXPONENT_BITS, FP16_MAX_EXPONENT,
220                FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT>(x);
221 }
222
223 fp16_int_t fp32_to_fp16(float x)
224 {
225         return fp_downconvert<fp16_int_t,
226                FP16_BIAS, FP16_MANTISSA_BITS, FP16_EXPONENT_BITS, FP16_MAX_EXPONENT,
227                FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT>(x);
228 }
229
230 #endif
231
232 }  // namespace