]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/clipped_relu.h
Provide vectorized NNUE code for SSE2 and MMX targets
[stockfish] / src / nnue / layers / clipped_relu.h
index 13196ec28b49d133afeb0c0f704e644b86583b8d..44d8a7def4c92ff55f7fcf1150b9553b37a90458 100644 (file)
@@ -84,7 +84,7 @@ namespace Eval::NNUE::Layers {
       }
       constexpr IndexType kStart = kNumChunks * kSimdWidth;
 
-  #elif defined(USE_SSSE3)
+  #elif defined(USE_SSE2)
       constexpr IndexType kNumChunks = kInputDimensions / kSimdWidth;
 
   #ifdef USE_SSE41
@@ -115,6 +115,24 @@ namespace Eval::NNUE::Layers {
       }
       constexpr IndexType kStart = kNumChunks * kSimdWidth;
 
+  #elif defined(USE_MMX)
+      constexpr IndexType kNumChunks = kInputDimensions / kSimdWidth;
+      const __m64 k0x80s = _mm_set1_pi8(-128);
+      const auto in = reinterpret_cast<const __m64*>(input);
+      const auto out = reinterpret_cast<__m64*>(output);
+      for (IndexType i = 0; i < kNumChunks; ++i) {
+        const __m64 words0 = _mm_srai_pi16(
+            _mm_packs_pi32(in[i * 4 + 0], in[i * 4 + 1]),
+            kWeightScaleBits);
+        const __m64 words1 = _mm_srai_pi16(
+            _mm_packs_pi32(in[i * 4 + 2], in[i * 4 + 3]),
+            kWeightScaleBits);
+        const __m64 packedbytes = _mm_packs_pi16(words0, words1);
+        out[i] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s);
+      }
+      _mm_empty();
+      constexpr IndexType kStart = kNumChunks * kSimdWidth;
+
   #elif defined(USE_NEON)
       constexpr IndexType kNumChunks = kInputDimensions / (kSimdWidth / 2);
       const int8x8_t kZero = {0};