X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fnnue%2Flayers%2Faffine_transform.h;fp=src%2Fnnue%2Flayers%2Faffine_transform.h;h=63b58af33c39777cca2d435f051d3950db3f8b42;hp=363b4916e37b40e80bbb719d9483f881929d75d2;hb=b4ad3a3c4b68f9c8736f444aeb3364f833247fdc;hpb=037ef3e18dc7f5455cc671995ae38d5b4d1fce4a diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index 363b4916..63b58af3 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -72,6 +72,10 @@ namespace Stockfish::Eval::NNUE::Layers { const __m64 Zeros = _mm_setzero_si64(); const auto inputVector = reinterpret_cast(input); +# elif defined(USE_NEON_DOTPROD) + constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; + const auto inputVector = reinterpret_cast(input); + # elif defined(USE_NEON) constexpr IndexType NumChunks = ceil_to_multiple(InputDimensions, 16) / 16; const auto inputVector = reinterpret_cast(input); @@ -123,6 +127,14 @@ namespace Stockfish::Eval::NNUE::Layers { sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum)); output[i] = _mm_cvtsi64_si32(sum); +# elif defined(USE_NEON_DOTPROD) + int32x4_t sum = {biases[i]}; + const auto row = reinterpret_cast(&weights[offset]); + for (IndexType j = 0; j < NumChunks; ++j) { + sum = vdotq_s32(sum, inputVector[j], row[j]); + } + output[i] = vaddvq_s32(sum); + # elif defined(USE_NEON) int32x4_t sum = {biases[i]}; const auto row = reinterpret_cast(&weights[offset]); @@ -187,6 +199,9 @@ namespace Stockfish::Eval::NNUE::Layers { #elif defined (USE_SSSE3) static constexpr IndexType InputSimdWidth = 16; static constexpr IndexType MaxNumOutputRegs = 8; +#elif defined (USE_NEON_DOTPROD) + static constexpr IndexType InputSimdWidth = 16; + static constexpr IndexType MaxNumOutputRegs = 8; #elif defined (USE_NEON) static constexpr IndexType InputSimdWidth = 8; static constexpr IndexType MaxNumOutputRegs = 8; @@ -292,6 +307,15 @@ namespace Stockfish::Eval::NNUE::Layers { #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2 #define vec_hadd Simd::m128_hadd #define vec_haddx4 Simd::m128_haddx4 +#elif defined (USE_NEON_DOTPROD) + using acc_vec_t = int32x4_t; + using bias_vec_t = int32x4_t; + using weight_vec_t = int8x16_t; + using in_vec_t = int8x16_t; + #define vec_zero {0} + #define vec_add_dpbusd_32x2 Simd::dotprod_m128_add_dpbusd_epi32x2 + #define vec_hadd Simd::neon_m128_hadd + #define vec_haddx4 Simd::neon_m128_haddx4 #elif defined (USE_NEON) using acc_vec_t = int32x4_t; using bias_vec_t = int32x4_t;