X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=src%2Fnnue%2Flayers%2Faffine_transform.h;h=11038d69b1c7ff40b948bace675266d73af7b12d;hb=4766dfc3956f78d853c5e0c4636d6f90fd93df9a;hp=b28712780b2684868bc2c2937628e4112b72b69c;hpb=b82d93ece484f833c994b40d9eddd959ba20ef92;p=stockfish diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index b2871278..11038d69 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -75,8 +75,7 @@ namespace Stockfish::Eval::NNUE::Layers { const auto inputVector = reinterpret_cast(input); # elif defined(USE_NEON) - static_assert(PaddedInputDimensions % 16 == 0); - constexpr IndexType NumChunks = PaddedInputDimensions / 16; + constexpr IndexType NumChunks = (InputDimensions + 15) / 16; const auto inputVector = reinterpret_cast(input); # endif @@ -181,6 +180,9 @@ namespace Stockfish::Eval::NNUE::Layers { #elif defined (USE_SSSE3) static constexpr const IndexType InputSimdWidth = 16; static constexpr const IndexType MaxNumOutputRegs = 8; +#elif defined (USE_NEON) + static constexpr const IndexType InputSimdWidth = 8; + static constexpr const IndexType MaxNumOutputRegs = 8; #else // The fallback implementation will not have permuted weights. // We define these to avoid a lot of ifdefs later. @@ -270,52 +272,64 @@ namespace Stockfish::Eval::NNUE::Layers { OutputType* output = reinterpret_cast(buffer); #if defined (USE_AVX512) - using vec_t = __m512i; - #define vec_setzero _mm512_setzero_si512 - #define vec_set_32 _mm512_set1_epi32 - #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32 + using acc_vec_t = __m512i; + using bias_vec_t = __m128i; + using weight_vec_t = __m512i; + using in_vec_t = __m512i; + #define vec_zero _mm512_setzero_si512() #define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2 #define vec_hadd Simd::m512_hadd #define vec_haddx4 Simd::m512_haddx4 #elif defined (USE_AVX2) - using vec_t = __m256i; - #define vec_setzero _mm256_setzero_si256 - #define vec_set_32 _mm256_set1_epi32 - #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32 + using acc_vec_t = __m256i; + using bias_vec_t = __m128i; + using weight_vec_t = __m256i; + using in_vec_t = __m256i; + #define vec_zero _mm256_setzero_si256() #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2 #define vec_hadd Simd::m256_hadd #define vec_haddx4 Simd::m256_haddx4 #elif defined (USE_SSSE3) - using vec_t = __m128i; - #define vec_setzero _mm_setzero_si128 - #define vec_set_32 _mm_set1_epi32 - #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32 + using acc_vec_t = __m128i; + using bias_vec_t = __m128i; + using weight_vec_t = __m128i; + using in_vec_t = __m128i; + #define vec_zero _mm_setzero_si128() #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2 #define vec_hadd Simd::m128_hadd #define vec_haddx4 Simd::m128_haddx4 +#elif defined (USE_NEON) + using acc_vec_t = int32x4_t; + using bias_vec_t = int32x4_t; + using weight_vec_t = int8x8_t; + using in_vec_t = int8x8_t; + #define vec_zero {0} + #define vec_add_dpbusd_32x2 Simd::neon_m128_add_dpbusd_epi32x2 + #define vec_hadd Simd::neon_m128_hadd + #define vec_haddx4 Simd::neon_m128_haddx4 #endif -#if defined (USE_SSSE3) - const vec_t* invec = reinterpret_cast(input); +#if defined (USE_SSSE3) || defined (USE_NEON) + const in_vec_t* invec = reinterpret_cast(input); // Perform accumulation to registers for each big block for (IndexType bigBlock = 0; bigBlock < NumBigBlocks; ++bigBlock) { - vec_t acc[NumOutputRegs] = { vec_setzero() }; + acc_vec_t acc[NumOutputRegs] = { vec_zero }; // Each big block has NumOutputRegs small blocks in each "row", one per register. // We process two small blocks at a time to save on one addition without VNNI. for (IndexType smallBlock = 0; smallBlock < NumSmallBlocksPerOutput; smallBlock += 2) { - const vec_t* weightvec = - reinterpret_cast( + const weight_vec_t* weightvec = + reinterpret_cast( weights + bigBlock * BigBlockSize + smallBlock * SmallBlockSize * NumOutputRegs); - const vec_t in0 = invec[smallBlock + 0]; - const vec_t in1 = invec[smallBlock + 1]; + const in_vec_t in0 = invec[smallBlock + 0]; + const in_vec_t in1 = invec[smallBlock + 1]; for (IndexType k = 0; k < NumOutputRegs; ++k) vec_add_dpbusd_32x2(acc[k], in0, weightvec[k], in1, weightvec[k + NumOutputRegs]); @@ -324,8 +338,8 @@ namespace Stockfish::Eval::NNUE::Layers { // Horizontally add all accumulators. if constexpr (NumOutputRegs % 4 == 0) { - __m128i* outputvec = reinterpret_cast<__m128i*>(output); - const __m128i* biasvec = reinterpret_cast(biases); + bias_vec_t* outputvec = reinterpret_cast(output); + const bias_vec_t* biasvec = reinterpret_cast(biases); for (IndexType k = 0; k < NumOutputRegs; k += 4) { @@ -343,9 +357,7 @@ namespace Stockfish::Eval::NNUE::Layers { } } -# undef vec_setzero -# undef vec_set_32 -# undef vec_add_dpbusd_32 +# undef vec_zero # undef vec_add_dpbusd_32x2 # undef vec_hadd # undef vec_haddx4