- };
-
- [[maybe_unused]] auto m512_add_dpbusd_epi32x4 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1,
- __m512i a2, __m512i b2, __m512i a3, __m512i b3) {
-#if defined (USE_VNNI)
- acc = _mm512_dpbusd_epi32(acc, a0, b0);
- acc = _mm512_dpbusd_epi32(acc, a1, b1);
- acc = _mm512_dpbusd_epi32(acc, a2, b2);
- acc = _mm512_dpbusd_epi32(acc, a3, b3);
-#else
- __m512i product0 = _mm512_maddubs_epi16(a0, b0);
- __m512i product1 = _mm512_maddubs_epi16(a1, b1);
- __m512i product2 = _mm512_maddubs_epi16(a2, b2);
- __m512i product3 = _mm512_maddubs_epi16(a3, b3);
- product0 = _mm512_add_epi16(product0, product1);
- product2 = _mm512_add_epi16(product2, product3);
- product0 = _mm512_add_epi16(product0, product2);
- product0 = _mm512_madd_epi16(product0, kOnes512);
- acc = _mm512_add_epi32(acc, product0);
-#endif
- };
-
-#endif
-#if defined (USE_AVX2)
-
- [[maybe_unused]] const __m256i kOnes256 = _mm256_set1_epi16(1);
-
- [[maybe_unused]] auto m256_hadd = [](__m256i sum, int bias) -> int {
- __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
- sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
- sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
- return _mm_cvtsi128_si32(sum128) + bias;
- };
-
- [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
-#if defined (USE_VNNI)
- acc = _mm256_dpbusd_epi32(acc, a, b);
-#else
- __m256i product0 = _mm256_maddubs_epi16(a, b);
- product0 = _mm256_madd_epi16(product0, kOnes256);
- acc = _mm256_add_epi32(acc, product0);
-#endif
- };
-
- [[maybe_unused]] auto m256_add_dpbusd_epi32x4 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1,
- __m256i a2, __m256i b2, __m256i a3, __m256i b3) {
-#if defined (USE_VNNI)
- acc = _mm256_dpbusd_epi32(acc, a0, b0);
- acc = _mm256_dpbusd_epi32(acc, a1, b1);
- acc = _mm256_dpbusd_epi32(acc, a2, b2);
- acc = _mm256_dpbusd_epi32(acc, a3, b3);
-#else
- __m256i product0 = _mm256_maddubs_epi16(a0, b0);
- __m256i product1 = _mm256_maddubs_epi16(a1, b1);
- __m256i product2 = _mm256_maddubs_epi16(a2, b2);
- __m256i product3 = _mm256_maddubs_epi16(a3, b3);
- product0 = _mm256_add_epi16(product0, product1);
- product2 = _mm256_add_epi16(product2, product3);
- product0 = _mm256_add_epi16(product0, product2);
- product0 = _mm256_madd_epi16(product0, kOnes256);
- acc = _mm256_add_epi32(acc, product0);
-#endif
- };
-
-#endif
-#if defined (USE_SSSE3)
-
- [[maybe_unused]] const __m128i kOnes128 = _mm_set1_epi16(1);
-
- [[maybe_unused]] auto m128_hadd = [](__m128i sum, int bias) -> int {
- sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
- sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
- return _mm_cvtsi128_si32(sum) + bias;
- };
-
- [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
- __m128i product0 = _mm_maddubs_epi16(a, b);
- product0 = _mm_madd_epi16(product0, kOnes128);
- acc = _mm_add_epi32(acc, product0);
- };
-
- [[maybe_unused]] auto m128_add_dpbusd_epi32x4 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1,
- __m128i a2, __m128i b2, __m128i a3, __m128i b3) {
- __m128i product0 = _mm_maddubs_epi16(a0, b0);
- __m128i product1 = _mm_maddubs_epi16(a1, b1);
- __m128i product2 = _mm_maddubs_epi16(a2, b2);
- __m128i product3 = _mm_maddubs_epi16(a3, b3);
- product0 = _mm_add_epi16(product0, product1);
- product2 = _mm_add_epi16(product2, product3);
- product0 = _mm_add_epi16(product0, product2);
- product0 = _mm_madd_epi16(product0, kOnes128);
- acc = _mm_add_epi32(acc, product0);
- };
-
-#endif
-
-#if defined (USE_AVX512)
- using vec_t = __m512i;
- #define vec_setzero _mm512_setzero_si512
- #define vec_set_32 _mm512_set1_epi32
- auto& vec_add_dpbusd_32 = m512_add_dpbusd_epi32;
- auto& vec_add_dpbusd_32x4 = m512_add_dpbusd_epi32x4;
- auto& vec_hadd = m512_hadd;
-#elif defined (USE_AVX2)
- using vec_t = __m256i;
- #define vec_setzero _mm256_setzero_si256
- #define vec_set_32 _mm256_set1_epi32
- auto& vec_add_dpbusd_32 = m256_add_dpbusd_epi32;
- auto& vec_add_dpbusd_32x4 = m256_add_dpbusd_epi32x4;
- auto& vec_hadd = m256_hadd;
-#elif defined (USE_SSSE3)
- using vec_t = __m128i;
- #define vec_setzero _mm_setzero_si128
- #define vec_set_32 _mm_set1_epi32
- auto& vec_add_dpbusd_32 = m128_add_dpbusd_epi32;
- auto& vec_add_dpbusd_32x4 = m128_add_dpbusd_epi32x4;
- auto& vec_hadd = m128_hadd;
-#endif
-
-#if defined (USE_SSSE3)
-
- const auto output = reinterpret_cast<OutputType*>(buffer);
- const auto input_vector = reinterpret_cast<const vec_t*>(input);
-
- static_assert(kOutputDimensions % kOutputSimdWidth == 0 || kOutputDimensions == 1);
-
- // kOutputDimensions is either 1 or a multiple of kSimdWidth
- // because then it is also an input dimension.
- if constexpr (kOutputDimensions % kOutputSimdWidth == 0)
- {
- constexpr IndexType kNumChunks = kPaddedInputDimensions / 4;
-
- const auto input32 = reinterpret_cast<const std::int32_t*>(input);
- vec_t* outptr = reinterpret_cast<vec_t*>(output);
- std::memcpy(output, biases_, kOutputDimensions * sizeof(OutputType));
-
- for (int i = 0; i < (int)kNumChunks - 3; i += 4)
- {
- const vec_t in0 = vec_set_32(input32[i + 0]);
- const vec_t in1 = vec_set_32(input32[i + 1]);
- const vec_t in2 = vec_set_32(input32[i + 2]);
- const vec_t in3 = vec_set_32(input32[i + 3]);
- const auto col0 = reinterpret_cast<const vec_t*>(&weights_[(i + 0) * kOutputDimensions * 4]);
- const auto col1 = reinterpret_cast<const vec_t*>(&weights_[(i + 1) * kOutputDimensions * 4]);
- const auto col2 = reinterpret_cast<const vec_t*>(&weights_[(i + 2) * kOutputDimensions * 4]);
- const auto col3 = reinterpret_cast<const vec_t*>(&weights_[(i + 3) * kOutputDimensions * 4]);
- for (int j = 0; j * kOutputSimdWidth < kOutputDimensions; ++j)
- vec_add_dpbusd_32x4(outptr[j], in0, col0[j], in1, col1[j], in2, col2[j], in3, col3[j]);
- }
- for (int i = 0; i < canSaturate16.count; ++i)
- output[canSaturate16.ids[i].out] += input[canSaturate16.ids[i].in] * canSaturate16.ids[i].w;
- }
- else if constexpr (kOutputDimensions == 1)
- {
-#if defined (USE_AVX512)
- if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) != 0)
- {
- constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
- const auto input_vector256 = reinterpret_cast<const __m256i*>(input);
-
- __m256i sum0 = _mm256_setzero_si256();
- const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
-
- for (int j = 0; j < (int)kNumChunks; ++j)
- {
- const __m256i in = input_vector256[j];
- m256_add_dpbusd_epi32(sum0, in, row0[j]);
- }
- output[0] = m256_hadd(sum0, biases_[0]);
- }
- else
-#endif
- {
-#if defined (USE_AVX512)
- constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2);
-#else
- constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-#endif
- vec_t sum0 = vec_setzero();
- const auto row0 = reinterpret_cast<const vec_t*>(&weights_[0]);
-
- for (int j = 0; j < (int)kNumChunks; ++j)
- {
- const vec_t in = input_vector[j];
- vec_add_dpbusd_32(sum0, in, row0[j]);
- }
- output[0] = vec_hadd(sum0, biases_[0]);
- }
- }
-
-#else
-
-// Use old implementation for the other architectures.