const auto inputVector = reinterpret_cast<const __m64*>(input);
# elif defined(USE_NEON)
- static_assert(PaddedInputDimensions % 16 == 0);
- constexpr IndexType NumChunks = PaddedInputDimensions / 16;
+ constexpr IndexType NumChunks = (InputDimensions + 15) / 16;
const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
# endif
#elif defined (USE_SSSE3)
static constexpr const IndexType InputSimdWidth = 16;
static constexpr const IndexType MaxNumOutputRegs = 8;
+#elif defined (USE_NEON)
+ static constexpr const IndexType InputSimdWidth = 8;
+ static constexpr const IndexType MaxNumOutputRegs = 8;
#else
// The fallback implementation will not have permuted weights.
// We define these to avoid a lot of ifdefs later.
OutputType* output = reinterpret_cast<OutputType*>(buffer);
#if defined (USE_AVX512)
- using vec_t = __m512i;
- #define vec_setzero _mm512_setzero_si512
- #define vec_set_32 _mm512_set1_epi32
- #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
+ using acc_vec_t = __m512i;
+ using bias_vec_t = __m128i;
+ using weight_vec_t = __m512i;
+ using in_vec_t = __m512i;
+ #define vec_zero _mm512_setzero_si512()
#define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2
#define vec_hadd Simd::m512_hadd
#define vec_haddx4 Simd::m512_haddx4
#elif defined (USE_AVX2)
- using vec_t = __m256i;
- #define vec_setzero _mm256_setzero_si256
- #define vec_set_32 _mm256_set1_epi32
- #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
+ using acc_vec_t = __m256i;
+ using bias_vec_t = __m128i;
+ using weight_vec_t = __m256i;
+ using in_vec_t = __m256i;
+ #define vec_zero _mm256_setzero_si256()
#define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
#define vec_hadd Simd::m256_hadd
#define vec_haddx4 Simd::m256_haddx4
#elif defined (USE_SSSE3)
- using vec_t = __m128i;
- #define vec_setzero _mm_setzero_si128
- #define vec_set_32 _mm_set1_epi32
- #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
+ using acc_vec_t = __m128i;
+ using bias_vec_t = __m128i;
+ using weight_vec_t = __m128i;
+ using in_vec_t = __m128i;
+ #define vec_zero _mm_setzero_si128()
#define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
#define vec_hadd Simd::m128_hadd
#define vec_haddx4 Simd::m128_haddx4
+#elif defined (USE_NEON)
+ using acc_vec_t = int32x4_t;
+ using bias_vec_t = int32x4_t;
+ using weight_vec_t = int8x8_t;
+ using in_vec_t = int8x8_t;
+ #define vec_zero {0}
+ #define vec_add_dpbusd_32x2 Simd::neon_m128_add_dpbusd_epi32x2
+ #define vec_hadd Simd::neon_m128_hadd
+ #define vec_haddx4 Simd::neon_m128_haddx4
#endif
-#if defined (USE_SSSE3)
- const vec_t* invec = reinterpret_cast<const vec_t*>(input);
+#if defined (USE_SSSE3) || defined (USE_NEON)
+ const in_vec_t* invec = reinterpret_cast<const in_vec_t*>(input);
// Perform accumulation to registers for each big block
for (IndexType bigBlock = 0; bigBlock < NumBigBlocks; ++bigBlock)
{
- vec_t acc[NumOutputRegs] = { vec_setzero() };
+ acc_vec_t acc[NumOutputRegs] = { vec_zero };
// Each big block has NumOutputRegs small blocks in each "row", one per register.
// We process two small blocks at a time to save on one addition without VNNI.
for (IndexType smallBlock = 0; smallBlock < NumSmallBlocksPerOutput; smallBlock += 2)
{
- const vec_t* weightvec =
- reinterpret_cast<const vec_t*>(
+ const weight_vec_t* weightvec =
+ reinterpret_cast<const weight_vec_t*>(
weights
+ bigBlock * BigBlockSize
+ smallBlock * SmallBlockSize * NumOutputRegs);
- const vec_t in0 = invec[smallBlock + 0];
- const vec_t in1 = invec[smallBlock + 1];
+ const in_vec_t in0 = invec[smallBlock + 0];
+ const in_vec_t in1 = invec[smallBlock + 1];
for (IndexType k = 0; k < NumOutputRegs; ++k)
vec_add_dpbusd_32x2(acc[k], in0, weightvec[k], in1, weightvec[k + NumOutputRegs]);
// Horizontally add all accumulators.
if constexpr (NumOutputRegs % 4 == 0)
{
- __m128i* outputvec = reinterpret_cast<__m128i*>(output);
- const __m128i* biasvec = reinterpret_cast<const __m128i*>(biases);
+ bias_vec_t* outputvec = reinterpret_cast<bias_vec_t*>(output);
+ const bias_vec_t* biasvec = reinterpret_cast<const bias_vec_t*>(biases);
for (IndexType k = 0; k < NumOutputRegs; k += 4)
{
}
}
-# undef vec_setzero
-# undef vec_set_32
-# undef vec_add_dpbusd_32
+# undef vec_zero
# undef vec_add_dpbusd_32x2
# undef vec_hadd
# undef vec_haddx4
#endif
+#if defined (USE_NEON)
+
+ [[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {
+# if USE_NEON >= 8
+ return vaddvq_s32(s);
+# else
+ return s[0] + s[1] + s[2] + s[3];
+# endif
+ }
+
+ [[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) {
+ return neon_m128_reduce_add_epi32(sum) + bias;
+ }
+
+ [[maybe_unused]] static int32x4_t neon_m128_haddx4(
+ int32x4_t sum0, int32x4_t sum1, int32x4_t sum2, int32x4_t sum3,
+ int32x4_t bias) {
+
+ int32x4_t hsums {
+ neon_m128_reduce_add_epi32(sum0),
+ neon_m128_reduce_add_epi32(sum1),
+ neon_m128_reduce_add_epi32(sum2),
+ neon_m128_reduce_add_epi32(sum3)
+ };
+ return vaddq_s32(hsums, bias);
+ }
+
+ [[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2(
+ int32x4_t& acc,
+ int8x8_t a0, int8x8_t b0,
+ int8x8_t a1, int8x8_t b1) {
+
+ int16x8_t product = vmull_s8(a0, b0);
+ product = vmlal_s8(product, a1, b1);
+ acc = vpadalq_s16(acc, product);
+ }
+
+#endif
+
}
#endif // STOCKFISH_SIMD_H_INCLUDED