]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_feature_transformer.h
Support VNNI on 256bit vectors
[stockfish] / src / nnue / nnue_feature_transformer.h
index cbcc26f3efae9f592eead48230d153c93ddd1301..437076102310bc3d4446f98d92372bcb9a063db0 100644 (file)
@@ -55,10 +55,10 @@ namespace Eval::NNUE {
 
     // Read network parameters
     bool ReadParameters(std::istream& stream) {
-      stream.read(reinterpret_cast<char*>(biases_),
-                  kHalfDimensions * sizeof(BiasType));
-      stream.read(reinterpret_cast<char*>(weights_),
-                  kHalfDimensions * kInputDimensions * sizeof(WeightType));
+      for (std::size_t i = 0; i < kHalfDimensions; ++i)
+        biases_[i] = read_little_endian<BiasType>(stream);
+      for (std::size_t i = 0; i < kHalfDimensions * kInputDimensions; ++i)
+        weights_[i] = read_little_endian<WeightType>(stream);
       return !stream.fail();
     }
 
@@ -88,7 +88,7 @@ namespace Eval::NNUE {
       constexpr int kControl = 0b11011000;
       const __m256i kZero = _mm256_setzero_si256();
 
-  #elif defined(USE_SSSE3)
+  #elif defined(USE_SSE2)
       constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
 
   #ifdef USE_SSE41
@@ -97,6 +97,10 @@ namespace Eval::NNUE {
       const __m128i k0x80s = _mm_set1_epi8(-128);
   #endif
 
+  #elif defined(USE_MMX)
+      constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
+      const __m64 k0x80s = _mm_set1_pi8(-128);
+
   #elif defined(USE_NEON)
       constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
       const int8x8_t kZero = {0};
@@ -117,7 +121,7 @@ namespace Eval::NNUE {
               _mm256_packs_epi16(sum0, sum1), kZero), kControl));
         }
 
