]> git.sesse.net Git - movit/blobdiff - fp16.cpp
Make all fp16 routines work with fp32 as input instead of fp64, since that is what...
[movit] / fp16.cpp
index 3738f5c4360bd055cdd5a39ed32d9aecdb8a3158..d83871bd358841efd564eaaed4ed2477245be239 100644 (file)
--- a/fp16.cpp
+++ b/fp16.cpp
@@ -3,23 +3,23 @@
 namespace movit {
 namespace {
 
-union fp64 {
-       double f;
-       unsigned long long ll;
+union fp32 {
+       float f;
+       unsigned int u;
 };
 
 template<class FP16_INT_T,
          int FP16_BIAS, int FP16_MANTISSA_BITS, int FP16_EXPONENT_BITS, int FP16_MAX_EXPONENT,
-         int FP64_BIAS, int FP64_MANTISSA_BITS, int FP64_EXPONENT_BITS, int FP64_MAX_EXPONENT>
-inline double fp_upconvert(FP16_INT_T x)
+         int FP32_BIAS, int FP32_MANTISSA_BITS, int FP32_EXPONENT_BITS, int FP32_MAX_EXPONENT>
+inline float fp_upconvert(FP16_INT_T x)
 {
-       int sign = x >> (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS);
-       int exponent = (x & ((1ULL << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) - 1)) >> FP16_MANTISSA_BITS;
-       unsigned long long mantissa = x & ((1ULL << FP16_MANTISSA_BITS) - 1);
+       int sign = x.val >> (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS);
+       int exponent = (x.val & ((1U << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) - 1)) >> FP16_MANTISSA_BITS;
+       unsigned int mantissa = x.val & ((1U << FP16_MANTISSA_BITS) - 1);
 
-       int sign64;
-       int exponent64;
-       unsigned long long mantissa64;
+       int sign32;
+       int exponent32;
+       unsigned int mantissa32;
 
        if (exponent == 0) {
                /* 
@@ -27,22 +27,22 @@ inline double fp_upconvert(FP16_INT_T x)
                 * ordinary numbers.
                 */
                if (mantissa == 0) {
-                       sign64 = sign;
-                       exponent64 = 0;
-                       mantissa64 = 0;
+                       sign32 = sign;
+                       exponent32 = 0;
+                       mantissa32 = 0;
                } else {
-                       sign64 = sign;
-                       exponent64 = FP64_BIAS - FP16_BIAS;
-                       mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS + 1);
+                       sign32 = sign;
+                       exponent32 = FP32_BIAS - FP16_BIAS;
+                       mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS + 1);
 
                        /* Normalize the number. */
-                       while ((mantissa64 & (1ULL << FP64_MANTISSA_BITS)) == 0) {
-                               --exponent64;
-                               mantissa64 <<= 1;
+                       while ((mantissa32 & (1U << FP32_MANTISSA_BITS)) == 0) {
+                               --exponent32;
+                               mantissa32 <<= 1;
                        }
 
                        /* Clear the now-implicit one-bit. */
-                       mantissa64 &= ~(1ULL << FP64_MANTISSA_BITS);
+                       mantissa32 &= ~(1U << FP32_MANTISSA_BITS);
                }
        } else if (exponent == FP16_MAX_EXPONENT) {
                /*
@@ -51,44 +51,44 @@ inline double fp_upconvert(FP16_INT_T x)
                 * keep the first bit (which signals signalling/non-signalling
                 * in many implementations).
                 */
-               sign64 = sign;
-               exponent64 = FP64_MAX_EXPONENT;
-               mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
+               sign32 = sign;
+               exponent32 = FP32_MAX_EXPONENT;
+               mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
        } else {
-               sign64 = sign;
+               sign32 = sign;
 
                /* Up-conversion is simple. Just re-bias the exponent... */
-               exponent64 = exponent + FP64_BIAS - FP16_BIAS;
+               exponent32 = exponent + FP32_BIAS - FP16_BIAS;
 
                /* ...and convert the mantissa. */
-               mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
+               mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
        }
 
-       union fp64 nx;
-       nx.ll = ((unsigned long long)sign64 << (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS))
-           | ((unsigned long long)exponent64 << FP64_MANTISSA_BITS)
-           | mantissa64;
+       union fp32 nx;
+       nx.u = ((unsigned int)sign32 << (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS))
+           | ((unsigned int)exponent32 << FP32_MANTISSA_BITS)
+           | mantissa32;
        return nx.f;
 }
-               
-unsigned long long shift_right_with_round(unsigned long long x, unsigned shift)
+
+unsigned int shift_right_with_round(unsigned int x, unsigned shift)
 {
-       /* shifts >= 64 need to be special-cased */
-       if (shift > 64) {
+       /* shifts >= 32 need to be special-cased */
+       if (shift > 32) {
                return 0;
-       } else if (shift == 64) {
-               if (x > (1ULL << 63)) {
+       } else if (shift == 32) {
+               if (x > (1U << 31)) {
                        return 1;
                } else {
                        return 0;
                }
        }
 
-       unsigned long long round_part = x & ((1ULL << shift) - 1);
-       if (round_part < (1ULL << (shift - 1))) {
+       unsigned int round_part = x & ((1U << shift) - 1);
+       if (round_part < (1U << (shift - 1))) {
                /* round down */
                x >>= shift;
-       } else if (round_part > (1ULL << (shift - 1))) {
+       } else if (round_part > (1U << (shift - 1))) {
                /* round up */
                x >>= shift;
                ++x;
@@ -104,23 +104,23 @@ unsigned long long shift_right_with_round(unsigned long long x, unsigned shift)
 
 template<class FP16_INT_T,
          int FP16_BIAS, int FP16_MANTISSA_BITS, int FP16_EXPONENT_BITS, int FP16_MAX_EXPONENT,
-         int FP64_BIAS, int FP64_MANTISSA_BITS, int FP64_EXPONENT_BITS, int FP64_MAX_EXPONENT>
-inline FP16_INT_T fp_downconvert(double x)
+         int FP32_BIAS, int FP32_MANTISSA_BITS, int FP32_EXPONENT_BITS, int FP32_MAX_EXPONENT>
+inline FP16_INT_T fp_downconvert(float x)
 {
-       union fp64 nx;
+       union fp32 nx;
        nx.f = x;
-       unsigned long long f = nx.ll;
-       int sign = f >> (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS);
-       int exponent = (f & ((1ULL << (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS)) - 1)) >> FP64_MANTISSA_BITS;
-       unsigned long long mantissa = f & ((1ULL << FP64_MANTISSA_BITS) - 1);
+       unsigned int f = nx.u;
+       int sign = f >> (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS);
+       int exponent = (f & ((1U << (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS)) - 1)) >> FP32_MANTISSA_BITS;
+       unsigned int mantissa = f & ((1U << FP32_MANTISSA_BITS) - 1);
 
        int sign16;
        int exponent16;
-       unsigned long long mantissa16;
+       unsigned int mantissa16;
 
        if (exponent == 0) {
                /*
-                * Denormals, or zero. The largest possible 64-bit
+                * Denormals, or zero. The largest possible 32-bit
                 * denormal is about +- 2^-1022, and the smallest possible
                 * 16-bit denormal is +- 2^-24. Thus, we can safely
                 * just set all of these to zero (but keep the sign bit).
@@ -128,7 +128,7 @@ inline FP16_INT_T fp_downconvert(double x)
                sign16 = sign;
                exponent16 = 0;
                mantissa16 = 0;
-       } else if (exponent == FP64_MAX_EXPONENT) {
+       } else if (exponent == FP32_MAX_EXPONENT) {
                /*
                 * Infinities or NaN (mantissa=0 => infinity, otherwise NaN).
                 * We don't care much about NaNs, so let us just keep the first
@@ -142,25 +142,25 @@ inline FP16_INT_T fp_downconvert(double x)
                } else {
                        sign16 = sign;  /* undefined */
                        exponent16 = FP16_MAX_EXPONENT;
-                       mantissa16 = mantissa >> (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
+                       mantissa16 = mantissa >> (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
                        if (mantissa16 == 0) {
                                mantissa16 = 1;
                        }
                }
        } else {
                /* Re-bias the exponent, and check if we will create a denormal. */
-               exponent16 = exponent + FP16_BIAS - FP64_BIAS;
+               exponent16 = exponent + FP16_BIAS - FP32_BIAS;
                if (exponent16 <= 0) {
-                       int shift_amount = FP64_MANTISSA_BITS - FP16_MANTISSA_BITS - exponent16 + 1;
+                       int shift_amount = FP32_MANTISSA_BITS - FP16_MANTISSA_BITS - exponent16 + 1;
                        sign16 = sign;
                        exponent16 = 0;
-                       mantissa16 = shift_right_with_round(mantissa | (1ULL << FP64_MANTISSA_BITS), shift_amount);
+                       mantissa16 = shift_right_with_round(mantissa | (1U << FP32_MANTISSA_BITS), shift_amount);
 
                        /*
                         * We could actually have rounded back into the lowest possible non-denormal
                         * here, so check for that.
                         */
-                       if (mantissa16 == (1ULL << FP16_MANTISSA_BITS)) {
+                       if (mantissa16 == (1U << FP16_MANTISSA_BITS)) {
                                exponent16 = 1;
                                mantissa16 = 0;
                        }
@@ -171,10 +171,10 @@ inline FP16_INT_T fp_downconvert(double x)
                         * mode.
                         */
                        sign16 = sign;
-                       mantissa16 = shift_right_with_round(mantissa, FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
+                       mantissa16 = shift_right_with_round(mantissa, FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
 
                        /* Check if we overflowed and need to increase the exponent. */
-                       if (mantissa16 == (1ULL << FP16_MANTISSA_BITS)) {
+                       if (mantissa16 == (1U << FP16_MANTISSA_BITS)) {
                                ++exponent16;
                                mantissa16 = 0;
                        }
@@ -187,9 +187,11 @@ inline FP16_INT_T fp_downconvert(double x)
                }
        }
 
-       return (sign16 << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS))
+       FP16_INT_T ret;
+       ret.val = (sign16 << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS))
            | (exponent16 << FP16_MANTISSA_BITS)
            | mantissa16;
+       return ret;
 }
 
 const int FP64_BIAS = 1023;
@@ -209,32 +211,22 @@ const int FP16_MAX_EXPONENT = (1 << FP16_EXPONENT_BITS) - 1;
 
 }  // namespace
 
-double fp16_to_fp64(fp16_int_t x)
+#ifndef __F16C__
+
+float fp16_to_fp32(fp16_int_t x)
 {
        return fp_upconvert<fp16_int_t,
               FP16_BIAS, FP16_MANTISSA_BITS, FP16_EXPONENT_BITS, FP16_MAX_EXPONENT,
-              FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
+              FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT>(x);
 }
 
-fp16_int_t fp64_to_fp16(double x)
+fp16_int_t fp32_to_fp16(float x)
 {
        return fp_downconvert<fp16_int_t,
               FP16_BIAS, FP16_MANTISSA_BITS, FP16_EXPONENT_BITS, FP16_MAX_EXPONENT,
-              FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
+              FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT>(x);
 }
 
-double fp32_to_fp64(fp32_int_t x)
-{
-       return fp_upconvert<fp32_int_t,
-              FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT,
-              FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
-}
-
-fp32_int_t fp64_to_fp32(double x)
-{
-       return fp_downconvert<fp32_int_t,
-              FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT,
-              FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
-}
+#endif
 
 }  // namespace