From c6f62363a657263a567a0cc9bae09f3c4016156d Mon Sep 17 00:00:00 2001 From: Gian-Carlo Pascutto Date: Mon, 14 Aug 2023 17:30:10 +0200 Subject: [PATCH] Simplify Square Clipped ReLU code. Squared numbers are never negative, so barring any wraparound there is no need to clamp to 0. From reading the code, there's no obvious way to get wraparound, so the entire operation can be simplified away. Updated original truncated code comments to be sensible. Verified by running ./stockfish bench 128 1 24 and by the following test: STC: https://tests.stockfishchess.org/tests/view/64da4db95b17f7c21c0eabe7 LLR: 2.94 (-2.94,2.94) <-1.75,0.25> Total: 60224 W: 15425 L: 15236 D: 29563 Ptnml(0-2): 195, 6576, 16382, 6763, 196 closes https://github.com/official-stockfish/Stockfish/pull/4751 No functional change --- src/nnue/layers/sqr_clipped_relu.h | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/src/nnue/layers/sqr_clipped_relu.h b/src/nnue/layers/sqr_clipped_relu.h index 69bd5147..5c1b9e6c 100644 --- a/src/nnue/layers/sqr_clipped_relu.h +++ b/src/nnue/layers/sqr_clipped_relu.h @@ -65,12 +65,6 @@ namespace Stockfish::Eval::NNUE::Layers { #if defined(USE_SSE2) constexpr IndexType NumChunks = InputDimensions / 16; - #ifdef USE_SSE41 - const __m128i Zero = _mm_setzero_si128(); - #else - const __m128i k0x80s = _mm_set1_epi8(-128); - #endif - static_assert(WeightScaleBits == 6); const auto in = reinterpret_cast(input); const auto out = reinterpret_cast<__m128i*>(output); @@ -82,21 +76,13 @@ namespace Stockfish::Eval::NNUE::Layers { _mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])); - // Not sure if + // We shift by WeightScaleBits * 2 = 12 and divide by 128 + // which is an additional shift-right of 7, meaning 19 in total. + // MulHi strips the lower 16 bits so we need to shift out 3 more to match. words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3); words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3); - const __m128i packedbytes = _mm_packs_epi16(words0, words1); - - _mm_store_si128(&out[i], - - #ifdef USE_SSE41 - _mm_max_epi8(packedbytes, Zero) - #else - _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) - #endif - - ); + _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1)); } constexpr IndexType Start = NumChunks * 16; @@ -108,7 +94,7 @@ namespace Stockfish::Eval::NNUE::Layers { output[i] = static_cast( // really should be /127 but we need to make it fast // needs to be accounted for in the trainer - std::max(0ll, std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128))); + std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)); } } }; -- 2.39.2