#elif defined (USE_SSSE3)
static constexpr const IndexType OutputSimdWidth = SimdWidth / 4;
#endif
+#if defined (USE_AVX512)
+ static constexpr const IndexType InputSimdWidth = SimdWidth * 2;
+#elif defined (USE_SSSE3)
+ static constexpr const IndexType InputSimdWidth = SimdWidth;
+#endif
// Size of forward propagation buffer used in this layer
static constexpr std::size_t SelfBufferSize =
for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
#if !defined (USE_SSSE3)
weights[i] = read_little_endian<WeightType>(stream);
+#elif defined (USE_VNNI) || defined (USE_AVX512)
+ if constexpr (OutputDimensions <= 8 && OutputDimensions != 1)
+ weights[i] = read_little_endian<WeightType>(stream);
+ else
+ weights[
+ (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
+ i / PaddedInputDimensions * 4 +
+ i % 4
+ ] = read_little_endian<WeightType>(stream);
#else
weights[
(i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
return !stream.fail();
}
-
// Forward propagation
const OutputType* propagate(
const TransformedFeatureType* transformedFeatures, char* buffer) const {
return _mm512_reduce_add_epi32(sum) + bias;
};
+ [[maybe_unused]] auto m512_hadd128x16_interleave = [](
+ __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
+
+ __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
+ __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
+
+ __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
+ __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
+
+ __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
+ __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
+
+ __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
+ __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
+
+ return _mm512_add_epi32(sum0123a, sum0123b);
+ };
+
+ [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
+ __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
+
+ __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
+
+ __m256i sum256lo = _mm512_castsi512_si256(sum);
+ __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
+
+ sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
+
+ __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
+ __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
+
+ return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
+ };
+
[[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
#if defined (USE_VNNI)
acc = _mm512_dpbusd_epi32(acc, a, b);
#endif
};
+ [[maybe_unused]] auto m512_add_dpbusd_epi32x2 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
+#if defined (USE_VNNI)
+ acc = _mm512_dpbusd_epi32(acc, a0, b0);
+ acc = _mm512_dpbusd_epi32(acc, a1, b1);
+#else
+ __m512i product0 = _mm512_maddubs_epi16(a0, b0);
+ __m512i product1 = _mm512_maddubs_epi16(a1, b1);
+ product0 = _mm512_adds_epi16(product0, product1);
+ product0 = _mm512_madd_epi16(product0, Ones512);
+ acc = _mm512_add_epi32(acc, product0);
+#endif
+ };
+
[[maybe_unused]] auto m512_add_dpbusd_epi32x4 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1,
__m512i a2, __m512i b2, __m512i a3, __m512i b3) {
#if defined (USE_VNNI)
return _mm_cvtsi128_si32(sum128) + bias;
};
+ [[maybe_unused]] auto m256_haddx4 = [](__m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3, __m128i bias) -> __m128i {
+ sum0 = _mm256_hadd_epi32(sum0, sum1);
+ sum2 = _mm256_hadd_epi32(sum2, sum3);
+
+ sum0 = _mm256_hadd_epi32(sum0, sum2);
+
+ __m128i sum128lo = _mm256_castsi256_si128(sum0);
+ __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
+
+ return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
+ };
+
[[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
#if defined (USE_VNNI)
acc = _mm256_dpbusd_epi32(acc, a, b);
#endif
};
+ [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
+#if defined (USE_VNNI)
+ acc = _mm256_dpbusd_epi32(acc, a0, b0);
+ acc = _mm256_dpbusd_epi32(acc, a1, b1);
+#else
+ __m256i product0 = _mm256_maddubs_epi16(a0, b0);
+ __m256i product1 = _mm256_maddubs_epi16(a1, b1);
+ product0 = _mm256_adds_epi16(product0, product1);
+ product0 = _mm256_madd_epi16(product0, Ones256);
+ acc = _mm256_add_epi32(acc, product0);
+#endif
+ };
+
[[maybe_unused]] auto m256_add_dpbusd_epi32x4 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1,
__m256i a2, __m256i b2, __m256i a3, __m256i b3) {
#if defined (USE_VNNI)
return _mm_cvtsi128_si32(sum) + bias;
};
+ [[maybe_unused]] auto m128_haddx4 = [](__m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3, __m128i bias) -> __m128i {
+ sum0 = _mm_hadd_epi32(sum0, sum1);
+ sum2 = _mm_hadd_epi32(sum2, sum3);
+ sum0 = _mm_hadd_epi32(sum0, sum2);
+ return _mm_add_epi32(sum0, bias);
+ };
+
[[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
__m128i product0 = _mm_maddubs_epi16(a, b);
product0 = _mm_madd_epi16(product0, Ones128);
acc = _mm_add_epi32(acc, product0);
};
+ [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
+ __m128i product0 = _mm_maddubs_epi16(a0, b0);
+ __m128i product1 = _mm_maddubs_epi16(a1, b1);
+ product0 = _mm_adds_epi16(product0, product1);
+ product0 = _mm_madd_epi16(product0, Ones128);
+ acc = _mm_add_epi32(acc, product0);
+ };
+
[[maybe_unused]] auto m128_add_dpbusd_epi32x4 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1,
__m128i a2, __m128i b2, __m128i a3, __m128i b3) {
__m128i product0 = _mm_maddubs_epi16(a0, b0);
using vec_t = __m512i;
#define vec_setzero _mm512_setzero_si512
#define vec_set_32 _mm512_set1_epi32
- auto& vec_add_dpbusd_32 = m512_add_dpbusd_epi32;
- auto& vec_add_dpbusd_32x4 = m512_add_dpbusd_epi32x4;
- auto& vec_hadd = m512_hadd;
+ [[maybe_unused]] auto& vec_add_dpbusd_32 = m512_add_dpbusd_epi32;
+ [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m512_add_dpbusd_epi32x2;
+ [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m512_add_dpbusd_epi32x4;
+ [[maybe_unused]] auto& vec_hadd = m512_hadd;
+ [[maybe_unused]] auto& vec_haddx4 = m512_haddx4;
#elif defined (USE_AVX2)
using vec_t = __m256i;
#define vec_setzero _mm256_setzero_si256
#define vec_set_32 _mm256_set1_epi32
- auto& vec_add_dpbusd_32 = m256_add_dpbusd_epi32;
- auto& vec_add_dpbusd_32x4 = m256_add_dpbusd_epi32x4;
- auto& vec_hadd = m256_hadd;
+ [[maybe_unused]] auto& vec_add_dpbusd_32 = m256_add_dpbusd_epi32;
+ [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m256_add_dpbusd_epi32x2;
+ [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m256_add_dpbusd_epi32x4;
+ [[maybe_unused]] auto& vec_hadd = m256_hadd;
+ [[maybe_unused]] auto& vec_haddx4 = m256_haddx4;
#elif defined (USE_SSSE3)
using vec_t = __m128i;
#define vec_setzero _mm_setzero_si128
#define vec_set_32 _mm_set1_epi32
- auto& vec_add_dpbusd_32 = m128_add_dpbusd_epi32;
- auto& vec_add_dpbusd_32x4 = m128_add_dpbusd_epi32x4;
- auto& vec_hadd = m128_hadd;
+ [[maybe_unused]] auto& vec_add_dpbusd_32 = m128_add_dpbusd_epi32;
+ [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m128_add_dpbusd_epi32x2;
+ [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m128_add_dpbusd_epi32x4;
+ [[maybe_unused]] auto& vec_hadd = m128_hadd;
+ [[maybe_unused]] auto& vec_haddx4 = m128_haddx4;
#endif
#if defined (USE_SSSE3)
const auto output = reinterpret_cast<OutputType*>(buffer);
const auto inputVector = reinterpret_cast<const vec_t*>(input);
+#endif
+
+#if defined (USE_VNNI) || defined (USE_AVX512)
- static_assert(OutputDimensions % OutputSimdWidth == 0 || OutputDimensions == 1);
+ static_assert(OutputDimensions == 1 || OutputDimensions % 4 == 0);
// OutputDimensions is either 1 or a multiple of SimdWidth
// because then it is also an input dimension.
+ if constexpr (OutputDimensions <= 8 && OutputDimensions != 1)
+ {
+ constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth;
+
+ static_assert(NumChunks % 2 == 0);
+
+ const auto input_vec = reinterpret_cast<const vec_t*>(input);
+ const auto bias_vec = reinterpret_cast<const __m128i*>(biases);
+ auto out_vec = reinterpret_cast<__m128i*>(output);
+
+ vec_t regs[OutputDimensions];
+ for (IndexType k = 0; k < OutputDimensions; ++k)
+ regs[k] = vec_setzero();
+
+ for (IndexType i = 0; i < NumChunks / 2; ++i)
+ {
+ const vec_t in0 = input_vec[i * 2 + 0];
+ const vec_t in1 = input_vec[i * 2 + 1];
+ for (IndexType k = 0; k < OutputDimensions; ++k)
+ {
+ const vec_t w0 = reinterpret_cast<const vec_t*>(&weights[k * PaddedInputDimensions])[i * 2 + 0];
+ const vec_t w1 = reinterpret_cast<const vec_t*>(&weights[k * PaddedInputDimensions])[i * 2 + 1];
+ vec_add_dpbusd_32(regs[k], in0, w0);
+ vec_add_dpbusd_32(regs[k], in1, w1);
+ }
+ }
+
+ for (IndexType k = 0; k < OutputDimensions / 4; ++k)
+ {
+ out_vec[k] = vec_haddx4(
+ regs[k * 4 + 0],
+ regs[k * 4 + 1],
+ regs[k * 4 + 2],
+ regs[k * 4 + 3],
+ bias_vec[k]
+ );
+ }
+ }
+ else if constexpr (InputDimensions == 8)
+ {
+ const auto input32 = reinterpret_cast<const std::int32_t*>(input);
+ __m256i* outptr = reinterpret_cast<__m256i*>(output);
+ std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
+
+ const __m256i in0 = _mm256_set1_epi32(input32[0]);
+ const __m256i in1 = _mm256_set1_epi32(input32[1]);
+ const auto col0 = reinterpret_cast<const __m256i*>(&weights[0]);
+ const auto col1 = reinterpret_cast<const __m256i*>(&weights[OutputDimensions * 4]);
+ for (IndexType j = 0; j * 8 < OutputDimensions; ++j)
+ m256_add_dpbusd_epi32x2(outptr[j], in0, col0[j], in1, col1[j]);
+ }
+ else
+
+#elif defined (USE_SSSE3)
+
+ if constexpr (OutputDimensions % OutputSimdWidth == 0 && InputDimensions == 8)
+ {
+ const auto input32 = reinterpret_cast<const std::int32_t*>(input);
+ vec_t* outptr = reinterpret_cast<vec_t*>(output);
+ std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
+
+ const vec_t in0 = vec_set_32(input32[0]);
+ const vec_t in1 = vec_set_32(input32[1]);
+ const auto col0 = reinterpret_cast<const vec_t*>(&weights[0]);
+ const auto col1 = reinterpret_cast<const vec_t*>(&weights[OutputDimensions * 4]);
+ for (IndexType j = 0; j * OutputSimdWidth < OutputDimensions; ++j)
+ vec_add_dpbusd_32x2(outptr[j], in0, col0[j], in1, col1[j]);
+ }
+ else
+
+#endif
+
+#if defined (USE_SSSE3)
+
if constexpr (OutputDimensions % OutputSimdWidth == 0)
{
static_assert(InputDimensions % 16 == 0);
#if defined(USE_SSE2)
// At least a multiple of 16, with SSE2.
- static_assert(InputDimensions % SimdWidth == 0);
- constexpr IndexType NumChunks = InputDimensions / SimdWidth;
+ static_assert(PaddedInputDimensions % SimdWidth == 0);
+ constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
const __m128i Zeros = _mm_setzero_si128();
const auto inputVector = reinterpret_cast<const __m128i*>(input);
const auto inputVector = reinterpret_cast<const __m64*>(input);
#elif defined(USE_NEON)
- static_assert(InputDimensions % SimdWidth == 0);
- constexpr IndexType NumChunks = InputDimensions / SimdWidth;
+ static_assert(PaddedInputDimensions % SimdWidth == 0);
+ constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
#endif
_mm_empty();
#endif
+#endif
+
+#if (!defined (USE_SSSE3) && defined (USE_SSE2)) || defined (USE_NEON)
+ static_assert(SimdWidth <= 16, "Otherwise we run outside of the padding for the output.");
+ if constexpr (SimdWidth > OutputDimensions && OutputDimensions != 1)
+ for (IndexType i = OutputDimensions; i < SimdWidth; ++i)
+ output[i] = 0;
#endif
return output;