- if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
- {
- __m512i sum0 = _mm512_setzero_si512();
-
- const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
-
- for (IndexType j = 0; j < kNumChunks512; ++j)
- {
- const __m512i in = input_vector512[j];
-
- m512_add_dpbusd_epi32(sum0, in, row0[j]);
- }
-
- output[0] = m512_hadd(sum0, biases_[0]);
- }
- else
- {
- __m256i sum0 = _mm256_setzero_si256();
-
- const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
-
- for (IndexType j = 0; j < kNumChunks256; ++j)
- {
- const __m256i in = input_vector256[j];
-
- m256_add_dpbusd_epi32(sum0, in, row0[j]);
- }
-
- output[0] = m256_hadd(sum0, biases_[0]);
- }
- }
- else
- {
- // This case can never happen because kOutputDimensions
- // is always 1 or a multiple of kSimdWidth.
- assert(false);
- }
-
-#elif defined (USE_AVX2)
-
- constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-
- const auto output = reinterpret_cast<OutputType*>(buffer);
- const auto input_vector = reinterpret_cast<const __m256i*>(input);
-
- // kOutputDimensions is either 1 or a multiple of kSimdWidth
- // because then it is also an input dimension.
- if constexpr (kOutputDimensions % 4 == 0)
- {
- for (IndexType i = 0; i < kOutputDimensions; i += 4)
- {
- const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
- const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
- const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
- const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
-
- const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
- __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
-
- __m256i sum0 = _mm256_setzero_si256();
- __m256i sum1 = _mm256_setzero_si256();
- __m256i sum2 = _mm256_setzero_si256();
- __m256i sum3 = _mm256_setzero_si256();
-
- const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
- const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
- const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
- const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
-
- for (IndexType j = 0; j < kNumChunks; ++j)