]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/sqr_clipped_relu.h
Cleanup includes
[stockfish] / src / nnue / layers / sqr_clipped_relu.h
index 3fbb243cfd6c38371183e8ebec459f822fa01cb6..503b283b25e53b7ef1590b526c50d21b0582a224 100644 (file)
 #ifndef NNUE_LAYERS_SQR_CLIPPED_RELU_H_INCLUDED
 #define NNUE_LAYERS_SQR_CLIPPED_RELU_H_INCLUDED
 
+#include <algorithm>
+#include <cstdint>
+#include <iosfwd>
+
 #include "../nnue_common.h"
 
 namespace Stockfish::Eval::NNUE::Layers {
@@ -59,18 +63,12 @@ namespace Stockfish::Eval::NNUE::Layers {
     }
 
     // Forward propagation
-    const OutputType* propagate(
+    void propagate(
         const InputType* input, OutputType* output) const {
 
   #if defined(USE_SSE2)
       constexpr IndexType NumChunks = InputDimensions / 16;
 
-  #ifdef USE_SSE41
-      const __m128i Zero = _mm_setzero_si128();
-  #else
-      const __m128i k0x80s = _mm_set1_epi8(-128);
-  #endif
-
       static_assert(WeightScaleBits == 6);
       const auto in = reinterpret_cast<const __m128i*>(input);
       const auto out = reinterpret_cast<__m128i*>(output);
@@ -82,21 +80,13 @@ namespace Stockfish::Eval::NNUE::Layers {
             _mm_load_si128(&in[i * 4 + 2]),
             _mm_load_si128(&in[i * 4 + 3]));
 
-        // Not sure if
+        // We shift by WeightScaleBits * 2 = 12 and divide by 128
+        // which is an additional shift-right of 7, meaning 19 in total.
+        // MulHi strips the lower 16 bits so we need to shift out 3 more to match.
         words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3);
         words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3);
 
-        const __m128i packedbytes = _mm_packs_epi16(words0, words1);
-
-        _mm_store_si128(&out[i],
-
-  #ifdef USE_SSE41
-          _mm_max_epi8(packedbytes, Zero)
-  #else
-          _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
-  #endif
-
-        );
+        _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
       }
       constexpr IndexType Start = NumChunks * 16;
 
@@ -108,10 +98,8 @@ namespace Stockfish::Eval::NNUE::Layers {
         output[i] = static_cast<OutputType>(
             // really should be /127 but we need to make it fast
             // needs to be accounted for in the trainer
-            std::max(0ll, std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)));
+            std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128));
       }
-
-      return output;
     }
   };