X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fnnue%2Flayers%2Faffine_transform.h;h=adf152eea5b8894fcdf26cdfebffcdbd602b5d7f;hp=34777ef66e70351fefc6f5310f10d3f6983d3926;hb=303713b560e356a902c1830bce205716cef54a44;hpb=4d30438400932d18c095a8b85c3e51789d5f0feb diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index 34777ef6..adf152ee 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -301,20 +301,40 @@ namespace Eval::NNUE::Layers { } else if constexpr (kOutputDimensions == 1) { - constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth; - - vec_t sum0 = vec_setzero(); - - const auto row0 = reinterpret_cast(&weights_[0]); - - for (int j = 0; j < (int)kNumChunks; ++j) +#if defined (USE_AVX512) + if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) != 0) { - const vec_t in = input_vector[j]; - - vec_add_dpbusd_32(sum0, in, row0[j]); + constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth; + const auto input_vector256 = reinterpret_cast(input); + + __m256i sum0 = _mm256_setzero_si256(); + const auto row0 = reinterpret_cast(&weights_[0]); + + for (int j = 0; j < (int)kNumChunks; ++j) + { + const __m256i in = input_vector256[j]; + m256_add_dpbusd_epi32(sum0, in, row0[j]); + } + output[0] = m256_hadd(sum0, biases_[0]); + } + else +#endif + { +#if defined (USE_AVX512) + constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2); +#else + constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth; +#endif + vec_t sum0 = vec_setzero(); + const auto row0 = reinterpret_cast(&weights_[0]); + + for (int j = 0; j < (int)kNumChunks; ++j) + { + const vec_t in = input_vector[j]; + vec_add_dpbusd_32(sum0, in, row0[j]); + } + output[0] = vec_hadd(sum0, biases_[0]); } - - output[0] = vec_hadd(sum0, biases_[0]); } #else