From: Tomasz Sobczyk Date: Tue, 3 Nov 2020 21:49:10 +0000 (+0100) Subject: AVX-512 for smaller affine and feature transforms. X-Git-Url: https://git.sesse.net/?p=stockfish;a=commitdiff_plain;h=ba35c88ab84b959d41a67b3d8fcb40adc6537ec8 AVX-512 for smaller affine and feature transforms. For the feature transformer the code is analogical to AVX2 since there was room for easy adaptation of wider simd registers. For the smaller affine transforms that have 32 byte stride we keep 2 columns in one zmm register. We also unroll more aggressively so that in the end we have to do 16 parallel horizontal additions on ymm slices each consisting of 4 32-bit integers. The slices are embedded in 8 zmm registers. These changes provide about 1.5% speedup for AVX-512 builds. Closes https://github.com/official-stockfish/Stockfish/pull/3218 No functional change. --- diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index f0292e45..47c9c488 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -83,7 +83,21 @@ namespace Eval::NNUE::Layers { return _mm512_reduce_add_epi32(sum) + bias; }; - [[maybe_unused]] auto m512_haddx4 = [](__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i { + // This function takes + // sum0 = [xmm0a, xmm0b, xmm0c, xmm0d] + // sum1 = [xmm1a, xmm1b, xmm1c, xmm1d] + // sum2 = [xmm2a, xmm2b, xmm2c, xmm2d] + // sum3 = [xmm3a, xmm3b, xmm3c, xmm3d] + // and returns + // ret = [ + // reduce_add_epi32(xmm0a), reduce_add_epi32(xmm1a), reduce_add_epi32(xmm2a), reduce_add_epi32(xmm3a), + // reduce_add_epi32(xmm0b), reduce_add_epi32(xmm1b), reduce_add_epi32(xmm2b), reduce_add_epi32(xmm3b), + // reduce_add_epi32(xmm0c), reduce_add_epi32(xmm1c), reduce_add_epi32(xmm2c), reduce_add_epi32(xmm3c), + // reduce_add_epi32(xmm0d), reduce_add_epi32(xmm1d), reduce_add_epi32(xmm2d), reduce_add_epi32(xmm3d) + // ] + [[maybe_unused]] auto m512_hadd128x16_interleave = []( + __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i { + __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1); __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1); @@ -96,7 +110,13 @@ namespace Eval::NNUE::Layers { __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23); __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23); - __m512i sum = _mm512_add_epi32(sum0123a, sum0123b); + return _mm512_add_epi32(sum0123a, sum0123b); + }; + + [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave]( + __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i { + + __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3); __m256i sum256lo = _mm512_castsi512_si256(sum); __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1); @@ -109,6 +129,58 @@ namespace Eval::NNUE::Layers { return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias); }; + [[maybe_unused]] auto m512_haddx8 = [m512_hadd128x16_interleave]( + __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, + __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m256i bias) -> __m256i { + + __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3); + __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7); + + __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13); + __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15); + __m512i x = _mm512_add_epi32( + _mm512_permutex2var_epi64(suma, indices0, sumb), + _mm512_permutex2var_epi64(suma, indices1, sumb)); + + __m256i sum256lo = _mm512_castsi512_si256(x); + __m256i sum256hi = _mm512_extracti64x4_epi64(x, 1); + + return _mm256_add_epi32(_mm256_add_epi32(sum256lo, sum256hi), bias); + }; + + [[maybe_unused]] auto m512_hadd256x8 =[m512_hadd128x16_interleave]( + __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m256i bias) -> __m256i { + + __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3); + + __m512i indices = _mm512_setr_epi32( + 0, 4, 8, 12, 2, 6, 10, 14, + 1, 5, 9, 13, 3, 7, 11, 15); + sum = _mm512_permutexvar_epi32(indices, sum); + + __m256i sum256lo = _mm512_castsi512_si256(sum); + __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1); + + return _mm256_add_epi32(_mm256_hadd_epi32(sum256lo, sum256hi), bias); + }; + + [[maybe_unused]] auto m512_hadd256x16 = [m512_hadd128x16_interleave]( + __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, + __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m512i bias) -> __m512i { + + __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3); + __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7); + + __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13); + __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15); + __m512i x = _mm512_add_epi32( + _mm512_permutex2var_epi64(suma, indices0, sumb), + _mm512_permutex2var_epi64(suma, indices1, sumb)); + + __m512i indices = _mm512_setr_epi32(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15); + return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias); + }; + [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) { #if defined (USE_VNNI) acc = _mm512_dpbusd_epi32(acc, a, b); @@ -205,7 +277,58 @@ namespace Eval::NNUE::Layers { // kOutputDimensions is either 1 or a multiple of kSimdWidth // because then it is also an input dimension. - if constexpr (kOutputDimensions % 4 == 0) + if constexpr (kOutputDimensions % 16 == 0 && kNumChunks256 == 1) + { + for (IndexType i = 0; i < kOutputDimensions; i += 16) + { + const IndexType offset01a = (i + 0) * kPaddedInputDimensions; + const IndexType offset23a = (i + 2) * kPaddedInputDimensions; + const IndexType offset45a = (i + 4) * kPaddedInputDimensions; + const IndexType offset67a = (i + 6) * kPaddedInputDimensions; + const IndexType offset01b = (i + 8) * kPaddedInputDimensions; + const IndexType offset23b = (i + 10) * kPaddedInputDimensions; + const IndexType offset45b = (i + 12) * kPaddedInputDimensions; + const IndexType offset67b = (i + 14) * kPaddedInputDimensions; + + const __m512i bias = *reinterpret_cast(&biases_[i]); + __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]); + + __m512i sum01a = _mm512_setzero_si512(); + __m512i sum23a = _mm512_setzero_si512(); + __m512i sum45a = _mm512_setzero_si512(); + __m512i sum67a = _mm512_setzero_si512(); + __m512i sum01b = _mm512_setzero_si512(); + __m512i sum23b = _mm512_setzero_si512(); + __m512i sum45b = _mm512_setzero_si512(); + __m512i sum67b = _mm512_setzero_si512(); + + const auto row01a = *reinterpret_cast(&weights_[offset01a]); + const auto row23a = *reinterpret_cast(&weights_[offset23a]); + const auto row45a = *reinterpret_cast(&weights_[offset45a]); + const auto row67a = *reinterpret_cast(&weights_[offset67a]); + const auto row01b = *reinterpret_cast(&weights_[offset01b]); + const auto row23b = *reinterpret_cast(&weights_[offset23b]); + const auto row45b = *reinterpret_cast(&weights_[offset45b]); + const auto row67b = *reinterpret_cast(&weights_[offset67b]); + + const __m256i in256 = input_vector256[0]; + const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1); + + m512_add_dpbusd_epi32(sum01a, in, row01a); + m512_add_dpbusd_epi32(sum23a, in, row23a); + m512_add_dpbusd_epi32(sum45a, in, row45a); + m512_add_dpbusd_epi32(sum67a, in, row67a); + m512_add_dpbusd_epi32(sum01b, in, row01b); + m512_add_dpbusd_epi32(sum23b, in, row23b); + m512_add_dpbusd_epi32(sum45b, in, row45b); + m512_add_dpbusd_epi32(sum67b, in, row67b); + + *outptr = m512_hadd256x16( + sum01a, sum23a, sum45a, sum67a, + sum01b, sum23b, sum45b, sum67b, bias); + } + } + else if constexpr (kOutputDimensions % 4 == 0) { for (IndexType i = 0; i < kOutputDimensions; i += 4) { diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index c3f012e4..f49777b5 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -127,7 +127,13 @@ namespace Eval::NNUE { const auto& accumulation = pos.state()->accumulator.accumulation; - #if defined(USE_AVX2) + #if defined(USE_AVX512) + constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth * 2); + static_assert(kHalfDimensions % (kSimdWidth * 2) == 0); + const __m512i kControl = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7); + const __m512i kZero = _mm512_setzero_si512(); + + #elif defined(USE_AVX2) constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth; constexpr int kControl = 0b11011000; const __m256i kZero = _mm256_setzero_si256(); @@ -154,13 +160,24 @@ namespace Eval::NNUE { for (IndexType p = 0; p < 2; ++p) { const IndexType offset = kHalfDimensions * p; - #if defined(USE_AVX2) + #if defined(USE_AVX512) + auto out = reinterpret_cast<__m512i*>(&output[offset]); + for (IndexType j = 0; j < kNumChunks; ++j) { + __m512i sum0 = _mm512_load_si512( + &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 0]); + __m512i sum1 = _mm512_load_si512( + &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 1]); + _mm512_store_si512(&out[j], _mm512_permutexvar_epi64(kControl, + _mm512_max_epi8(_mm512_packs_epi16(sum0, sum1), kZero))); + } + + #elif defined(USE_AVX2) auto out = reinterpret_cast<__m256i*>(&output[offset]); for (IndexType j = 0; j < kNumChunks; ++j) { __m256i sum0 = _mm256_load_si256( &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 0]); __m256i sum1 = _mm256_load_si256( - &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 1]); + &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 1]); _mm256_store_si256(&out[j], _mm256_permute4x64_epi64(_mm256_max_epi8( _mm256_packs_epi16(sum0, sum1), kZero), kControl)); } @@ -177,9 +194,9 @@ namespace Eval::NNUE { _mm_store_si128(&out[j], #ifdef USE_SSE41 - _mm_max_epi8(packedbytes, kZero) + _mm_max_epi8(packedbytes, kZero) #else - _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) + _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) #endif );