From: Steinar H. Gunderson Date: Wed, 23 Sep 2015 23:59:47 +0000 (+0200) Subject: Fix a bug where combined fp16 weights would be horribly wrong. X-Git-Tag: 1.2.0~2 X-Git-Url: https://git.sesse.net/?p=movit;a=commitdiff_plain;h=6c954b4f0bff0743e13ce6ddcee8bda15b3af234 Fix a bug where combined fp16 weights would be horribly wrong. Seemingly weights were always returned as float, and then cast to fp16_int_t -- without proper conversion! And sum_sq_error would be calculated based on the correct value, not the broken- casted one. It's a small miracle the unit tests didn't catch this; they didn't until I started introducing small errors for another reason. Most real-world testing seems to have hit fp32, and thus this wasn't caught there either. Also make fp16_int_t a struct so that it is not implicitly convertible to/from numeric types, so this never ever can happen again. --- diff --git a/fp16.cpp b/fp16.cpp index fc5800e..e8993f9 100644 --- a/fp16.cpp +++ b/fp16.cpp @@ -13,9 +13,9 @@ template inline double fp_upconvert(FP16_INT_T x) { - int sign = x >> (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS); - int exponent = (x & ((1ULL << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) - 1)) >> FP16_MANTISSA_BITS; - unsigned long long mantissa = x & ((1ULL << FP16_MANTISSA_BITS) - 1); + int sign = x.val >> (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS); + int exponent = (x.val & ((1ULL << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) - 1)) >> FP16_MANTISSA_BITS; + unsigned long long mantissa = x.val & ((1ULL << FP16_MANTISSA_BITS) - 1); int sign64; int exponent64; @@ -187,9 +187,11 @@ inline FP16_INT_T fp_downconvert(double x) } } - return (sign16 << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) + FP16_INT_T ret; + ret.val = (sign16 << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) | (exponent16 << FP16_MANTISSA_BITS) | mantissa16; + return ret; } const int FP64_BIAS = 1023; diff --git a/fp16.h b/fp16.h index 5417e02..c21153b 100644 --- a/fp16.h +++ b/fp16.h @@ -14,8 +14,13 @@ namespace movit { -typedef unsigned int fp32_int_t; -typedef unsigned short fp16_int_t; +// structs instead of ints, so that they are not implicitly convertible. +struct fp32_int_t { + unsigned int val; +}; +struct fp16_int_t { + unsigned short val; +}; #ifdef __F16C__ @@ -23,14 +28,16 @@ typedef unsigned short fp16_int_t; // are at compile time). static inline double fp16_to_fp64(fp16_int_t x) { - return _cvtsh_ss(x); + return _cvtsh_ss(x.val); } static inline fp16_int_t fp64_to_fp16(double x) { // NOTE: Strictly speaking, there are some select values where this isn't correct, // since we first round to fp32 and then to fp16. - return _cvtss_sh(x, 0); + fp16_int_t ret; + ret.val = _cvtss_sh(x, 0); + return ret; } #else diff --git a/fp16_test.cpp b/fp16_test.cpp index bb8b182..e0920e9 100644 --- a/fp16_test.cpp +++ b/fp16_test.cpp @@ -4,31 +4,48 @@ #include namespace movit { +namespace { + +fp16_int_t make_fp16(unsigned short x) +{ + fp16_int_t ret; + ret.val = x; + return ret; +} + +fp32_int_t make_fp32(unsigned int x) +{ + fp32_int_t ret; + ret.val = x; + return ret; +} + +} // namespace TEST(FP16Test, Simple) { - EXPECT_EQ(0x0000, fp64_to_fp16(0.0)); - EXPECT_DOUBLE_EQ(0.0, fp16_to_fp64(0x0000)); + EXPECT_EQ(0x0000, fp64_to_fp16(0.0).val); + EXPECT_DOUBLE_EQ(0.0, fp16_to_fp64(make_fp16(0x0000))); - EXPECT_EQ(0x3c00, fp64_to_fp16(1.0)); - EXPECT_DOUBLE_EQ(1.0, fp16_to_fp64(0x3c00)); + EXPECT_EQ(0x3c00, fp64_to_fp16(1.0).val); + EXPECT_DOUBLE_EQ(1.0, fp16_to_fp64(make_fp16(0x3c00))); - EXPECT_EQ(0x3555, fp64_to_fp16(1.0 / 3.0)); - EXPECT_DOUBLE_EQ(0.333251953125, fp16_to_fp64(0x3555)); + EXPECT_EQ(0x3555, fp64_to_fp16(1.0 / 3.0).val); + EXPECT_DOUBLE_EQ(0.333251953125, fp16_to_fp64(make_fp16(0x3555))); } TEST(FP16Test, RoundToNearestEven) { - ASSERT_DOUBLE_EQ(1.0, fp16_to_fp64(0x3c00)); - - double x0 = fp16_to_fp64(0x3c00); - double x1 = fp16_to_fp64(0x3c01); - double x2 = fp16_to_fp64(0x3c02); - double x3 = fp16_to_fp64(0x3c03); - double x4 = fp16_to_fp64(0x3c04); - - EXPECT_EQ(0x3c00, fp64_to_fp16(0.5 * (x0 + x1))); - EXPECT_EQ(0x3c02, fp64_to_fp16(0.5 * (x1 + x2))); - EXPECT_EQ(0x3c02, fp64_to_fp16(0.5 * (x2 + x3))); - EXPECT_EQ(0x3c04, fp64_to_fp16(0.5 * (x3 + x4))); + ASSERT_DOUBLE_EQ(1.0, fp16_to_fp64(make_fp16(0x3c00))); + + double x0 = fp16_to_fp64(make_fp16(0x3c00)); + double x1 = fp16_to_fp64(make_fp16(0x3c01)); + double x2 = fp16_to_fp64(make_fp16(0x3c02)); + double x3 = fp16_to_fp64(make_fp16(0x3c03)); + double x4 = fp16_to_fp64(make_fp16(0x3c04)); + + EXPECT_EQ(0x3c00, fp64_to_fp16(0.5 * (x0 + x1)).val); + EXPECT_EQ(0x3c02, fp64_to_fp16(0.5 * (x1 + x2)).val); + EXPECT_EQ(0x3c02, fp64_to_fp16(0.5 * (x2 + x3)).val); + EXPECT_EQ(0x3c04, fp64_to_fp16(0.5 * (x3 + x4)).val); } union fp64 { @@ -42,8 +59,8 @@ union fp32 { TEST(FP16Test, NaN) { // Ignore the sign bit. - EXPECT_EQ(0x7e00, fp64_to_fp16(0.0 / 0.0) & 0x7fff); - EXPECT_TRUE(isnan(fp16_to_fp64(0xfe00))); + EXPECT_EQ(0x7e00, fp64_to_fp16(0.0 / 0.0).val & 0x7fff); + EXPECT_TRUE(isnan(fp16_to_fp64(make_fp16(0xfe00)))); fp64 borderline_inf; borderline_inf.ll = 0x7ff0000000000000ull; @@ -68,15 +85,15 @@ TEST(FP16Test, NaN) { TEST(FP16Test, Denormals) { const double smallest_fp16_denormal = 5.9604644775390625e-08; - EXPECT_EQ(0x0001, fp64_to_fp16(smallest_fp16_denormal)); - EXPECT_EQ(0x0000, fp64_to_fp16(0.5 * smallest_fp16_denormal)); // Round-to-even. - EXPECT_EQ(0x0001, fp64_to_fp16(0.51 * smallest_fp16_denormal)); - EXPECT_EQ(0x0002, fp64_to_fp16(1.5 * smallest_fp16_denormal)); + EXPECT_EQ(0x0001, fp64_to_fp16(smallest_fp16_denormal).val); + EXPECT_EQ(0x0000, fp64_to_fp16(0.5 * smallest_fp16_denormal).val); // Round-to-even. + EXPECT_EQ(0x0001, fp64_to_fp16(0.51 * smallest_fp16_denormal).val); + EXPECT_EQ(0x0002, fp64_to_fp16(1.5 * smallest_fp16_denormal).val); const double smallest_fp16_non_denormal = 6.103515625e-05; - EXPECT_EQ(0x0400, fp64_to_fp16(smallest_fp16_non_denormal)); - EXPECT_EQ(0x0400, fp64_to_fp16(smallest_fp16_non_denormal - 0.5 * smallest_fp16_denormal)); // Round-to-even. - EXPECT_EQ(0x03ff, fp64_to_fp16(smallest_fp16_non_denormal - smallest_fp16_denormal)); + EXPECT_EQ(0x0400, fp64_to_fp16(smallest_fp16_non_denormal).val); + EXPECT_EQ(0x0400, fp64_to_fp16(smallest_fp16_non_denormal - 0.5 * smallest_fp16_denormal).val); // Round-to-even. + EXPECT_EQ(0x03ff, fp64_to_fp16(smallest_fp16_non_denormal - smallest_fp16_denormal).val); } // Randomly test a large number of fp64 -> fp32 conversions, comparing @@ -93,7 +110,7 @@ TEST(FP16Test, FP32ReferenceDownconvert) { src.ll = (((unsigned long long)r1) << 33) ^ ((unsigned long long)r2 << 16) ^ r3; reference.f = float(src.f); - result.u = fp64_to_fp32(src.f); + result.u = fp64_to_fp32(src.f).val; EXPECT_EQ(isnan(result.f), isnan(reference.f)); if (!isnan(result.f)) { @@ -116,7 +133,7 @@ TEST(FP16Test, FP32ReferenceUpconvert) { src.u = ((unsigned long long)r1 << 16) ^ r2; reference.f = double(src.f); - result.f = fp32_to_fp64(src.u); + result.f = fp32_to_fp64(make_fp32(src.u)); EXPECT_EQ(isnan(result.f), isnan(reference.f)); if (!isnan(result.f)) { diff --git a/resample_effect.cpp b/resample_effect.cpp index 156098e..244a3e2 100644 --- a/resample_effect.cpp +++ b/resample_effect.cpp @@ -107,7 +107,7 @@ unsigned combine_samples(const Tap *src, Tap *dst, float num_s float pos2 = src[i + 1].pos; assert(pos2 > pos1); - fp16_int_t pos, total_weight; + DestFloat pos, total_weight; float sum_sq_error; combine_two_samples(w1, w2, pos1, pos2, num_subtexels, inv_num_subtexels, &pos, &total_weight, &sum_sq_error); diff --git a/util.cpp b/util.cpp index da6057e..59f1dcd 100644 --- a/util.cpp +++ b/util.cpp @@ -253,11 +253,11 @@ void combine_two_samples(float w1, float w2, float pos1, float pos2, float num_s // w = (a(1-z) + bz) / ((1-z)² + z²) // // If z had infinite precision, this would simply reduce to w = w1 + w2. - *total_weight = (w1 + z * (w2 - w1)) / (z * z + (1 - z) * (1 - z)); + *total_weight = from_fp64((w1 + z * (w2 - w1)) / (z * z + (1 - z) * (1 - z))); if (sum_sq_error != NULL) { - float err1 = *total_weight * (1 - z) - w1; - float err2 = *total_weight * z - w2; + float err1 = to_fp64(*total_weight) * (1 - z) - w1; + float err2 = to_fp64(*total_weight) * z - w2; *sum_sq_error = err1 * err1 + err2 * err2; } }