X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=src%2Fnnue%2Fnnue_feature_transformer.h;h=85bc2bc8e66cb5014cfb779e7f25e3482d1aa8a9;hb=be7a03a957d5c2590a329f8f47acea8af2305adf;hp=f145c848f96fb82349be3d78530f76b97918ab04;hpb=2046d5da30b2cd505b69bddb40062b0d37b43bc7;p=stockfish diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index f145c848..85bc2bc8 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -36,16 +36,16 @@ namespace Eval::NNUE { #ifdef USE_AVX512 typedef __m512i vec_t; - #define vec_load(a) _mm512_loadA_si512(a) - #define vec_store(a,b) _mm512_storeA_si512(a,b) + #define vec_load(a) _mm512_load_si512(a) + #define vec_store(a,b) _mm512_store_si512(a,b) #define vec_add_16(a,b) _mm512_add_epi16(a,b) #define vec_sub_16(a,b) _mm512_sub_epi16(a,b) static constexpr IndexType kNumRegs = 8; // only 8 are needed #elif USE_AVX2 typedef __m256i vec_t; - #define vec_load(a) _mm256_loadA_si256(a) - #define vec_store(a,b) _mm256_storeA_si256(a,b) + #define vec_load(a) _mm256_load_si256(a) + #define vec_store(a,b) _mm256_store_si256(a,b) #define vec_add_16(a,b) _mm256_add_epi16(a,b) #define vec_sub_16(a,b) _mm256_sub_epi16(a,b) static constexpr IndexType kNumRegs = 16; @@ -127,7 +127,13 @@ namespace Eval::NNUE { const auto& accumulation = pos.state()->accumulator.accumulation; - #if defined(USE_AVX2) + #if defined(USE_AVX512) + constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth * 2); + static_assert(kHalfDimensions % (kSimdWidth * 2) == 0); + const __m512i kControl = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7); + const __m512i kZero = _mm512_setzero_si512(); + + #elif defined(USE_AVX2) constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth; constexpr int kControl = 0b11011000; const __m256i kZero = _mm256_setzero_si256(); @@ -154,14 +160,25 @@ namespace Eval::NNUE { for (IndexType p = 0; p < 2; ++p) { const IndexType offset = kHalfDimensions * p; - #if defined(USE_AVX2) + #if defined(USE_AVX512) + auto out = reinterpret_cast<__m512i*>(&output[offset]); + for (IndexType j = 0; j < kNumChunks; ++j) { + __m512i sum0 = _mm512_load_si512( + &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 0]); + __m512i sum1 = _mm512_load_si512( + &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 1]); + _mm512_store_si512(&out[j], _mm512_permutexvar_epi64(kControl, + _mm512_max_epi8(_mm512_packs_epi16(sum0, sum1), kZero))); + } + + #elif defined(USE_AVX2) auto out = reinterpret_cast<__m256i*>(&output[offset]); for (IndexType j = 0; j < kNumChunks; ++j) { - __m256i sum0 = _mm256_loadA_si256( + __m256i sum0 = _mm256_load_si256( &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 0]); - __m256i sum1 = _mm256_loadA_si256( - &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 1]); - _mm256_storeA_si256(&out[j], _mm256_permute4x64_epi64(_mm256_max_epi8( + __m256i sum1 = _mm256_load_si256( + &reinterpret_cast(accumulation[perspectives[p]][0])[j * 2 + 1]); + _mm256_store_si256(&out[j], _mm256_permute4x64_epi64(_mm256_max_epi8( _mm256_packs_epi16(sum0, sum1), kZero), kControl)); } @@ -177,9 +194,9 @@ namespace Eval::NNUE { _mm_store_si128(&out[j], #ifdef USE_SSE41 - _mm_max_epi8(packedbytes, kZero) + _mm_max_epi8(packedbytes, kZero) #else - _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) + _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s) #endif ); @@ -230,7 +247,7 @@ namespace Eval::NNUE { // Look for a usable accumulator of an earlier position. We keep track // of the estimated gain in terms of features to be added/subtracted. StateInfo *st = pos.state(), *next = nullptr; - int gain = popcount(pos.pieces()) - 2; + int gain = pos.count() - 2; while (st->accumulator.state[c] == EMPTY) { auto& dp = st->dirtyPiece;