- constexpr IndexType kNumChunks512 = kPaddedInputDimensions / (kSimdWidth * 2);
- constexpr IndexType kNumChunks256 = kPaddedInputDimensions / kSimdWidth;
-
- const auto output = reinterpret_cast<OutputType*>(buffer);
-
- // Since to saturate a zmm register it takes 64 bytes we
- // cannot use AVX512 for the smaller affine transforms.
- // Instead we fallback to a AVX2 implementation if the
- // kInputDimensions isn't a multiple of 64.
- // Note that this means that for example for
- // kInputDimensions of 96 we fallback to AVX2 even though
- // the first 64 elements could be processed with AVX512.
- // This is caused by mixing the __m256 and __m512 variables
- // required to better handle that case and it would
- // require handling more cases statically not to lose performance.
- // This should be revisited if such input dimensions are to be considered.
- [[maybe_unused]] const auto input_vector512 = reinterpret_cast<const __m512i*>(input);
- [[maybe_unused]] const auto input_vector256 = reinterpret_cast<const __m256i*>(input);
-
- // kOutputDimensions is either 1 or a multiple of kSimdWidth
- // because then it is also an input dimension.
- if constexpr (kOutputDimensions % 16 == 0 && kNumChunks256 == 1)
- {
- for (IndexType i = 0; i < kOutputDimensions; i += 16)
- {
- const IndexType offset01a = (i + 0) * kPaddedInputDimensions;
- const IndexType offset23a = (i + 2) * kPaddedInputDimensions;
- const IndexType offset45a = (i + 4) * kPaddedInputDimensions;
- const IndexType offset67a = (i + 6) * kPaddedInputDimensions;
- const IndexType offset01b = (i + 8) * kPaddedInputDimensions;
- const IndexType offset23b = (i + 10) * kPaddedInputDimensions;
- const IndexType offset45b = (i + 12) * kPaddedInputDimensions;
- const IndexType offset67b = (i + 14) * kPaddedInputDimensions;
-
- const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
- __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
-
- __m512i sum01a = _mm512_setzero_si512();
- __m512i sum23a = _mm512_setzero_si512();
- __m512i sum45a = _mm512_setzero_si512();
- __m512i sum67a = _mm512_setzero_si512();
- __m512i sum01b = _mm512_setzero_si512();
- __m512i sum23b = _mm512_setzero_si512();
- __m512i sum45b = _mm512_setzero_si512();
- __m512i sum67b = _mm512_setzero_si512();
-
- const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
- const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
- const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
- const auto row67a = *reinterpret_cast<const __m512i*>(&weights_[offset67a]);
- const auto row01b = *reinterpret_cast<const __m512i*>(&weights_[offset01b]);
- const auto row23b = *reinterpret_cast<const __m512i*>(&weights_[offset23b]);
- const auto row45b = *reinterpret_cast<const __m512i*>(&weights_[offset45b]);
- const auto row67b = *reinterpret_cast<const __m512i*>(&weights_[offset67b]);
-
- const __m256i in256 = input_vector256[0];
- const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
-
- m512_add_dpbusd_epi32(sum01a, in, row01a);
- m512_add_dpbusd_epi32(sum23a, in, row23a);
- m512_add_dpbusd_epi32(sum45a, in, row45a);
- m512_add_dpbusd_epi32(sum67a, in, row67a);
- m512_add_dpbusd_epi32(sum01b, in, row01b);
- m512_add_dpbusd_epi32(sum23b, in, row23b);
- m512_add_dpbusd_epi32(sum45b, in, row45b);
- m512_add_dpbusd_epi32(sum67b, in, row67b);
-
- *outptr = m512_hadd256x16(
- sum01a, sum23a, sum45a, sum67a,
- sum01b, sum23b, sum45b, sum67b, bias);
- }
- }
- else if constexpr (kOutputDimensions % 4 == 0)
- {
- for (IndexType i = 0; i < kOutputDimensions; i += 4)
- {
- const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
- const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
- const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
- const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
-
- const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
- __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
-
- if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
- {
- __m512i sum0 = _mm512_setzero_si512();
- __m512i sum1 = _mm512_setzero_si512();
- __m512i sum2 = _mm512_setzero_si512();
- __m512i sum3 = _mm512_setzero_si512();
-
- const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
- const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
- const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
- const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
-
- int j = 0;
- if (!canSaturate16x4[i / 4])
- {
- for (; j < (int)kNumChunks512 - 1; j += 2)
- {
- const __m512i in0 = input_vector512[j];
- const __m512i in1 = input_vector512[j + 1];
-
- m512_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
- m512_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
- m512_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
- m512_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
- }
- }
- for (; j < (int)kNumChunks512; ++j)
- {
- const __m512i in = input_vector512[j];
-
- m512_add_dpbusd_epi32(sum0, in, row0[j]);
- m512_add_dpbusd_epi32(sum1, in, row1[j]);
- m512_add_dpbusd_epi32(sum2, in, row2[j]);
- m512_add_dpbusd_epi32(sum3, in, row3[j]);
- }
-
- *outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias);
- }
- else
- {
- __m256i sum0 = _mm256_setzero_si256();
- __m256i sum1 = _mm256_setzero_si256();
- __m256i sum2 = _mm256_setzero_si256();
- __m256i sum3 = _mm256_setzero_si256();
-
- const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
- const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
- const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
- const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
-
- for (IndexType j = 0; j < kNumChunks256; ++j)
- {
- const __m256i in = input_vector256[j];
-
- m256_add_dpbusd_epi32(sum0, in, row0[j]);
- m256_add_dpbusd_epi32(sum1, in, row1[j]);
- m256_add_dpbusd_epi32(sum2, in, row2[j]);
- m256_add_dpbusd_epi32(sum3, in, row3[j]);
- }
-
- *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
- }
- }
- }
- else if constexpr (kOutputDimensions == 1)
- {
- if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
- {
- __m512i sum0 = _mm512_setzero_si512();