X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fnnue%2Flayers%2Faffine_transform.h;h=322e32402500029033caa5ddbb7123965adc45d7;hp=8d2acd1852eabfa9dce873b803040742a2da75f9;hb=dd63b98fb06e050aa961fbad6fd1f9316f2b17df;hpb=6bc0256292cf51d390fee0cb78963da884dc2677 diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index 8d2acd18..322e3240 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -79,8 +79,10 @@ namespace Eval::NNUE::Layers { #if defined(USE_AVX512) constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2); - const __m512i kOnes = _mm512_set1_epi16(1); const auto input_vector = reinterpret_cast(input); + #if !defined(USE_VNNI) + const __m512i kOnes = _mm512_set1_epi16(1); + #endif #elif defined(USE_AVX2) constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth; @@ -113,9 +115,13 @@ namespace Eval::NNUE::Layers { __m512i sum = _mm512_setzero_si512(); const auto row = reinterpret_cast(&weights_[offset]); for (IndexType j = 0; j < kNumChunks; ++j) { + #if defined(USE_VNNI) + sum = _mm512_dpbusd_epi32(sum, _mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j])); + #else __m512i product = _mm512_maddubs_epi16(_mm512_loadA_si512(&input_vector[j]), _mm512_load_si512(&row[j])); product = _mm512_madd_epi16(product, kOnes); sum = _mm512_add_epi32(sum, product); + #endif } // Note: Changing kMaxSimdWidth from 32 to 64 breaks loading existing networks. @@ -125,8 +131,14 @@ namespace Eval::NNUE::Layers { { const auto iv256 = reinterpret_cast(&input_vector[kNumChunks]); const auto row256 = reinterpret_cast(&row[kNumChunks]); + #if defined(USE_VNNI) + __m256i product256 = _mm256_dpbusd_epi32( + _mm512_castsi512_si256(sum), _mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0])); + sum = _mm512_inserti32x8(sum, product256, 0); + #else __m256i product256 = _mm256_maddubs_epi16(_mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0])); sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(product256)); + #endif } output[i] = _mm512_reduce_add_epi32(sum) + biases_[i];