X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fnnue%2Flayers%2Fsqr_clipped_relu.h;fp=src%2Fnnue%2Flayers%2Fsqr_clipped_relu.h;h=5c1b9e6cd060d642f16de063c097c0b3da61e505;hp=69bd51471d7fac875d6f7453f8f6b1f0cce3be9f;hb=c6f62363a657263a567a0cc9bae09f3c4016156d;hpb=4c5919fa95d543b1cd5d0403f3a89e10a2bdd10c diff --git a/src/nnue/layers/sqr_clipped_relu.h b/src/nnue/layers/sqr_clipped_relu.h index 69bd5147..5c1b9e6c 100644 --- a/src/nnue/layers/sqr_clipped_relu.h +++ b/src/nnue/layers/sqr_clipped_relu.h @@ -65,12 +65,6 @@ namespace Stockfish::Eval::NNUE::Layers { #if defined(USE_SSE2) constexpr IndexType NumChunks = InputDimensions / 16; - #ifdef USE_SSE41 - const __m128i Zero = _mm_setzero_si128(); - #else - const __m128i k0x80s = _mm_set1_epi8(-128); - #endif - static_assert(WeightScaleBits == 6); const auto in = reinterpret_cast(input); const auto out = reinterpret_cast<__m128i*>(output); @@ -82,21 +76,13 @@ namespace Stockfish::Eval::NNUE::Layers { _mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3])); - // Not sure if + // We shift by WeightScaleBits * 2 = 12 and divide by 128 + // which is an additional shift-right of 7, meaning 19 in total. + // MulHi strips the lower 16 bits so we need to shift out 3 more to match. words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3); words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3); - const __m128i packedbytes = _mm_packs_epi16(words0, words1); - - _mm_store_si128(&out[i], - - #ifdef USE_SSE41 - _mm_max_epi8(packedbytes, Zero) - #else - _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) - #endif - - ); + _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1)); } constexpr IndexType Start = NumChunks * 16; @@ -108,7 +94,7 @@ namespace Stockfish::Eval::NNUE::Layers { output[i] = static_cast( // really should be /127 but we need to make it fast // needs to be accounted for in the trainer - std::max(0ll, std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128))); + std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)); } } };