]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/sqr_clipped_relu.h
Simplify Square Clipped ReLU code.
[stockfish] / src / nnue / layers / sqr_clipped_relu.h
index 69bd51471d7fac875d6f7453f8f6b1f0cce3be9f..5c1b9e6cd060d642f16de063c097c0b3da61e505 100644 (file)
@@ -65,12 +65,6 @@ namespace Stockfish::Eval::NNUE::Layers {
   #if defined(USE_SSE2)
       constexpr IndexType NumChunks = InputDimensions / 16;
 
-  #ifdef USE_SSE41
-      const __m128i Zero = _mm_setzero_si128();
-  #else
-      const __m128i k0x80s = _mm_set1_epi8(-128);
-  #endif
-
       static_assert(WeightScaleBits == 6);
       const auto in = reinterpret_cast<const __m128i*>(input);
       const auto out = reinterpret_cast<__m128i*>(output);
@@ -82,21 +76,13 @@ namespace Stockfish::Eval::NNUE::Layers {
             _mm_load_si128(&in[i * 4 + 2]),
             _mm_load_si128(&in[i * 4 + 3]));
 
-        // Not sure if
+        // We shift by WeightScaleBits * 2 = 12 and divide by 128
+        // which is an additional shift-right of 7, meaning 19 in total.
+        // MulHi strips the lower 16 bits so we need to shift out 3 more to match.
         words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3);
         words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3);
 
-        const __m128i packedbytes = _mm_packs_epi16(words0, words1);
-
-        _mm_store_si128(&out[i],
-
-  #ifdef USE_SSE41
-          _mm_max_epi8(packedbytes, Zero)
-  #else
-          _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
-  #endif
-
-        );
+        _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
       }
       constexpr IndexType Start = NumChunks * 16;
 
@@ -108,7 +94,7 @@ namespace Stockfish::Eval::NNUE::Layers {
         output[i] = static_cast<OutputType>(
             // really should be /127 but we need to make it fast
             // needs to be accounted for in the trainer
-            std::max(0ll, std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)));
+            std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128));
       }
     }
   };