X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=src%2Fnnue%2Flayers%2Faffine_transform.h;h=fc65c34339ff35bb44f69dee3a85285de69eb51b;hb=8a912951de6d4bff78d3ff5258213a0c7e6f494e;hp=c936a83ed66278057934bde9f189b9938a07f49c;hpb=cb22520a9c7e1d716a56f08390aa638a23a597b2;p=stockfish diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index c936a83e..fc65c343 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -21,9 +21,9 @@ #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED #define NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED +#include #include -#include -#include + #include "../nnue_common.h" #include "simd.h" @@ -45,17 +45,13 @@ namespace Stockfish::Eval::NNUE::Layers { template static void affine_transform_non_ssse3(std::int32_t* output, const std::int8_t* weights, const std::int32_t* biases, const std::uint8_t* input) { +# if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON) # if defined(USE_SSE2) // At least a multiple of 16, with SSE2. constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; const __m128i Zeros = _mm_setzero_si128(); const auto inputVector = reinterpret_cast(input); -# elif defined(USE_MMX) - constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 8) / 8; - const __m64 Zeros = _mm_setzero_si64(); - const auto inputVector = reinterpret_cast(input); - # elif defined(USE_NEON_DOTPROD) constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; const auto inputVector = reinterpret_cast(input); @@ -91,26 +87,6 @@ namespace Stockfish::Eval::NNUE::Layers { sum = _mm_add_epi32(sum, sum_second_32); output[i] = _mm_cvtsi128_si32(sum); -# elif defined(USE_MMX) - __m64 sumLo = _mm_cvtsi32_si64(biases[i]); - __m64 sumHi = Zeros; - const auto row = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < NumChunks; ++j) { - __m64 row_j = row[j]; - __m64 input_j = inputVector[j]; - __m64 extendedRowLo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8); - __m64 extendedRowHi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8); - __m64 extendedInputLo = _mm_unpacklo_pi8(input_j, Zeros); - __m64 extendedInputHi = _mm_unpackhi_pi8(input_j, Zeros); - __m64 productLo = _mm_madd_pi16(extendedRowLo, extendedInputLo); - __m64 productHi = _mm_madd_pi16(extendedRowHi, extendedInputHi); - sumLo = _mm_add_pi32(sumLo, productLo); - sumHi = _mm_add_pi32(sumHi, productHi); - } - __m64 sum = _mm_add_pi32(sumLo, sumHi); - sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum)); - output[i] = _mm_cvtsi64_si32(sum); - # elif defined(USE_NEON_DOTPROD) int32x4_t sum = {biases[i]}; const auto row = reinterpret_cast(&weights[offset]); @@ -129,17 +105,19 @@ namespace Stockfish::Eval::NNUE::Layers { } output[i] = sum[0] + sum[1] + sum[2] + sum[3]; -# else - std::int32_t sum = biases[i]; - for (IndexType j = 0; j < InputDimensions; ++j) { - sum += weights[offset + j] * input[j]; - } - output[i] = sum; # endif } - -# if defined(USE_MMX) - _mm_empty(); +# else + std::memcpy(output, biases, sizeof(std::int32_t) * OutputDimensions); + + // Traverse weights in transpose order to take advantage of input sparsity + for (IndexType i = 0; i < InputDimensions; ++i) + if (input[i]) { + const std::int8_t* w = &weights[i]; + const int in = input[i]; + for (IndexType j = 0; j < OutputDimensions; ++j) + output[j] += w[j * PaddedInputDimensions] * in; + } # endif } #endif @@ -210,6 +188,11 @@ namespace Stockfish::Eval::NNUE::Layers { void propagate( const InputType* input, OutputType* output) const { +#if defined (USE_SSSE3) + + if constexpr (OutputDimensions > 1) + { + #if defined (USE_AVX512) using vec_t = __m512i; #define vec_setzero _mm512_setzero_si512 @@ -233,15 +216,10 @@ namespace Stockfish::Eval::NNUE::Layers { #define vec_hadd Simd::m128_hadd #endif -#if defined (USE_SSSE3) - const auto inputVector = reinterpret_cast(input); - - static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType); + static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType); - static_assert(OutputDimensions % OutputSimdWidth == 0 || OutputDimensions == 1); + static_assert(OutputDimensions % OutputSimdWidth == 0); - if constexpr (OutputDimensions % OutputSimdWidth == 0) - { constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 8) / 4; constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth; @@ -264,26 +242,58 @@ namespace Stockfish::Eval::NNUE::Layers { vec_t* outptr = reinterpret_cast(output); for (IndexType k = 0; k < NumRegs; ++k) outptr[k] = acc[k]; + +# undef vec_setzero +# undef vec_set_32 +# undef vec_add_dpbusd_32 +# undef vec_add_dpbusd_32x2 +# undef vec_hadd + } else if constexpr (OutputDimensions == 1) { - constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth; + +// We cannot use AVX512 for the last layer because there's only 32 inputs and the buffer is not padded to 64 elements. +#if defined (USE_AVX2) + using vec_t = __m256i; + #define vec_setzero _mm256_setzero_si256 + #define vec_set_32 _mm256_set1_epi32 + #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 + #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2 + #define vec_hadd Simd::m256_hadd +#elif defined (USE_SSSE3) + using vec_t = __m128i; + #define vec_setzero _mm_setzero_si128 + #define vec_set_32 _mm_set1_epi32 + #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 + #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2 + #define vec_hadd Simd::m128_hadd +#endif + + const auto inputVector = reinterpret_cast(input); + + static constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(InputType); + + static_assert(PaddedInputDimensions % InputSimdWidth == 0); + + constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth; vec_t sum0 = vec_setzero(); const auto row0 = reinterpret_cast(&weights[0]); - for (int j = 0; j < (int)NumChunks; ++j) + for (int j = 0; j < int(NumChunks); ++j) { const vec_t in = inputVector[j]; vec_add_dpbusd_32(sum0, in, row0[j]); } output[0] = vec_hadd(sum0, biases[0]); - } # undef vec_setzero # undef vec_set_32 # undef vec_add_dpbusd_32 # undef vec_add_dpbusd_32x2 # undef vec_hadd + + } #else // Use old implementation for the other architectures. affine_transform_non_ssse3<