]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_feature_transformer.h
Provide vectorized NNUE code for SSE2 and MMX targets
[stockfish] / src / nnue / nnue_feature_transformer.h
index 3818e444b6af9710110dff8eba49b4148d55b53b..40f2603d9d5c8a8fe212a1a40b465876f822cf6e 100644 (file)
@@ -88,7 +88,7 @@ namespace Eval::NNUE {
       constexpr int kControl = 0b11011000;
       const __m256i kZero = _mm256_setzero_si256();
 
-  #elif defined(USE_SSSE3)
+  #elif defined(USE_SSE2)
       constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
 
   #ifdef USE_SSE41
@@ -97,6 +97,10 @@ namespace Eval::NNUE {
       const __m128i k0x80s = _mm_set1_epi8(-128);
   #endif
 
+  #elif defined(USE_MMX)
+      constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
+      const __m64 k0x80s = _mm_set1_pi8(-128);
+
   #elif defined(USE_NEON)
       constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
       const int8x8_t kZero = {0};
@@ -117,7 +121,7 @@ namespace Eval::NNUE {
               _mm256_packs_epi16(sum0, sum1), kZero), kControl));
         }
 
-  #elif defined(USE_SSSE3)
+  #elif defined(USE_SSE2)
         auto out = reinterpret_cast<__m128i*>(&output[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
           __m128i sum0 = _mm_load_si128(&reinterpret_cast<const __m128i*>(
@@ -137,6 +141,17 @@ namespace Eval::NNUE {
           );
         }
 
+  #elif defined(USE_MMX)
+        auto out = reinterpret_cast<__m64*>(&output[offset]);
+        for (IndexType j = 0; j < kNumChunks; ++j) {
+          __m64 sum0 = *(&reinterpret_cast<const __m64*>(
+              accumulation[perspectives[p]][0])[j * 2 + 0]);
+          __m64 sum1 = *(&reinterpret_cast<const __m64*>(
+              accumulation[perspectives[p]][0])[j * 2 + 1]);
+          const __m64 packedbytes = _mm_packs_pi16(sum0, sum1);
+          out[j] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s);
+        }
+
   #elif defined(USE_NEON)
         const auto out = reinterpret_cast<int8x8_t*>(&output[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
@@ -154,6 +169,9 @@ namespace Eval::NNUE {
   #endif
 
       }
+  #if defined(USE_MMX)
+      _mm_empty();
+  #endif
     }
 
    private:
@@ -193,6 +211,15 @@ namespace Eval::NNUE {
           for (IndexType j = 0; j < kNumChunks; ++j)
             accumulation[j] = _mm_add_epi16(accumulation[j], column[j]);
 
+  #elif defined(USE_MMX)
+          auto accumulation = reinterpret_cast<__m64*>(
+              &accumulator.accumulation[perspective][i][0]);
+          auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
+          constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
+          for (IndexType j = 0; j < kNumChunks; ++j) {
+            accumulation[j] = _mm_add_pi16(accumulation[j], column[j]);
+          }
+
   #elif defined(USE_NEON)
           auto accumulation = reinterpret_cast<int16x8_t*>(
               &accumulator.accumulation[perspective][i][0]);
@@ -208,6 +235,9 @@ namespace Eval::NNUE {
 
         }
       }
+  #if defined(USE_MMX)
+      _mm_empty();
+  #endif
 
       accumulator.computed_accumulation = true;
       accumulator.computed_score = false;
@@ -234,6 +264,11 @@ namespace Eval::NNUE {
         auto accumulation = reinterpret_cast<__m128i*>(
             &accumulator.accumulation[perspective][i][0]);
 
+  #elif defined(USE_MMX)
+        constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
+        auto accumulation = reinterpret_cast<__m64*>(
+            &accumulator.accumulation[perspective][i][0]);
+
   #elif defined(USE_NEON)
         constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
         auto accumulation = reinterpret_cast<int16x8_t*>(
@@ -263,6 +298,12 @@ namespace Eval::NNUE {
               accumulation[j] = _mm_sub_epi16(accumulation[j], column[j]);
             }
 
+  #elif defined(USE_MMX)
+            auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
+            for (IndexType j = 0; j < kNumChunks; ++j) {
+              accumulation[j] = _mm_sub_pi16(accumulation[j], column[j]);
+            }
+
   #elif defined(USE_NEON)
             auto column = reinterpret_cast<const int16x8_t*>(&weights_[offset]);
             for (IndexType j = 0; j < kNumChunks; ++j) {
@@ -294,6 +335,12 @@ namespace Eval::NNUE {
               accumulation[j] = _mm_add_epi16(accumulation[j], column[j]);
             }
 
+  #elif defined(USE_MMX)
+            auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
+            for (IndexType j = 0; j < kNumChunks; ++j) {
+              accumulation[j] = _mm_add_pi16(accumulation[j], column[j]);
+            }
+
   #elif defined(USE_NEON)
             auto column = reinterpret_cast<const int16x8_t*>(&weights_[offset]);
             for (IndexType j = 0; j < kNumChunks; ++j) {
@@ -310,6 +357,9 @@ namespace Eval::NNUE {
           }
         }
       }
+  #if defined(USE_MMX)
+      _mm_empty();
+  #endif
 
       accumulator.computed_accumulation = true;
       accumulator.computed_score = false;