- const OutputType* propagate(
- const InputType* input, OutputType* output) const {
-
-#if defined (USE_SSSE3)
-#if defined (USE_AVX512)
- using vec_t = __m512i;
- #define vec_setzero _mm512_setzero_si512
- #define vec_set_32 _mm512_set1_epi32
- #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
-#elif defined (USE_AVX2)
- using vec_t = __m256i;
- #define vec_setzero _mm256_setzero_si256
- #define vec_set_32 _mm256_set1_epi32
- #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
-#elif defined (USE_SSSE3)
- using vec_t = __m128i;
- #define vec_setzero _mm_setzero_si128
- #define vec_set_32 _mm_set1_epi32
- #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
-#endif
- static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);
-
- constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / ChunkSize;
- constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
- std::uint16_t nnz[NumChunks];
- IndexType count;
-
- const auto input32 = reinterpret_cast<const std::int32_t*>(input);
-
- // Find indices of nonzero 32bit blocks
- find_nnz<NumChunks>(input32, nnz, count);
-
- const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
- vec_t acc[NumRegs];
- for (IndexType k = 0; k < NumRegs; ++k)
- acc[k] = biasvec[k];
-
- for (IndexType j = 0; j < count; ++j)
- {
- const auto i = nnz[j];
- const vec_t in = vec_set_32(input32[i]);
- const auto col = reinterpret_cast<const vec_t*>(&weights[i * OutputDimensions * ChunkSize]);
+ void propagate(const InputType* input, OutputType* output) const {
+
+#if (USE_SSSE3 | (USE_NEON >= 8))
+ #if defined(USE_AVX512)
+ using invec_t = __m512i;
+ using outvec_t = __m512i;
+ #define vec_set_32 _mm512_set1_epi32
+ #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
+ #elif defined(USE_AVX2)
+ using invec_t = __m256i;
+ using outvec_t = __m256i;
+ #define vec_set_32 _mm256_set1_epi32
+ #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
+ #elif defined(USE_SSSE3)
+ using invec_t = __m128i;
+ using outvec_t = __m128i;
+ #define vec_set_32 _mm_set1_epi32
+ #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
+ #elif defined(USE_NEON_DOTPROD)
+ using invec_t = int8x16_t;
+ using outvec_t = int32x4_t;
+ #define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a))
+ #define vec_add_dpbusd_32 Simd::dotprod_m128_add_dpbusd_epi32
+ #elif defined(USE_NEON)
+ using invec_t = int8x16_t;
+ using outvec_t = int32x4_t;
+ #define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a))
+ #define vec_add_dpbusd_32 Simd::neon_m128_add_dpbusd_epi32
+ #endif
+ static constexpr IndexType OutputSimdWidth = sizeof(outvec_t) / sizeof(OutputType);
+
+ constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / ChunkSize;
+ constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
+ std::uint16_t nnz[NumChunks];
+ IndexType count;
+
+ const auto input32 = reinterpret_cast<const std::int32_t*>(input);
+
+ // Find indices of nonzero 32bit blocks
+ find_nnz<NumChunks>(input32, nnz, count);
+
+ const outvec_t* biasvec = reinterpret_cast<const outvec_t*>(biases);
+ outvec_t acc[NumRegs];