]> git.sesse.net Git - stockfish/commitdiff
Simplify Square Clipped ReLU code.
authorGian-Carlo Pascutto <gcp@sjeng.org>
Mon, 14 Aug 2023 15:30:10 +0000 (17:30 +0200)
committerStéphane Nicolet <cassio@free.fr>
Tue, 22 Aug 2023 09:14:19 +0000 (11:14 +0200)
Squared numbers are never negative, so barring any wraparound there
is no need to clamp to 0. From reading the code, there's no obvious
way to get wraparound, so the entire operation can be simplified
away. Updated original truncated code comments to be sensible.

Verified by running ./stockfish bench 128 1 24 and by the following test:

STC: https://tests.stockfishchess.org/tests/view/64da4db95b17f7c21c0eabe7
LLR: 2.94 (-2.94,2.94) <-1.75,0.25>
Total: 60224 W: 15425 L: 15236 D: 29563
Ptnml(0-2): 195, 6576, 16382, 6763, 196

closes https://github.com/official-stockfish/Stockfish/pull/4751

No functional change

src/nnue/layers/sqr_clipped_relu.h

index 69bd51471d7fac875d6f7453f8f6b1f0cce3be9f..5c1b9e6cd060d642f16de063c097c0b3da61e505 100644 (file)
@@ -65,12 +65,6 @@ namespace Stockfish::Eval::NNUE::Layers {
   #if defined(USE_SSE2)
       constexpr IndexType NumChunks = InputDimensions / 16;
 
-  #ifdef USE_SSE41
-      const __m128i Zero = _mm_setzero_si128();
-  #else
-      const __m128i k0x80s = _mm_set1_epi8(-128);
-  #endif
-
       static_assert(WeightScaleBits == 6);
       const auto in = reinterpret_cast<const __m128i*>(input);
       const auto out = reinterpret_cast<__m128i*>(output);
@@ -82,21 +76,13 @@ namespace Stockfish::Eval::NNUE::Layers {
             _mm_load_si128(&in[i * 4 + 2]),
             _mm_load_si128(&in[i * 4 + 3]));
 
-        // Not sure if
+        // We shift by WeightScaleBits * 2 = 12 and divide by 128
+        // which is an additional shift-right of 7, meaning 19 in total.
+        // MulHi strips the lower 16 bits so we need to shift out 3 more to match.
         words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3);
         words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3);
 
-        const __m128i packedbytes = _mm_packs_epi16(words0, words1);
-
-        _mm_store_si128(&out[i],
-
-  #ifdef USE_SSE41
-          _mm_max_epi8(packedbytes, Zero)
-  #else
-          _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
-  #endif
-
-        );
+        _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
       }
       constexpr IndexType Start = NumChunks * 16;
 
@@ -108,7 +94,7 @@ namespace Stockfish::Eval::NNUE::Layers {
         output[i] = static_cast<OutputType>(
             // really should be /127 but we need to make it fast
             // needs to be accounted for in the trainer
-            std::max(0ll, std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)));
+            std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128));
       }
     }
   };