-  #elif defined(USE_SSSE3)
+  #elif defined(USE_SSE2)
         auto out = reinterpret_cast<__m128i*>(&output[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
           __m128i sum0 = _mm_load_si128(&reinterpret_cast<const __m128i*>(
@@ -137,6 +141,17 @@ namespace Eval::NNUE {
           );
         }
 
+  #elif defined(USE_MMX)
+        auto out = reinterpret_cast<__m64*>(&output[offset]);
+        for (IndexType j = 0; j < kNumChunks; ++j) {
+          __m64 sum0 = *(&reinterpret_cast<const __m64*>(
+              accumulation[perspectives[p]][0])[j * 2 + 0]);
+          __m64 sum1 = *(&reinterpret_cast<const __m64*>(
+              accumulation[perspectives[p]][0])[j * 2 + 1]);
+          const __m64 packedbytes = _mm_packs_pi16(sum0, sum1);
+          out[j] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s);
+        }
+
   #elif defined(USE_NEON)
         const auto out = reinterpret_cast<int8x8_t*>(&output[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
@@ -154,6 +169,9 @@ namespace Eval::NNUE {
   #endif
 
       }
+  #if defined(USE_MMX)
+      _mm_empty();
+  #endif
     }
 
    private:
@@ -169,23 +187,37 @@ namespace Eval::NNUE {
                    kHalfDimensions * sizeof(BiasType));
         for (const auto index : active_indices[perspective]) {
           const IndexType offset = kHalfDimensions * index;
+  #if defined(USE_AVX512)
+          auto accumulation = reinterpret_cast<__m512i*>(
+              &accumulator.accumulation[perspective][i][0]);
+          auto column = reinterpret_cast<const __m512i*>(&weights_[offset]);
+          constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
+          for (IndexType j = 0; j < kNumChunks; ++j)
+            _mm512_storeA_si512(&accumulation[j], _mm512_add_epi16(_mm512_loadA_si512(&accumulation[j]), column[j]));
 
-  #if defined(USE_AVX2)
+  #elif defined(USE_AVX2)
           auto accumulation = reinterpret_cast<__m256i*>(
               &accumulator.accumulation[perspective][i][0]);
           auto column = reinterpret_cast<const __m256i*>(&weights_[offset]);
           constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
-          for (IndexType j = 0; j < kNumChunks; ++j) {
+          for (IndexType j = 0; j < kNumChunks; ++j)
             _mm256_storeA_si256(&accumulation[j], _mm256_add_epi16(_mm256_loadA_si256(&accumulation[j]), column[j]));
-          }
 
   #elif defined(USE_SSE2)
           auto accumulation = reinterpret_cast<__m128i*>(
               &accumulator.accumulation[perspective][i][0]);
           auto column = reinterpret_cast<const __m128i*>(&weights_[offset]);
           constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
-          for (IndexType j = 0; j < kNumChunks; ++j) {
+          for (IndexType j = 0; j < kNumChunks; ++j)
             accumulation[j] = _mm_add_epi16(accumulation[j], column[j]);
+
+  #elif defined(USE_MMX)
+          auto accumulation = reinterpret_cast<__m64*>(
+              &accumulator.accumulation[perspective][i][0]);
+          auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
+          constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
+          for (IndexType j = 0; j < kNumChunks; ++j) {
+            accumulation[j] = _mm_add_pi16(accumulation[j], column[j]);
           }
 
   #elif defined(USE_NEON)
@@ -193,18 +225,19 @@ namespace Eval::NNUE {
               &accumulator.accumulation[perspective][i][0]);
           auto column = reinterpret_cast<const int16x8_t*>(&weights_[offset]);
           constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
-          for (IndexType j = 0; j < kNumChunks; ++j) {
+          for (IndexType j = 0; j < kNumChunks; ++j)
             accumulation[j] = vaddq_s16(accumulation[j], column[j]);
-          }
 
   #else
-          for (IndexType j = 0; j < kHalfDimensions; ++j) {
+          for (IndexType j = 0; j < kHalfDimensions; ++j)
             accumulator.accumulation[perspective][i][j] += weights_[offset + j];
-          }
   #endif
 
         }
       }
+  #if defined(USE_MMX)
+      _mm_empty();
+  #endif
 
       accumulator.computed_accumulation = true;
       accumulator.computed_score = false;
@@ -231,6 +264,11 @@ namespace Eval::NNUE {
         auto accumulation = reinterpret_cast<__m128i*>(
             &accumulator.accumulation[perspective][i][0]);
 
+  #elif defined(USE_MMX)
+        constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
+        auto accumulation = reinterpret_cast<__m64*>(
+            &accumulator.accumulation[perspective][i][0]);
+
   #elif defined(USE_NEON)
         constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth / 2);
         auto accumulation = reinterpret_cast<int16x8_t*>(
@@ -260,6 +298,12 @@ namespace Eval::NNUE {
               accumulation[j] = _mm_sub_epi16(accumulation[j], column[j]);
             }
 
+  #elif defined(USE_MMX)
+            auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
+            for (IndexType j = 0; j < kNumChunks; ++j) {
+              accumulation[j] = _mm_sub_pi16(accumulation[j], column[j]);
+            }
+
   #elif defined(USE_NEON)
             auto column = reinterpret_cast<const int16x8_t*>(&weights_[offset]);
             for (IndexType j = 0; j < kNumChunks; ++j) {
@@ -291,6 +335,12 @@ namespace Eval::NNUE {
               accumulation[j] = _mm_add_epi16(accumulation[j], column[j]);
             }
 
+  #elif defined(USE_MMX)
+            auto column = reinterpret_cast<const __m64*>(&weights_[offset]);
+            for (IndexType j = 0; j < kNumChunks; ++j) {
+              accumulation[j] = _mm_add_pi16(accumulation[j], column[j]);
+            }
+
   #elif defined(USE_NEON)
             auto column = reinterpret_cast<const int16x8_t*>(&weights_[offset]);
             for (IndexType j = 0; j < kNumChunks; ++j) {
@@ -307,6 +357,9 @@ namespace Eval::NNUE {
           }
         }
       }
+  #if defined(USE_MMX)
+      _mm_empty();
+  #endif
 
       accumulator.computed_accumulation = true;
       accumulator.computed_score = false;