]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/affine_transform.h
Add support for VNNI
[stockfish] / src / nnue / layers / affine_transform.h
index 8d2acd1852eabfa9dce873b803040742a2da75f9..322e32402500029033caa5ddbb7123965adc45d7 100644 (file)
@@ -79,8 +79,10 @@ namespace Eval::NNUE::Layers {
 
   #if defined(USE_AVX512)
       constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2);
 
   #if defined(USE_AVX512)
       constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2);
-      const __m512i kOnes = _mm512_set1_epi16(1);
       const auto input_vector = reinterpret_cast<const __m512i*>(input);
       const auto input_vector = reinterpret_cast<const __m512i*>(input);
+  #if !defined(USE_VNNI)
+      const __m512i kOnes = _mm512_set1_epi16(1);
+  #endif
 
   #elif defined(USE_AVX2)
       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
 
   #elif defined(USE_AVX2)
       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
@@ -113,9 +115,13 @@ namespace Eval::NNUE::Layers {
         __m512i sum = _mm512_setzero_si512();
         const auto row = reinterpret_cast<const __m512i*>(&weights_[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
         __m512i sum = _mm512_setzero_si512();
         const auto row = reinterpret_cast<const __m512i*>(&weights_[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
+  #if defined(USE_VNNI)
+            sum = _mm512_dpbusd_epi32(sum, _mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
+  #else
             __m512i product = _mm512_maddubs_epi16(_mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
             product = _mm512_madd_epi16(product, kOnes);
             sum = _mm512_add_epi32(sum, product);
             __m512i product = _mm512_maddubs_epi16(_mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
             product = _mm512_madd_epi16(product, kOnes);
             sum = _mm512_add_epi32(sum, product);
+  #endif
         }
 
         // Note: Changing kMaxSimdWidth from 32 to 64 breaks loading existing networks.
         }
 
         // Note: Changing kMaxSimdWidth from 32 to 64 breaks loading existing networks.
@@ -125,8 +131,14 @@ namespace Eval::NNUE::Layers {
         {
             const auto iv256  = reinterpret_cast<const __m256i*>(&input_vector[kNumChunks]);
             const auto row256 = reinterpret_cast<const __m256i*>(&row[kNumChunks]);
         {
             const auto iv256  = reinterpret_cast<const __m256i*>(&input_vector[kNumChunks]);
             const auto row256 = reinterpret_cast<const __m256i*>(&row[kNumChunks]);
+  #if defined(USE_VNNI)
+            __m256i product256 = _mm256_dpbusd_epi32(
+                _mm512_castsi512_si256(sum), _mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0]));
+            sum = _mm512_inserti32x8(sum, product256, 0);
+  #else
             __m256i product256 = _mm256_maddubs_epi16(_mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0]));
             sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(product256));
             __m256i product256 = _mm256_maddubs_epi16(_mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0]));
             sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(product256));
+  #endif
         }
         output[i] = _mm512_reduce_add_epi32(sum) + biases_[i];
 
         }
         output[i] = _mm512_reduce_add_epi32(sum) + biases_[i];