- [[maybe_unused]] const __m256i Ones256 = _mm256_set1_epi16(1);
-
- [[maybe_unused]] auto m256_hadd = [](__m256i sum, int bias) -> int {
- __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
- sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
- sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
- return _mm_cvtsi128_si32(sum128) + bias;
- };
-
- [[maybe_unused]] auto m256_haddx4 = [](__m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3, __m128i bias) -> __m128i {
- sum0 = _mm256_hadd_epi32(sum0, sum1);
- sum2 = _mm256_hadd_epi32(sum2, sum3);
-
- sum0 = _mm256_hadd_epi32(sum0, sum2);
-
- __m128i sum128lo = _mm256_castsi256_si128(sum0);
- __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
-
- return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
- };
-
- [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
-#if defined (USE_VNNI)
- acc = _mm256_dpbusd_epi32(acc, a, b);
-#else
- __m256i product0 = _mm256_maddubs_epi16(a, b);
- product0 = _mm256_madd_epi16(product0, Ones256);
- acc = _mm256_add_epi32(acc, product0);
-#endif
- };
-
- [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
-#if defined (USE_VNNI)
- acc = _mm256_dpbusd_epi32(acc, a0, b0);
- acc = _mm256_dpbusd_epi32(acc, a1, b1);
-#else
- __m256i product0 = _mm256_maddubs_epi16(a0, b0);
- __m256i product1 = _mm256_maddubs_epi16(a1, b1);
- product0 = _mm256_adds_epi16(product0, product1);
- product0 = _mm256_madd_epi16(product0, Ones256);
- acc = _mm256_add_epi32(acc, product0);
-#endif
- };
-
- [[maybe_unused]] auto m256_add_dpbusd_epi32x4 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1,
- __m256i a2, __m256i b2, __m256i a3, __m256i b3) {
-#if defined (USE_VNNI)
- acc = _mm256_dpbusd_epi32(acc, a0, b0);
- acc = _mm256_dpbusd_epi32(acc, a1, b1);
- acc = _mm256_dpbusd_epi32(acc, a2, b2);
- acc = _mm256_dpbusd_epi32(acc, a3, b3);
-#else
- __m256i product0 = _mm256_maddubs_epi16(a0, b0);
- __m256i product1 = _mm256_maddubs_epi16(a1, b1);
- __m256i product2 = _mm256_maddubs_epi16(a2, b2);
- __m256i product3 = _mm256_maddubs_epi16(a3, b3);
- product0 = _mm256_adds_epi16(product0, product1);
- product0 = _mm256_madd_epi16(product0, Ones256);
- product2 = _mm256_adds_epi16(product2, product3);
- product2 = _mm256_madd_epi16(product2, Ones256);
- acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product2));
-#endif
- };
-
-#endif
-#if defined (USE_SSSE3)
-
- [[maybe_unused]] const __m128i Ones128 = _mm_set1_epi16(1);
-
- [[maybe_unused]] auto m128_hadd = [](__m128i sum, int bias) -> int {
- sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
- sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
- return _mm_cvtsi128_si32(sum) + bias;
- };
-
- [[maybe_unused]] auto m128_haddx4 = [](__m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3, __m128i bias) -> __m128i {
- sum0 = _mm_hadd_epi32(sum0, sum1);
- sum2 = _mm_hadd_epi32(sum2, sum3);
- sum0 = _mm_hadd_epi32(sum0, sum2);
- return _mm_add_epi32(sum0, bias);
- };
-
- [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
- __m128i product0 = _mm_maddubs_epi16(a, b);
- product0 = _mm_madd_epi16(product0, Ones128);
- acc = _mm_add_epi32(acc, product0);
- };
-
- [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
- __m128i product0 = _mm_maddubs_epi16(a0, b0);
- __m128i product1 = _mm_maddubs_epi16(a1, b1);
- product0 = _mm_adds_epi16(product0, product1);
- product0 = _mm_madd_epi16(product0, Ones128);
- acc = _mm_add_epi32(acc, product0);
- };
-
- [[maybe_unused]] auto m128_add_dpbusd_epi32x4 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1,
- __m128i a2, __m128i b2, __m128i a3, __m128i b3) {
- __m128i product0 = _mm_maddubs_epi16(a0, b0);
- __m128i product1 = _mm_maddubs_epi16(a1, b1);
- __m128i product2 = _mm_maddubs_epi16(a2, b2);
- __m128i product3 = _mm_maddubs_epi16(a3, b3);
- product0 = _mm_adds_epi16(product0, product1);
- product0 = _mm_madd_epi16(product0, Ones128);
- product2 = _mm_adds_epi16(product2, product3);
- product2 = _mm_madd_epi16(product2, Ones128);
- acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product2));
- };
-
-#endif
-
-#if defined (USE_AVX512)
- using vec_t = __m512i;
- #define vec_setzero _mm512_setzero_si512
- #define vec_set_32 _mm512_set1_epi32
- [[maybe_unused]] auto& vec_add_dpbusd_32 = m512_add_dpbusd_epi32;
- [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m512_add_dpbusd_epi32x2;
- [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m512_add_dpbusd_epi32x4;
- [[maybe_unused]] auto& vec_hadd = m512_hadd;
- [[maybe_unused]] auto& vec_haddx4 = m512_haddx4;
-#elif defined (USE_AVX2)
- using vec_t = __m256i;
- #define vec_setzero _mm256_setzero_si256
- #define vec_set_32 _mm256_set1_epi32
- [[maybe_unused]] auto& vec_add_dpbusd_32 = m256_add_dpbusd_epi32;
- [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m256_add_dpbusd_epi32x2;
- [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m256_add_dpbusd_epi32x4;
- [[maybe_unused]] auto& vec_hadd = m256_hadd;
- [[maybe_unused]] auto& vec_haddx4 = m256_haddx4;
-#elif defined (USE_SSSE3)
- using vec_t = __m128i;
- #define vec_setzero _mm_setzero_si128
- #define vec_set_32 _mm_set1_epi32
- [[maybe_unused]] auto& vec_add_dpbusd_32 = m128_add_dpbusd_epi32;
- [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m128_add_dpbusd_epi32x2;
- [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m128_add_dpbusd_epi32x4;
- [[maybe_unused]] auto& vec_hadd = m128_hadd;
- [[maybe_unused]] auto& vec_haddx4 = m128_haddx4;
-#endif
-
-#if defined (USE_SSSE3)
- const auto output = reinterpret_cast<OutputType*>(buffer);
- const auto inputVector = reinterpret_cast<const vec_t*>(input);
-#endif
-
-#if defined (USE_VNNI) || defined (USE_AVX512)
-
- static_assert(OutputDimensions == 1 || OutputDimensions % 4 == 0);
-
- // OutputDimensions is either 1 or a multiple of SimdWidth
- // because then it is also an input dimension.
- if constexpr (OutputDimensions <= 8 && OutputDimensions != 1)
- {
- constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth;
-
- static_assert(NumChunks % 2 == 0);
-
- const auto input_vec = reinterpret_cast<const vec_t*>(input);
- const auto bias_vec = reinterpret_cast<const __m128i*>(biases);
- auto out_vec = reinterpret_cast<__m128i*>(output);
-
- vec_t regs[OutputDimensions];
- for (IndexType k = 0; k < OutputDimensions; ++k)
- regs[k] = vec_setzero();
-
- for (IndexType i = 0; i < NumChunks / 2; ++i)
- {
- const vec_t in0 = input_vec[i * 2 + 0];
- const vec_t in1 = input_vec[i * 2 + 1];
- for (IndexType k = 0; k < OutputDimensions; ++k)
- {
- const vec_t w0 = reinterpret_cast<const vec_t*>(&weights[k * PaddedInputDimensions])[i * 2 + 0];
- const vec_t w1 = reinterpret_cast<const vec_t*>(&weights[k * PaddedInputDimensions])[i * 2 + 1];
- vec_add_dpbusd_32(regs[k], in0, w0);
- vec_add_dpbusd_32(regs[k], in1, w1);
- }
- }
-
- for (IndexType k = 0; k < OutputDimensions / 4; ++k)
- {
- out_vec[k] = vec_haddx4(
- regs[k * 4 + 0],
- regs[k * 4 + 1],
- regs[k * 4 + 2],
- regs[k * 4 + 3],
- bias_vec[k]
- );
- }
- }
- else if constexpr (InputDimensions == 8)
- {
- const auto input32 = reinterpret_cast<const std::int32_t*>(input);
- __m256i* outptr = reinterpret_cast<__m256i*>(output);
- std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
-
- const __m256i in0 = _mm256_set1_epi32(input32[0]);
- const __m256i in1 = _mm256_set1_epi32(input32[1]);
- const auto col0 = reinterpret_cast<const __m256i*>(&weights[0]);
- const auto col1 = reinterpret_cast<const __m256i*>(&weights[OutputDimensions * 4]);
- for (IndexType j = 0; j * 8 < OutputDimensions; ++j)
- m256_add_dpbusd_epi32x2(outptr[j], in0, col0[j], in1, col1[j]);
- }
- else
-
-#elif defined (USE_SSSE3)
-
- if constexpr (OutputDimensions % OutputSimdWidth == 0 && InputDimensions == 8)
- {
- const auto input32 = reinterpret_cast<const std::int32_t*>(input);
- vec_t* outptr = reinterpret_cast<vec_t*>(output);
- std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
-
- const vec_t in0 = vec_set_32(input32[0]);
- const vec_t in1 = vec_set_32(input32[1]);
- const auto col0 = reinterpret_cast<const vec_t*>(&weights[0]);
- const auto col1 = reinterpret_cast<const vec_t*>(&weights[OutputDimensions * 4]);
- for (IndexType j = 0; j * OutputSimdWidth < OutputDimensions; ++j)
- vec_add_dpbusd_32x2(outptr[j], in0, col0[j], in1, col1[j]);
- }
- else
-
-#endif
-
-#if defined (USE_SSSE3)
-
- if constexpr (OutputDimensions % OutputSimdWidth == 0)
- {
- static_assert(InputDimensions % 16 == 0);
-
- constexpr IndexType NumChunks = InputDimensions / 4;
- constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
-
- const auto input32 = reinterpret_cast<const std::int32_t*>(input);
- const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
- vec_t outs[NumRegs];
- for (IndexType k = 0; k < NumRegs; ++k)
- outs[k] = biasvec[k];
-
- for (IndexType i = 0; i < NumChunks; i += 4)
- {
- const vec_t in0 = vec_set_32(input32[i + 0]);
- const vec_t in1 = vec_set_32(input32[i + 1]);
- const vec_t in2 = vec_set_32(input32[i + 2]);
- const vec_t in3 = vec_set_32(input32[i + 3]);
- const auto col0 = reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]);
- const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
- const auto col2 = reinterpret_cast<const vec_t*>(&weights[(i + 2) * OutputDimensions * 4]);
- const auto col3 = reinterpret_cast<const vec_t*>(&weights[(i + 3) * OutputDimensions * 4]);
- for (IndexType k = 0; k < NumRegs; ++k)
- vec_add_dpbusd_32x4(outs[k], in0, col0[k], in1, col1[k], in2, col2[k], in3, col3[k]);
- }
-
- vec_t* outptr = reinterpret_cast<vec_t*>(output);
- for (IndexType k = 0; k < NumRegs; ++k)
- outptr[k] = outs[k];
- }
- else if constexpr (OutputDimensions == 1)
- {
- static_assert(InputDimensions % 4 == 0);
-
-#if defined (USE_AVX512)
- if constexpr (PaddedInputDimensions % (SimdWidth * 2) != 0)
- {
- constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
- const auto inputVector256 = reinterpret_cast<const __m256i*>(input);
-
- __m256i sum0 = _mm256_setzero_si256();
- const auto row0 = reinterpret_cast<const __m256i*>(&weights[0]);
-
- for (int j = 0; j < (int)NumChunks; ++j)
- {
- const __m256i in = inputVector256[j];
- m256_add_dpbusd_epi32(sum0, in, row0[j]);
- }
- output[0] = m256_hadd(sum0, biases[0]);
- }
- else
-#endif
- {
-#if defined (USE_AVX512)
- constexpr IndexType NumChunks = PaddedInputDimensions / (SimdWidth * 2);
-#else
- constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
-#endif
- vec_t sum0 = vec_setzero();
- const auto row0 = reinterpret_cast<const vec_t*>(&weights[0]);
-
- for (int j = 0; j < (int)NumChunks; ++j)
- {
- const vec_t in = inputVector[j];
- vec_add_dpbusd_32(sum0, in, row0[j]);
- }
- output[0] = vec_hadd(sum0, biases[0]);
- }
- }
-
-#else
-
-// Use old implementation for the other architectures.
-
- auto output = reinterpret_cast<OutputType*>(buffer);
-
-#if defined(USE_SSE2)
- // At least a multiple of 16, with SSE2.
- static_assert(PaddedInputDimensions % SimdWidth == 0);
- constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
- const __m128i Zeros = _mm_setzero_si128();
- const auto inputVector = reinterpret_cast<const __m128i*>(input);