in affine transform for AVX512/AVX2/SSSE3
The idea is to initialize sum with the first element instead of zero.
Reduce one add_epi32 and one set_zero SIMD instructions for each output dimension.
sum = 0; for i = 1 to n sum += a[i] ->
sum = a[1]; for i = 2 to n sum += a[i]
STC:
LLR: 2.95 (-2.94,2.94) {-0.25,1.25}
Total: 69048 W: 7024 L: 6799 D: 55225
Ptnml(0-2): 260, 5175, 23458, 5342, 289
https://tests.stockfishchess.org/tests/view/
5faf2cf467cbf42301d6aa06
closes https://github.com/official-stockfish/Stockfish/pull/3227
No functional change.
marotear
Matthew Lai (matthewlai)
Matthew Sullivan (Matt14916)
marotear
Matthew Lai (matthewlai)
Matthew Sullivan (Matt14916)
Michael An (man)
Michael Byrne (MichaelB7)
Michael Chaly (Vizvezdenec)
Michael An (man)
Michael Byrne (MichaelB7)
Michael Chaly (Vizvezdenec)
return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
};
return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
};
- [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
+ [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
acc = _mm512_dpbusd_epi32(acc, a, b);
#else
acc = _mm512_dpbusd_epi32(acc, a, b);
#else
+ [[maybe_unused]] auto m512_dpbusd_epi32 = [=](__m512i a, __m512i b) -> __m512i {
__m512i product0 = _mm512_maddubs_epi16(a, b);
__m512i product0 = _mm512_maddubs_epi16(a, b);
- product0 = _mm512_madd_epi16(product0, kOnes512);
- acc = _mm512_add_epi32(acc, product0);
+ return _mm512_madd_epi16(product0, kOnes512);
return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
};
return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
};
-
- [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
+ [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
acc = _mm256_dpbusd_epi32(acc, a, b);
#else
acc = _mm256_dpbusd_epi32(acc, a, b);
#else
+ [[maybe_unused]] auto m256_dpbusd_epi32 = [=](__m256i a, __m256i b) -> __m256i {
__m256i product0 = _mm256_maddubs_epi16(a, b);
__m256i product0 = _mm256_maddubs_epi16(a, b);
- product0 = _mm256_madd_epi16(product0, kOnes256);
- acc = _mm256_add_epi32(acc, product0);
+ return _mm256_madd_epi16(product0, kOnes256);
return _mm_add_epi32(sum0, bias);
};
return _mm_add_epi32(sum0, bias);
};
- [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
+ [[maybe_unused]] auto m128_dpbusd_epi32 = [=](__m128i a, __m128i b) -> __m128i {
__m128i product0 = _mm_maddubs_epi16(a, b);
__m128i product0 = _mm_maddubs_epi16(a, b);
- product0 = _mm_madd_epi16(product0, kOnes128);
- acc = _mm_add_epi32(acc, product0);
+ return _mm_madd_epi16(product0, kOnes128);
const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
__m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
__m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
- __m512i sum01a = _mm512_setzero_si512();
- __m512i sum23a = _mm512_setzero_si512();
- __m512i sum45a = _mm512_setzero_si512();
- __m512i sum67a = _mm512_setzero_si512();
- __m512i sum01b = _mm512_setzero_si512();
- __m512i sum23b = _mm512_setzero_si512();
- __m512i sum45b = _mm512_setzero_si512();
- __m512i sum67b = _mm512_setzero_si512();
-
const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
const __m256i in256 = input_vector256[0];
const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
const __m256i in256 = input_vector256[0];
const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
+#if defined (USE_VNNI)
+ __m512i sum01a = _mm512_setzero_si512();
+ __m512i sum23a = _mm512_setzero_si512();
+ __m512i sum45a = _mm512_setzero_si512();
+ __m512i sum67a = _mm512_setzero_si512();
+ __m512i sum01b = _mm512_setzero_si512();
+ __m512i sum23b = _mm512_setzero_si512();
+ __m512i sum45b = _mm512_setzero_si512();
+ __m512i sum67b = _mm512_setzero_si512();
+
m512_add_dpbusd_epi32(sum01a, in, row01a);
m512_add_dpbusd_epi32(sum23a, in, row23a);
m512_add_dpbusd_epi32(sum45a, in, row45a);
m512_add_dpbusd_epi32(sum01a, in, row01a);
m512_add_dpbusd_epi32(sum23a, in, row23a);
m512_add_dpbusd_epi32(sum45a, in, row45a);
m512_add_dpbusd_epi32(sum23b, in, row23b);
m512_add_dpbusd_epi32(sum45b, in, row45b);
m512_add_dpbusd_epi32(sum67b, in, row67b);
m512_add_dpbusd_epi32(sum23b, in, row23b);
m512_add_dpbusd_epi32(sum45b, in, row45b);
m512_add_dpbusd_epi32(sum67b, in, row67b);
+#else
+ __m512i sum01a = m512_dpbusd_epi32(in, row01a);
+ __m512i sum23a = m512_dpbusd_epi32(in, row23a);
+ __m512i sum45a = m512_dpbusd_epi32(in, row45a);
+ __m512i sum67a = m512_dpbusd_epi32(in, row67a);
+ __m512i sum01b = m512_dpbusd_epi32(in, row01b);
+ __m512i sum23b = m512_dpbusd_epi32(in, row23b);
+ __m512i sum45b = m512_dpbusd_epi32(in, row45b);
+ __m512i sum67b = m512_dpbusd_epi32(in, row67b);
+#endif
*outptr = m512_hadd256x16(
sum01a, sum23a, sum45a, sum67a,
*outptr = m512_hadd256x16(
sum01a, sum23a, sum45a, sum67a,
if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
{
if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
{
- __m512i sum0 = _mm512_setzero_si512();
- __m512i sum1 = _mm512_setzero_si512();
- __m512i sum2 = _mm512_setzero_si512();
- __m512i sum3 = _mm512_setzero_si512();
-
const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
- for (IndexType j = 0; j < kNumChunks512; ++j)
+#if defined (USE_VNNI)
+ __m512i sum0 = _mm512_setzero_si512();
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ const IndexType kStart = 0;
+#else
+ __m512i sum0 = m512_dpbusd_epi32(input_vector512[0], row0[0]);
+ __m512i sum1 = m512_dpbusd_epi32(input_vector512[0], row1[0]);
+ __m512i sum2 = m512_dpbusd_epi32(input_vector512[0], row2[0]);
+ __m512i sum3 = m512_dpbusd_epi32(input_vector512[0], row3[0]);
+ const IndexType kStart = 1;
+#endif
+
+ for (IndexType j = kStart; j < kNumChunks512; ++j)
{
const __m512i in = input_vector512[j];
{
const __m512i in = input_vector512[j];
m512_add_dpbusd_epi32(sum0, in, row0[j]);
m512_add_dpbusd_epi32(sum1, in, row1[j]);
m512_add_dpbusd_epi32(sum2, in, row2[j]);
m512_add_dpbusd_epi32(sum3, in, row3[j]);
m512_add_dpbusd_epi32(sum0, in, row0[j]);
m512_add_dpbusd_epi32(sum1, in, row1[j]);
m512_add_dpbusd_epi32(sum2, in, row2[j]);
m512_add_dpbusd_epi32(sum3, in, row3[j]);
+#else
+ sum0 = _mm512_add_epi32(sum0, m512_dpbusd_epi32(in, row0[j]));
+ sum1 = _mm512_add_epi32(sum1, m512_dpbusd_epi32(in, row1[j]));
+ sum2 = _mm512_add_epi32(sum2, m512_dpbusd_epi32(in, row2[j]));
+ sum3 = _mm512_add_epi32(sum3, m512_dpbusd_epi32(in, row3[j]));
+#endif
}
*outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias);
}
else
{
}
*outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias);
}
else
{
- __m256i sum0 = _mm256_setzero_si256();
- __m256i sum1 = _mm256_setzero_si256();
- __m256i sum2 = _mm256_setzero_si256();
- __m256i sum3 = _mm256_setzero_si256();
-
const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
- for (IndexType j = 0; j < kNumChunks256; ++j)
+#if defined (USE_VNNI)
+ __m256i sum0 = _mm256_setzero_si256();
+ __m256i sum1 = _mm256_setzero_si256();
+ __m256i sum2 = _mm256_setzero_si256();
+ __m256i sum3 = _mm256_setzero_si256();
+ const IndexType kStart = 0;
+#else
+ __m256i sum0 = m256_dpbusd_epi32(input_vector256[0], row0[0]);
+ __m256i sum1 = m256_dpbusd_epi32(input_vector256[0], row1[0]);
+ __m256i sum2 = m256_dpbusd_epi32(input_vector256[0], row2[0]);
+ __m256i sum3 = m256_dpbusd_epi32(input_vector256[0], row3[0]);
+ const IndexType kStart = 1;
+#endif
+
+ for (IndexType j = kStart; j < kNumChunks256; ++j)
{
const __m256i in = input_vector256[j];
{
const __m256i in = input_vector256[j];
m256_add_dpbusd_epi32(sum0, in, row0[j]);
m256_add_dpbusd_epi32(sum1, in, row1[j]);
m256_add_dpbusd_epi32(sum2, in, row2[j]);
m256_add_dpbusd_epi32(sum3, in, row3[j]);
m256_add_dpbusd_epi32(sum0, in, row0[j]);
m256_add_dpbusd_epi32(sum1, in, row1[j]);
m256_add_dpbusd_epi32(sum2, in, row2[j]);
m256_add_dpbusd_epi32(sum3, in, row3[j]);
+#else
+ sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
+ sum1 = _mm256_add_epi32(sum1, m256_dpbusd_epi32(in, row1[j]));
+ sum2 = _mm256_add_epi32(sum2, m256_dpbusd_epi32(in, row2[j]));
+ sum3 = _mm256_add_epi32(sum3, m256_dpbusd_epi32(in, row3[j]));
+#endif
}
*outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
}
*outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
{
if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
{
{
if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
{
- __m512i sum0 = _mm512_setzero_si512();
-
const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
- for (IndexType j = 0; j < kNumChunks512; ++j)
+#if defined (USE_VNNI)
+ __m512i sum0 = _mm512_setzero_si512();
+ const IndexType kStart = 0;
+#else
+ __m512i sum0 = m512_dpbusd_epi32(input_vector512[0], row0[0]);
+ const IndexType kStart = 1;
+#endif
+
+ for (IndexType j = kStart; j < kNumChunks512; ++j)
{
const __m512i in = input_vector512[j];
{
const __m512i in = input_vector512[j];
m512_add_dpbusd_epi32(sum0, in, row0[j]);
m512_add_dpbusd_epi32(sum0, in, row0[j]);
+#else
+ sum0 = _mm512_add_epi32(sum0, m512_dpbusd_epi32(in, row0[j]));
+#endif
}
output[0] = m512_hadd(sum0, biases_[0]);
}
else
{
}
output[0] = m512_hadd(sum0, biases_[0]);
}
else
{
- __m256i sum0 = _mm256_setzero_si256();
-
const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
- for (IndexType j = 0; j < kNumChunks256; ++j)
+#if defined (USE_VNNI)
+ __m256i sum0 = _mm256_setzero_si256();
+ const IndexType kStart = 0;
+#else
+ __m256i sum0 = m256_dpbusd_epi32(input_vector256[0], row0[0]);
+ const IndexType kStart = 1;
+#endif
+
+ for (IndexType j = kStart; j < kNumChunks256; ++j)
{
const __m256i in = input_vector256[j];
{
const __m256i in = input_vector256[j];
m256_add_dpbusd_epi32(sum0, in, row0[j]);
m256_add_dpbusd_epi32(sum0, in, row0[j]);
+#else
+ sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
+#endif
}
output[0] = m256_hadd(sum0, biases_[0]);
}
output[0] = m256_hadd(sum0, biases_[0]);
const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
__m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
__m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
- __m256i sum0 = _mm256_setzero_si256();
- __m256i sum1 = _mm256_setzero_si256();
- __m256i sum2 = _mm256_setzero_si256();
- __m256i sum3 = _mm256_setzero_si256();
-
const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
- for (IndexType j = 0; j < kNumChunks; ++j)
+#if defined (USE_VNNI)
+ __m256i sum0 = _mm256_setzero_si256();
+ __m256i sum1 = _mm256_setzero_si256();
+ __m256i sum2 = _mm256_setzero_si256();
+ __m256i sum3 = _mm256_setzero_si256();
+ const IndexType kStart = 0;
+#else
+ __m256i sum0 = m256_dpbusd_epi32(input_vector[0], row0[0]);
+ __m256i sum1 = m256_dpbusd_epi32(input_vector[0], row1[0]);
+ __m256i sum2 = m256_dpbusd_epi32(input_vector[0], row2[0]);
+ __m256i sum3 = m256_dpbusd_epi32(input_vector[0], row3[0]);
+ const IndexType kStart = 1;
+#endif
+
+ for (IndexType j = kStart; j < kNumChunks; ++j)
{
const __m256i in = input_vector[j];
{
const __m256i in = input_vector[j];
m256_add_dpbusd_epi32(sum0, in, row0[j]);
m256_add_dpbusd_epi32(sum1, in, row1[j]);
m256_add_dpbusd_epi32(sum2, in, row2[j]);
m256_add_dpbusd_epi32(sum3, in, row3[j]);
m256_add_dpbusd_epi32(sum0, in, row0[j]);
m256_add_dpbusd_epi32(sum1, in, row1[j]);
m256_add_dpbusd_epi32(sum2, in, row2[j]);
m256_add_dpbusd_epi32(sum3, in, row3[j]);
+#else
+ sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
+ sum1 = _mm256_add_epi32(sum1, m256_dpbusd_epi32(in, row1[j]));
+ sum2 = _mm256_add_epi32(sum2, m256_dpbusd_epi32(in, row2[j]));
+ sum3 = _mm256_add_epi32(sum3, m256_dpbusd_epi32(in, row3[j]));
+#endif
}
*outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
}
*outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
}
else if constexpr (kOutputDimensions == 1)
{
}
else if constexpr (kOutputDimensions == 1)
{
- __m256i sum0 = _mm256_setzero_si256();
-
const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
- for (IndexType j = 0; j < kNumChunks; ++j)
+#if defined (USE_VNNI)
+ __m256i sum0 = _mm256_setzero_si256();
+ const IndexType kStart = 0;
+#else
+ __m256i sum0 = m256_dpbusd_epi32(input_vector[0], row0[0]);
+ const IndexType kStart = 1;
+#endif
+
+ for (IndexType j = kStart; j < kNumChunks; ++j)
{
const __m256i in = input_vector[j];
{
const __m256i in = input_vector[j];
- m256_add_dpbusd_epi32(sum0, in, row0[j]);
+#if defined (USE_VNNI)
+ m256_add_dpbusd_epi32(sum0, in, row0[j]);
+#else
+ sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
+#endif
}
output[0] = m256_hadd(sum0, biases_[0]);
}
output[0] = m256_hadd(sum0, biases_[0]);
const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
__m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
__m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
- __m128i sum0 = _mm_setzero_si128();
- __m128i sum1 = _mm_setzero_si128();
- __m128i sum2 = _mm_setzero_si128();
- __m128i sum3 = _mm_setzero_si128();
-
const auto row0 = reinterpret_cast<const __m128i*>(&weights_[offset0]);
const auto row1 = reinterpret_cast<const __m128i*>(&weights_[offset1]);
const auto row2 = reinterpret_cast<const __m128i*>(&weights_[offset2]);
const auto row3 = reinterpret_cast<const __m128i*>(&weights_[offset3]);
const auto row0 = reinterpret_cast<const __m128i*>(&weights_[offset0]);
const auto row1 = reinterpret_cast<const __m128i*>(&weights_[offset1]);
const auto row2 = reinterpret_cast<const __m128i*>(&weights_[offset2]);
const auto row3 = reinterpret_cast<const __m128i*>(&weights_[offset3]);
- for (int j = 0; j < (int)kNumChunks; j += 1)
+ __m128i sum0 = m128_dpbusd_epi32(input_vector[0], row0[0]);
+ __m128i sum1 = m128_dpbusd_epi32(input_vector[0], row1[0]);
+ __m128i sum2 = m128_dpbusd_epi32(input_vector[0], row2[0]);
+ __m128i sum3 = m128_dpbusd_epi32(input_vector[0], row3[0]);
+
+ for (int j = 1; j < (int)kNumChunks; ++j)
{
const __m128i in = input_vector[j];
{
const __m128i in = input_vector[j];
- m128_add_dpbusd_epi32(sum0, in, row0[j]);
- m128_add_dpbusd_epi32(sum1, in, row1[j]);
- m128_add_dpbusd_epi32(sum2, in, row2[j]);
- m128_add_dpbusd_epi32(sum3, in, row3[j]);
+ sum0 = _mm_add_epi32(sum0, m128_dpbusd_epi32(in, row0[j]));
+ sum1 = _mm_add_epi32(sum1, m128_dpbusd_epi32(in, row1[j]));
+ sum2 = _mm_add_epi32(sum2, m128_dpbusd_epi32(in, row2[j]));
+ sum3 = _mm_add_epi32(sum3, m128_dpbusd_epi32(in, row3[j]));
}
*outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias);
}
*outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias);
}
else if constexpr (kOutputDimensions == 1)
{
}
else if constexpr (kOutputDimensions == 1)
{
- __m128i sum0 = _mm_setzero_si128();
-
const auto row0 = reinterpret_cast<const __m128i*>(&weights_[0]);
const auto row0 = reinterpret_cast<const __m128i*>(&weights_[0]);
- for (int j = 0; j < (int)kNumChunks; j += 1)
- {
- const __m128i in = input_vector[j];
+ __m128i sum0 = m128_dpbusd_epi32(input_vector[0], row0[0]);
- m128_add_dpbusd_epi32(sum0, in, row0[j]);
- }
+ for (int j = 1; j < (int)kNumChunks; ++j)
+ sum0 = _mm_add_epi32(sum0, m128_dpbusd_epi32(input_vector[j], row0[j]));
output[0] = m128_hadd(sum0, biases_[0]);
}
output[0] = m128_hadd(sum0, biases_[0]);
}