- __m512i sum0 = _mm512_setzero_si512();
-
- const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
-
- for (IndexType j = 0; j < kNumChunks512; ++j)
- {
- const __m512i in = input_vector512[j];
-
- m512_add_dpbusd_epi32(sum0, in, row0[j]);
- }
-
- output[0] = m512_hadd(sum0, biases_[0]);
- }
- else
- {
- __m256i sum0 = _mm256_setzero_si256();
-
- const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
-
- for (IndexType j = 0; j < kNumChunks256; ++j)
- {
- const __m256i in = input_vector256[j];
-
- m256_add_dpbusd_epi32(sum0, in, row0[j]);
- }
-
- output[0] = m256_hadd(sum0, biases_[0]);
- }
- }
- else
- {
- // This case can never happen because kOutputDimensions
- // is always 1 or a multiple of kSimdWidth.
- assert(false);
- }
-
-#elif defined (USE_AVX2)
-
- constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-
- const auto output = reinterpret_cast<OutputType*>(buffer);
- const auto input_vector = reinterpret_cast<const __m256i*>(input);
-
- // kOutputDimensions is either 1 or a multiple of kSimdWidth
- // because then it is also an input dimension.
- if constexpr (kOutputDimensions % 4 == 0)
- {
- for (IndexType i = 0; i < kOutputDimensions; i += 4)
- {
- const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
- const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
- const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
- const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
-
- const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
- __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
-
- __m256i sum0 = _mm256_setzero_si256();
- __m256i sum1 = _mm256_setzero_si256();
- __m256i sum2 = _mm256_setzero_si256();
- __m256i sum3 = _mm256_setzero_si256();
-
- const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
- const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
- const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
- const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
-
- int j = 0;
- if (!canSaturate16x4[i / 4])
- {
- for (; j < (int)kNumChunks - 1; j += 2)
- {
- const __m256i in0 = input_vector[j];
- const __m256i in1 = input_vector[j + 1];
-
- m256_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
- m256_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
- m256_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
- m256_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
- }
- }
- for (; j < (int)kNumChunks; ++j)
- {
- const __m256i in = input_vector[j];
-
- m256_add_dpbusd_epi32(sum0, in, row0[j]);
- m256_add_dpbusd_epi32(sum1, in, row1[j]);
- m256_add_dpbusd_epi32(sum2, in, row2[j]);
- m256_add_dpbusd_epi32(sum3, in, row3[j]);
- }
-
- *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
- }
- }
- else if constexpr (kOutputDimensions == 1)
- {
- __m256i sum0 = _mm256_setzero_si256();
-
- const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
-
- for (IndexType j = 0; j < kNumChunks; ++j)
- {
- const __m256i in = input_vector[j];
-
- m256_add_dpbusd_epi32(sum0, in, row0[j]);
- }
-
- output[0] = m256_hadd(sum0, biases_[0]);
- }
- else
- {
- // This case can never happen because kOutputDimensions
- // is always 1 or a multiple of kSimdWidth.
- assert(false);
- }
-
-#elif defined (USE_SSSE3)
-
- constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-
- auto output = reinterpret_cast<OutputType*>(buffer);
- const auto input_vector = reinterpret_cast<const __m128i*>(input);
-
- // kOutputDimensions is either 1 or a multiple of kSimdWidth
- // because then it is also an input dimension.
- if constexpr (kOutputDimensions % 4 == 0)
- {
- for (IndexType i = 0; i < kOutputDimensions; i += 4)
- {
- const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
- const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
- const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
- const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
-
- const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
- __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
-
- __m128i sum0 = _mm_setzero_si128();
- __m128i sum1 = _mm_setzero_si128();
- __m128i sum2 = _mm_setzero_si128();
- __m128i sum3 = _mm_setzero_si128();
-
- const auto row0 = reinterpret_cast<const __m128i*>(&weights_[offset0]);
- const auto row1 = reinterpret_cast<const __m128i*>(&weights_[offset1]);
- const auto row2 = reinterpret_cast<const __m128i*>(&weights_[offset2]);
- const auto row3 = reinterpret_cast<const __m128i*>(&weights_[offset3]);
-
- int j = 0;
- if (!canSaturate16x4[i / 4])
- {
- for (; j < (int)kNumChunks - 1; j += 2)
- {
- const __m128i in0 = input_vector[j];
- const __m128i in1 = input_vector[j + 1];
-
- m128_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
- m128_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
- m128_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
- m128_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
- }
- }
- for (; j < (int)kNumChunks; ++j)
- {
- const __m128i in = input_vector[j];
-
- m128_add_dpbusd_epi32(sum0, in, row0[j]);
- m128_add_dpbusd_epi32(sum1, in, row1[j]);
- m128_add_dpbusd_epi32(sum2, in, row2[j]);
- m128_add_dpbusd_epi32(sum3, in, row3[j]);
- }
-
- *outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias);
- }
- }
- else if constexpr (kOutputDimensions == 1)
- {
- __m128i sum0 = _mm_setzero_si128();
-
- const auto row0 = reinterpret_cast<const __m128i*>(&weights_[0]);
-
- for (int j = 0; j < (int)kNumChunks; ++j)
- {
- const __m128i in = input_vector[j];
-
- m128_add_dpbusd_epi32(sum0, in, row0[j]);
- }
-
- output[0] = m128_hadd(sum0, biases_[0]);
- }
- else
- {
- // This case can never happen because kOutputDimensions
- // is always 1 or a multiple of kSimdWidth.
- assert(false);
- }
-
-#else
-
-// Use old implementation for the other architectures.
-
- auto output = reinterpret_cast<OutputType*>(buffer);
-
-#if defined(USE_SSE2)
- constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-#ifndef USE_SSSE3
- const __m128i kZeros = _mm_setzero_si128();
-#else
- const __m128i kOnes = _mm_set1_epi16(1);
-#endif
- const auto input_vector = reinterpret_cast<const __m128i*>(input);
-
-#elif defined(USE_MMX)
- constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
- const __m64 kZeros = _mm_setzero_si64();
- const auto input_vector = reinterpret_cast<const __m64*>(input);
-
-#elif defined(USE_NEON)
- constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
- const auto input_vector = reinterpret_cast<const int8x8_t*>(input);
-#endif
+ // We cannot use AVX512 for the last layer because there are only 32 inputs
+ // and the buffer is not padded to 64 elements.
+ #if defined(USE_AVX2)
+ using vec_t = __m256i;
+ #define vec_setzero() _mm256_setzero_si256()
+ #define vec_set_32 _mm256_set1_epi32
+ #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
+ #define vec_hadd Simd::m256_hadd
+ #elif defined(USE_SSSE3)
+ using vec_t = __m128i;
+ #define vec_setzero() _mm_setzero_si128()
+ #define vec_set_32 _mm_set1_epi32
+ #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
+ #define vec_hadd Simd::m128_hadd
+ #elif defined(USE_NEON_DOTPROD)
+ using vec_t = int32x4_t;
+ #define vec_setzero() vdupq_n_s32(0)
+ #define vec_set_32 vdupq_n_s32
+ #define vec_add_dpbusd_32(acc, a, b) \
+ Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \
+ vreinterpretq_s8_s32(b))
+ #define vec_hadd Simd::neon_m128_hadd
+ #endif
+
+ const auto inputVector = reinterpret_cast<const vec_t*>(input);
+
+ static constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(InputType);
+
+ static_assert(PaddedInputDimensions % InputSimdWidth == 0);
+
+ constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth;
+ vec_t sum0 = vec_setzero();
+ const auto row0 = reinterpret_cast<const vec_t*>(&weights[0]);
+
+ for (int j = 0; j < int(NumChunks); ++j)
+ {
+ const vec_t in = inputVector[j];
+ vec_add_dpbusd_32(sum0, in, row0[j]);
+ }
+ output[0] = vec_hadd(sum0, biases[0]);