]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/affine_transform.h
Support VNNI on 256bit vectors
[stockfish] / src / nnue / layers / affine_transform.h
index 985ee71a4193e571f9ecdddfc144ca4c2c571aea..94d0b5a9494644e574cd111104943d18667c9196 100644 (file)
@@ -62,11 +62,10 @@ namespace Eval::NNUE::Layers {
    // Read network parameters
     bool ReadParameters(std::istream& stream) {
       if (!previous_layer_.ReadParameters(stream)) return false;
-      stream.read(reinterpret_cast<char*>(biases_),
-                  kOutputDimensions * sizeof(BiasType));
-      stream.read(reinterpret_cast<char*>(weights_),
-                  kOutputDimensions * kPaddedInputDimensions *
-                  sizeof(WeightType));
+      for (std::size_t i = 0; i < kOutputDimensions; ++i)
+        biases_[i] = read_little_endian<BiasType>(stream);
+      for (std::size_t i = 0; i < kOutputDimensions * kPaddedInputDimensions; ++i)
+        weights_[i] = read_little_endian<WeightType>(stream);
       return !stream.fail();
     }
 
@@ -79,13 +78,17 @@ namespace Eval::NNUE::Layers {
 
   #if defined(USE_AVX512)
       constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2);
-      const __m512i kOnes = _mm512_set1_epi16(1);
       const auto input_vector = reinterpret_cast<const __m512i*>(input);
+  #if !defined(USE_VNNI)
+      const __m512i kOnes = _mm512_set1_epi16(1);
+  #endif
 
   #elif defined(USE_AVX2)
       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-      const __m256i kOnes = _mm256_set1_epi16(1);
       const auto input_vector = reinterpret_cast<const __m256i*>(input);
+  #if !defined(USE_VNNI)
+      const __m256i kOnes = _mm256_set1_epi16(1);
+  #endif
 
   #elif defined(USE_SSE2)
       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
@@ -113,9 +116,13 @@ namespace Eval::NNUE::Layers {
         __m512i sum = _mm512_setzero_si512();
         const auto row = reinterpret_cast<const __m512i*>(&weights_[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
+  #if defined(USE_VNNI)
+            sum = _mm512_dpbusd_epi32(sum, _mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
+  #else
             __m512i product = _mm512_maddubs_epi16(_mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
             product = _mm512_madd_epi16(product, kOnes);
             sum = _mm512_add_epi32(sum, product);
+  #endif
         }
 
         // Note: Changing kMaxSimdWidth from 32 to 64 breaks loading existing networks.
@@ -125,9 +132,14 @@ namespace Eval::NNUE::Layers {
         {
             const auto iv256  = reinterpret_cast<const __m256i*>(&input_vector[kNumChunks]);
             const auto row256 = reinterpret_cast<const __m256i*>(&row[kNumChunks]);
+  #if defined(USE_VNNI)
+            __m256i product256 = _mm256_dpbusd_epi32(
+                _mm512_castsi512_si256(sum), _mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0]));
+            sum = _mm512_inserti32x8(sum, product256, 0);
+  #else
             __m256i product256 = _mm256_maddubs_epi16(_mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0]));
-            product256 = _mm256_madd_epi16(product256, _mm256_set1_epi16(1));
-            sum = _mm512_add_epi32(sum, _mm512_zextsi256_si512(product256));
+            sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(product256));
+  #endif
         }
         output[i] = _mm512_reduce_add_epi32(sum) + biases_[i];
 
@@ -135,9 +147,13 @@ namespace Eval::NNUE::Layers {
         __m256i sum = _mm256_setzero_si256();
         const auto row = reinterpret_cast<const __m256i*>(&weights_[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
+  #if defined(USE_VNNI)
+          sum = _mm256_dpbusd_epi32(sum, _mm256_loadA_si256(&input_vector[j]), _mm256_load_si256(&row[j]));
+  #else
           __m256i product = _mm256_maddubs_epi16(_mm256_loadA_si256(&input_vector[j]), _mm256_load_si256(&row[j]));
           product = _mm256_madd_epi16(product, kOnes);
           sum = _mm256_add_epi32(sum, product);
+  #endif
         }
         __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
         sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));