From: mstembera Date: Sat, 12 Dec 2020 22:18:38 +0000 (-0800) Subject: AVX512, AVX2 and SSSE3 speedups X-Git-Url: https://git.sesse.net/?p=stockfish;a=commitdiff_plain;h=d862ba40692797031ec5b0d95e46bcfc5a80f06c AVX512, AVX2 and SSSE3 speedups Improves throughput by summing 2 intermediate dot products using 16 bit addition before upconverting to 32 bit. Potential saturation is detected and the code-path is avoided in this case. The saturation can't happen with the current nets, but nets can be constructed that trigger this check. STC https://tests.stockfishchess.org/tests/view/5fd40a861ac1691201888479 LLR: 2.94 (-2.94,2.94) {-0.25,1.25} Total: 25544 W: 2451 L: 2296 D: 20797 Ptnml(0-2): 92, 1761, 8925, 1888, 106 about 5% speedup closes https://github.com/official-stockfish/Stockfish/pull/3261 No functional change --- diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index 0e0515f9..a715ca85 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -66,6 +66,53 @@ namespace Eval::NNUE::Layers { biases_[i] = read_little_endian(stream); for (std::size_t i = 0; i < kOutputDimensions * kPaddedInputDimensions; ++i) weights_[i] = read_little_endian(stream); + +#if defined (USE_SSSE3) + // Determine if quadruplets of weight and input products can be summed using 16bits + // without saturation. We assume worst case combinations of 0 and 127 for all inputs. + if (!stream.fail()) + { + auto can_saturate = [](const WeightType* w, int idx[4]) { + int pSum = 0, nSum = 0; + for (int p = 0; p < 4; ++p) + if (w[idx[p]] > 0) + pSum += w[idx[p]]; + else + nSum += w[idx[p]]; + + return pSum > 258 || nSum < -258; + }; + + for (IndexType i = 0; i < kOutputDimensions; ++i) + { + canSaturate16[i] = false; + const WeightType* w = &weights_[i * kPaddedInputDimensions]; +#if defined (USE_AVX512) + for (IndexType j = 0; j < (kPaddedInputDimensions & ~127) && !canSaturate16[i]; j += 128) + for (int k = 0; k < 64 && !canSaturate16[i]; k += 2) + { + int spacing[4] = { 0, 1, 64, 65 }; + canSaturate16[i] = can_saturate(&w[j + k], spacing); + } +#elif defined (USE_AVX2) + for (IndexType j = 0; j < (kPaddedInputDimensions & ~63) && !canSaturate16[i]; j += 64) + for (int k = 0; k < 32 && !canSaturate16[i]; k += 2) + { + int spacing[4] = { 0, 1, 32, 33 }; + canSaturate16[i] = can_saturate(&w[j + k], spacing); + } +#elif defined (USE_SSSE3) + for (IndexType j = 0; j < (kPaddedInputDimensions & ~31) && !canSaturate16[i]; j += 32) + for (int k = 0; k < 16 && !canSaturate16[i]; k += 2) + { + int spacing[4] = { 0, 1, 16, 17 }; + canSaturate16[i] = can_saturate(&w[j + k], spacing); + } +#endif + } + } +#endif + return !stream.fail(); } @@ -181,13 +228,26 @@ namespace Eval::NNUE::Layers { return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias); }; -#if defined (USE_VNNI) [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) { +#if defined (USE_VNNI) acc = _mm512_dpbusd_epi32(acc, a, b); #else - [[maybe_unused]] auto m512_dpbusd_epi32 = [=](__m512i a, __m512i b) -> __m512i { __m512i product0 = _mm512_maddubs_epi16(a, b); - return _mm512_madd_epi16(product0, kOnes512); + product0 = _mm512_madd_epi16(product0, kOnes512); + acc = _mm512_add_epi32(acc, product0); +#endif + }; + + [[maybe_unused]] auto m512_add_dpbusd_epi32x2 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) { +#if defined (USE_VNNI) + acc = _mm512_dpbusd_epi32(acc, a0, b0); + acc = _mm512_dpbusd_epi32(acc, a1, b1); +#else + __m512i product0 = _mm512_maddubs_epi16(a0, b0); + __m512i product1 = _mm512_maddubs_epi16(a1, b1); + product0 = _mm512_adds_epi16(product0, product1); + product0 = _mm512_madd_epi16(product0, kOnes512); + acc = _mm512_add_epi32(acc, product0); #endif }; @@ -214,13 +274,27 @@ namespace Eval::NNUE::Layers { return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias); }; -#if defined (USE_VNNI) + [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) { +#if defined (USE_VNNI) acc = _mm256_dpbusd_epi32(acc, a, b); #else - [[maybe_unused]] auto m256_dpbusd_epi32 = [=](__m256i a, __m256i b) -> __m256i { __m256i product0 = _mm256_maddubs_epi16(a, b); - return _mm256_madd_epi16(product0, kOnes256); + product0 = _mm256_madd_epi16(product0, kOnes256); + acc = _mm256_add_epi32(acc, product0); +#endif + }; + + [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) { +#if defined (USE_VNNI) + acc = _mm256_dpbusd_epi32(acc, a0, b0); + acc = _mm256_dpbusd_epi32(acc, a1, b1); +#else + __m256i product0 = _mm256_maddubs_epi16(a0, b0); + __m256i product1 = _mm256_maddubs_epi16(a1, b1); + product0 = _mm256_adds_epi16(product0, product1); + product0 = _mm256_madd_epi16(product0, kOnes256); + acc = _mm256_add_epi32(acc, product0); #endif }; @@ -245,9 +319,18 @@ namespace Eval::NNUE::Layers { return _mm_add_epi32(sum0, bias); }; - [[maybe_unused]] auto m128_dpbusd_epi32 = [=](__m128i a, __m128i b) -> __m128i { + [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) { __m128i product0 = _mm_maddubs_epi16(a, b); - return _mm_madd_epi16(product0, kOnes128); + product0 = _mm_madd_epi16(product0, kOnes128); + acc = _mm_add_epi32(acc, product0); + }; + + [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) { + __m128i product0 = _mm_maddubs_epi16(a0, b0); + __m128i product1 = _mm_maddubs_epi16(a1, b1); + product0 = _mm_adds_epi16(product0, product1); + product0 = _mm_madd_epi16(product0, kOnes128); + acc = _mm_add_epi32(acc, product0); }; #endif @@ -291,6 +374,15 @@ namespace Eval::NNUE::Layers { const __m512i bias = *reinterpret_cast(&biases_[i]); __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]); + __m512i sum01a = _mm512_setzero_si512(); + __m512i sum23a = _mm512_setzero_si512(); + __m512i sum45a = _mm512_setzero_si512(); + __m512i sum67a = _mm512_setzero_si512(); + __m512i sum01b = _mm512_setzero_si512(); + __m512i sum23b = _mm512_setzero_si512(); + __m512i sum45b = _mm512_setzero_si512(); + __m512i sum67b = _mm512_setzero_si512(); + const auto row01a = *reinterpret_cast(&weights_[offset01a]); const auto row23a = *reinterpret_cast(&weights_[offset23a]); const auto row45a = *reinterpret_cast(&weights_[offset45a]); @@ -303,16 +395,6 @@ namespace Eval::NNUE::Layers { const __m256i in256 = input_vector256[0]; const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1); -#if defined (USE_VNNI) - __m512i sum01a = _mm512_setzero_si512(); - __m512i sum23a = _mm512_setzero_si512(); - __m512i sum45a = _mm512_setzero_si512(); - __m512i sum67a = _mm512_setzero_si512(); - __m512i sum01b = _mm512_setzero_si512(); - __m512i sum23b = _mm512_setzero_si512(); - __m512i sum45b = _mm512_setzero_si512(); - __m512i sum67b = _mm512_setzero_si512(); - m512_add_dpbusd_epi32(sum01a, in, row01a); m512_add_dpbusd_epi32(sum23a, in, row23a); m512_add_dpbusd_epi32(sum45a, in, row45a); @@ -321,16 +403,6 @@ namespace Eval::NNUE::Layers { m512_add_dpbusd_epi32(sum23b, in, row23b); m512_add_dpbusd_epi32(sum45b, in, row45b); m512_add_dpbusd_epi32(sum67b, in, row67b); -#else - __m512i sum01a = m512_dpbusd_epi32(in, row01a); - __m512i sum23a = m512_dpbusd_epi32(in, row23a); - __m512i sum45a = m512_dpbusd_epi32(in, row45a); - __m512i sum67a = m512_dpbusd_epi32(in, row67a); - __m512i sum01b = m512_dpbusd_epi32(in, row01b); - __m512i sum23b = m512_dpbusd_epi32(in, row23b); - __m512i sum45b = m512_dpbusd_epi32(in, row45b); - __m512i sum67b = m512_dpbusd_epi32(in, row67b); -#endif *outptr = m512_hadd256x16( sum01a, sum23a, sum45a, sum67a, @@ -351,80 +423,62 @@ namespace Eval::NNUE::Layers { if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0) { - const auto row0 = reinterpret_cast(&weights_[offset0]); - const auto row1 = reinterpret_cast(&weights_[offset1]); - const auto row2 = reinterpret_cast(&weights_[offset2]); - const auto row3 = reinterpret_cast(&weights_[offset3]); - -#if defined (USE_VNNI) __m512i sum0 = _mm512_setzero_si512(); __m512i sum1 = _mm512_setzero_si512(); __m512i sum2 = _mm512_setzero_si512(); __m512i sum3 = _mm512_setzero_si512(); - const IndexType kStart = 0; -#else - __m512i sum0 = m512_dpbusd_epi32(input_vector512[0], row0[0]); - __m512i sum1 = m512_dpbusd_epi32(input_vector512[0], row1[0]); - __m512i sum2 = m512_dpbusd_epi32(input_vector512[0], row2[0]); - __m512i sum3 = m512_dpbusd_epi32(input_vector512[0], row3[0]); - const IndexType kStart = 1; -#endif - for (IndexType j = kStart; j < kNumChunks512; ++j) + const auto row0 = reinterpret_cast(&weights_[offset0]); + const auto row1 = reinterpret_cast(&weights_[offset1]); + const auto row2 = reinterpret_cast(&weights_[offset2]); + const auto row3 = reinterpret_cast(&weights_[offset3]); + + int j = 0; + if (!canSaturate16x4[i / 4]) + { + for (; j < (int)kNumChunks512 - 1; j += 2) + { + const __m512i in0 = input_vector512[j]; + const __m512i in1 = input_vector512[j + 1]; + + m512_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]); + m512_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]); + m512_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]); + m512_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]); + } + } + for (; j < (int)kNumChunks512; ++j) { const __m512i in = input_vector512[j]; -#if defined (USE_VNNI) m512_add_dpbusd_epi32(sum0, in, row0[j]); m512_add_dpbusd_epi32(sum1, in, row1[j]); m512_add_dpbusd_epi32(sum2, in, row2[j]); m512_add_dpbusd_epi32(sum3, in, row3[j]); -#else - sum0 = _mm512_add_epi32(sum0, m512_dpbusd_epi32(in, row0[j])); - sum1 = _mm512_add_epi32(sum1, m512_dpbusd_epi32(in, row1[j])); - sum2 = _mm512_add_epi32(sum2, m512_dpbusd_epi32(in, row2[j])); - sum3 = _mm512_add_epi32(sum3, m512_dpbusd_epi32(in, row3[j])); -#endif } *outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias); } else { - const auto row0 = reinterpret_cast(&weights_[offset0]); - const auto row1 = reinterpret_cast(&weights_[offset1]); - const auto row2 = reinterpret_cast(&weights_[offset2]); - const auto row3 = reinterpret_cast(&weights_[offset3]); - -#if defined (USE_VNNI) __m256i sum0 = _mm256_setzero_si256(); __m256i sum1 = _mm256_setzero_si256(); __m256i sum2 = _mm256_setzero_si256(); __m256i sum3 = _mm256_setzero_si256(); - const IndexType kStart = 0; -#else - __m256i sum0 = m256_dpbusd_epi32(input_vector256[0], row0[0]); - __m256i sum1 = m256_dpbusd_epi32(input_vector256[0], row1[0]); - __m256i sum2 = m256_dpbusd_epi32(input_vector256[0], row2[0]); - __m256i sum3 = m256_dpbusd_epi32(input_vector256[0], row3[0]); - const IndexType kStart = 1; -#endif - for (IndexType j = kStart; j < kNumChunks256; ++j) + const auto row0 = reinterpret_cast(&weights_[offset0]); + const auto row1 = reinterpret_cast(&weights_[offset1]); + const auto row2 = reinterpret_cast(&weights_[offset2]); + const auto row3 = reinterpret_cast(&weights_[offset3]); + + for (IndexType j = 0; j < kNumChunks256; ++j) { const __m256i in = input_vector256[j]; -#if defined (USE_VNNI) m256_add_dpbusd_epi32(sum0, in, row0[j]); m256_add_dpbusd_epi32(sum1, in, row1[j]); m256_add_dpbusd_epi32(sum2, in, row2[j]); m256_add_dpbusd_epi32(sum3, in, row3[j]); -#else - sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j])); - sum1 = _mm256_add_epi32(sum1, m256_dpbusd_epi32(in, row1[j])); - sum2 = _mm256_add_epi32(sum2, m256_dpbusd_epi32(in, row2[j])); - sum3 = _mm256_add_epi32(sum3, m256_dpbusd_epi32(in, row3[j])); -#endif } *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias); @@ -435,50 +489,30 @@ namespace Eval::NNUE::Layers { { if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0) { - const auto row0 = reinterpret_cast(&weights_[0]); - -#if defined (USE_VNNI) __m512i sum0 = _mm512_setzero_si512(); - const IndexType kStart = 0; -#else - __m512i sum0 = m512_dpbusd_epi32(input_vector512[0], row0[0]); - const IndexType kStart = 1; -#endif - for (IndexType j = kStart; j < kNumChunks512; ++j) + const auto row0 = reinterpret_cast(&weights_[0]); + + for (IndexType j = 0; j < kNumChunks512; ++j) { const __m512i in = input_vector512[j]; -#if defined (USE_VNNI) m512_add_dpbusd_epi32(sum0, in, row0[j]); -#else - sum0 = _mm512_add_epi32(sum0, m512_dpbusd_epi32(in, row0[j])); -#endif } output[0] = m512_hadd(sum0, biases_[0]); } else { - const auto row0 = reinterpret_cast(&weights_[0]); - -#if defined (USE_VNNI) __m256i sum0 = _mm256_setzero_si256(); - const IndexType kStart = 0; -#else - __m256i sum0 = m256_dpbusd_epi32(input_vector256[0], row0[0]); - const IndexType kStart = 1; -#endif - for (IndexType j = kStart; j < kNumChunks256; ++j) + const auto row0 = reinterpret_cast(&weights_[0]); + + for (IndexType j = 0; j < kNumChunks256; ++j) { const __m256i in = input_vector256[j]; -#if defined (USE_VNNI) m256_add_dpbusd_epi32(sum0, in, row0[j]); -#else - sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j])); -#endif } output[0] = m256_hadd(sum0, biases_[0]); @@ -512,40 +546,38 @@ namespace Eval::NNUE::Layers { const __m128i bias = *reinterpret_cast(&biases_[i]); __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]); - const auto row0 = reinterpret_cast(&weights_[offset0]); - const auto row1 = reinterpret_cast(&weights_[offset1]); - const auto row2 = reinterpret_cast(&weights_[offset2]); - const auto row3 = reinterpret_cast(&weights_[offset3]); - -#if defined (USE_VNNI) __m256i sum0 = _mm256_setzero_si256(); __m256i sum1 = _mm256_setzero_si256(); __m256i sum2 = _mm256_setzero_si256(); __m256i sum3 = _mm256_setzero_si256(); - const IndexType kStart = 0; -#else - __m256i sum0 = m256_dpbusd_epi32(input_vector[0], row0[0]); - __m256i sum1 = m256_dpbusd_epi32(input_vector[0], row1[0]); - __m256i sum2 = m256_dpbusd_epi32(input_vector[0], row2[0]); - __m256i sum3 = m256_dpbusd_epi32(input_vector[0], row3[0]); - const IndexType kStart = 1; -#endif - for (IndexType j = kStart; j < kNumChunks; ++j) + const auto row0 = reinterpret_cast(&weights_[offset0]); + const auto row1 = reinterpret_cast(&weights_[offset1]); + const auto row2 = reinterpret_cast(&weights_[offset2]); + const auto row3 = reinterpret_cast(&weights_[offset3]); + + int j = 0; + if (!canSaturate16x4[i / 4]) { - const __m256i in = input_vector[j]; + for (; j < (int)kNumChunks - 1; j += 2) + { + const __m256i in0 = input_vector[j]; + const __m256i in1 = input_vector[j + 1]; + + m256_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]); + m256_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]); + m256_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]); + m256_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]); + } + } + for (; j < (int)kNumChunks; ++j) + { + const __m256i in = input_vector[j]; -#if defined (USE_VNNI) - m256_add_dpbusd_epi32(sum0, in, row0[j]); - m256_add_dpbusd_epi32(sum1, in, row1[j]); - m256_add_dpbusd_epi32(sum2, in, row2[j]); - m256_add_dpbusd_epi32(sum3, in, row3[j]); -#else - sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j])); - sum1 = _mm256_add_epi32(sum1, m256_dpbusd_epi32(in, row1[j])); - sum2 = _mm256_add_epi32(sum2, m256_dpbusd_epi32(in, row2[j])); - sum3 = _mm256_add_epi32(sum3, m256_dpbusd_epi32(in, row3[j])); -#endif + m256_add_dpbusd_epi32(sum0, in, row0[j]); + m256_add_dpbusd_epi32(sum1, in, row1[j]); + m256_add_dpbusd_epi32(sum2, in, row2[j]); + m256_add_dpbusd_epi32(sum3, in, row3[j]); } *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias); @@ -553,25 +585,15 @@ namespace Eval::NNUE::Layers { } else if constexpr (kOutputDimensions == 1) { - const auto row0 = reinterpret_cast(&weights_[0]); - -#if defined (USE_VNNI) __m256i sum0 = _mm256_setzero_si256(); - const IndexType kStart = 0; -#else - __m256i sum0 = m256_dpbusd_epi32(input_vector[0], row0[0]); - const IndexType kStart = 1; -#endif - for (IndexType j = kStart; j < kNumChunks; ++j) + const auto row0 = reinterpret_cast(&weights_[0]); + + for (IndexType j = 0; j < kNumChunks; ++j) { - const __m256i in = input_vector[j]; + const __m256i in = input_vector[j]; -#if defined (USE_VNNI) - m256_add_dpbusd_epi32(sum0, in, row0[j]); -#else - sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j])); -#endif + m256_add_dpbusd_epi32(sum0, in, row0[j]); } output[0] = m256_hadd(sum0, biases_[0]); @@ -604,24 +626,38 @@ namespace Eval::NNUE::Layers { const __m128i bias = *reinterpret_cast(&biases_[i]); __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]); + __m128i sum0 = _mm_setzero_si128(); + __m128i sum1 = _mm_setzero_si128(); + __m128i sum2 = _mm_setzero_si128(); + __m128i sum3 = _mm_setzero_si128(); + const auto row0 = reinterpret_cast(&weights_[offset0]); const auto row1 = reinterpret_cast(&weights_[offset1]); const auto row2 = reinterpret_cast(&weights_[offset2]); const auto row3 = reinterpret_cast(&weights_[offset3]); - __m128i sum0 = m128_dpbusd_epi32(input_vector[0], row0[0]); - __m128i sum1 = m128_dpbusd_epi32(input_vector[0], row1[0]); - __m128i sum2 = m128_dpbusd_epi32(input_vector[0], row2[0]); - __m128i sum3 = m128_dpbusd_epi32(input_vector[0], row3[0]); - - for (int j = 1; j < (int)kNumChunks; ++j) + int j = 0; + if (!canSaturate16x4[i / 4]) + { + for (; j < (int)kNumChunks - 1; j += 2) + { + const __m128i in0 = input_vector[j]; + const __m128i in1 = input_vector[j + 1]; + + m128_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]); + m128_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]); + m128_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]); + m128_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]); + } + } + for (; j < (int)kNumChunks; ++j) { - const __m128i in = input_vector[j]; + const __m128i in = input_vector[j]; - sum0 = _mm_add_epi32(sum0, m128_dpbusd_epi32(in, row0[j])); - sum1 = _mm_add_epi32(sum1, m128_dpbusd_epi32(in, row1[j])); - sum2 = _mm_add_epi32(sum2, m128_dpbusd_epi32(in, row2[j])); - sum3 = _mm_add_epi32(sum3, m128_dpbusd_epi32(in, row3[j])); + m128_add_dpbusd_epi32(sum0, in, row0[j]); + m128_add_dpbusd_epi32(sum1, in, row1[j]); + m128_add_dpbusd_epi32(sum2, in, row2[j]); + m128_add_dpbusd_epi32(sum3, in, row3[j]); } *outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias); @@ -629,12 +665,16 @@ namespace Eval::NNUE::Layers { } else if constexpr (kOutputDimensions == 1) { + __m128i sum0 = _mm_setzero_si128(); + const auto row0 = reinterpret_cast(&weights_[0]); - __m128i sum0 = m128_dpbusd_epi32(input_vector[0], row0[0]); + for (int j = 0; j < (int)kNumChunks; ++j) + { + const __m128i in = input_vector[j]; - for (int j = 1; j < (int)kNumChunks; ++j) - sum0 = _mm_add_epi32(sum0, m128_dpbusd_epi32(input_vector[j], row0[j])); + m128_add_dpbusd_epi32(sum0, in, row0[j]); + } output[0] = m128_hadd(sum0, biases_[0]); } @@ -751,8 +791,11 @@ namespace Eval::NNUE::Layers { PreviousLayer previous_layer_; alignas(kCacheLineSize) BiasType biases_[kOutputDimensions]; - alignas(kCacheLineSize) - WeightType weights_[kOutputDimensions * kPaddedInputDimensions]; + alignas(kCacheLineSize) WeightType weights_[kOutputDimensions * kPaddedInputDimensions]; + union { + uint32_t canSaturate16x4[(kOutputDimensions + 3) / 4]; + bool canSaturate16[kOutputDimensions]; + }; }; } // namespace Eval::NNUE::Layers