Affine transform refactoring.
authorMaximMolchanov <maksym.n.molchanov@gmail.com>
Wed, 6 Jan 2021 03:29:32 +0000 (05:29 +0200)
committerJoost VandeVondele <Joost.VandeVondele@gmail.com>
Fri, 8 Jan 2021 15:35:44 +0000 (16:35 +0100)
Reordered weights in such a way that accumulated sum fits to output.
Weights are grouped in blocks of four elements because four
int8 (weight type) corresponds to one int32 (output type).
No horizontal additions.
Grouped AVX512, AVX2 and SSSE3 implementations.
Repeated code was removed.

An earlier version passed STC:

LLR: 2.97 (-2.94,2.94) {-0.25,1.25}
Total: 15336 W: 1495 L: 1355 D: 12486
Ptnml(0-2): 44, 1054, 5350, 1158, 62
https://tests.stockfishchess.org/tests/view/5ff60e106019e097de3eefd5

Speedup depends on the architecture, up to 4% measured on a NNUE only bench.

closes https://github.com/official-stockfish/Stockfish/pull/3287

No functional change

src/nnue/layers/affine_transform.h

index a715ca85090b8d5c3d530152768810fdd2c94da5..ab2beab7168ebe2e92d6e6be560903eccab1a055 100644 (file)
@@ -41,6 +41,11 @@ namespace Eval::NNUE::Layers {
     static constexpr IndexType kOutputDimensions = OutputDimensions;
     static constexpr IndexType kPaddedInputDimensions =
         CeilToMultiple<IndexType>(kInputDimensions, kMaxSimdWidth);
+#if defined (USE_AVX512)
+    static constexpr const IndexType kOutputSimdWidth = kSimdWidth / 2;
+#elif defined (USE_SSSE3)
+    static constexpr const IndexType kOutputSimdWidth = kSimdWidth / 4;
+#endif
 
     // Size of forward propagation buffer used in this layer
     static constexpr std::size_t kSelfBufferSize =
@@ -65,51 +70,55 @@ namespace Eval::NNUE::Layers {
       for (std::size_t i = 0; i < kOutputDimensions; ++i)
         biases_[i] = read_little_endian<BiasType>(stream);
       for (std::size_t i = 0; i < kOutputDimensions * kPaddedInputDimensions; ++i)
+#if !defined (USE_SSSE3)
         weights_[i] = read_little_endian<WeightType>(stream);
+#else
+        weights_[
+          (i / 4) % (kPaddedInputDimensions / 4) * kOutputDimensions * 4 +
+          i / kPaddedInputDimensions * 4 +
+          i % 4
+        ] = read_little_endian<WeightType>(stream);
 
-#if defined (USE_SSSE3)
-      // Determine if quadruplets of weight and input products can be summed using 16bits
+      // Determine if eights of weight and input products can be summed using 16bits
       // without saturation. We assume worst case combinations of 0 and 127 for all inputs.
-      if (!stream.fail())
+      if (kOutputDimensions > 1 && !stream.fail())
       {
-          auto can_saturate = [](const WeightType* w, int idx[4]) {
-              int pSum = 0, nSum = 0;
-              for (int p = 0; p < 4; ++p)
-                  if (w[idx[p]] > 0)
-                      pSum += w[idx[p]];
-                  else
-                      nSum += w[idx[p]];
-
-              return pSum > 258 || nSum < -258;
-          };
-
-          for (IndexType i = 0; i < kOutputDimensions; ++i)
-          {
-              canSaturate16[i] = false;
-              const WeightType* w = &weights_[i * kPaddedInputDimensions];
-#if defined (USE_AVX512)
-              for (IndexType j = 0; j < (kPaddedInputDimensions & ~127) && !canSaturate16[i]; j += 128)
-                  for (int k = 0; k < 64 && !canSaturate16[i]; k += 2)
+          canSaturate16.count = 0;
+#if !defined(USE_VNNI)
+          for (IndexType i = 0; i < kPaddedInputDimensions; i += 16)
+              for (IndexType j = 0; j < kOutputDimensions; ++j)
+                  for (int x = 0; x < 2; ++x)
                   {
-                      int spacing[4] = { 0, 1, 64, 65 };
-                      canSaturate16[i] = can_saturate(&w[j + k], spacing);
-                  }
-#elif defined (USE_AVX2)
-              for (IndexType j = 0; j < (kPaddedInputDimensions & ~63) && !canSaturate16[i]; j += 64)
-                  for (int k = 0; k < 32 && !canSaturate16[i]; k += 2)
-                  {
-                      int spacing[4] = { 0, 1, 32, 33 };
-                      canSaturate16[i] = can_saturate(&w[j + k], spacing);
-                  }
-#elif defined (USE_SSSE3)
-              for (IndexType j = 0; j < (kPaddedInputDimensions & ~31) && !canSaturate16[i]; j += 32)
-                  for (int k = 0; k < 16 && !canSaturate16[i]; k += 2)
-                  {
-                      int spacing[4] = { 0, 1, 16, 17 };
-                      canSaturate16[i] = can_saturate(&w[j + k], spacing);
+                      WeightType* w = &weights_[i * kOutputDimensions + j * 4 + x * 2];
+                      int sum[2] = {0, 0};
+                      for (int k = 0; k < 8; ++k)
+                      {
+                          IndexType idx = k / 2 * kOutputDimensions * 4 + k % 2;
+                          sum[w[idx] < 0] += w[idx];
+                      }
+                      for (int sign : {-1, 1})
+                          while (sign * sum[sign == -1] > 258)
+                          {
+                              int maxK = 0, maxW = 0;
+                              for (int k = 0; k < 8; ++k)
+                              {
+                                  IndexType idx = k / 2 * kOutputDimensions * 4 + k % 2;
+                                  if (maxW < sign * w[idx])
+                                      maxK = k, maxW = sign * w[idx];
+                              }
+
+                              IndexType idx = maxK / 2 * kOutputDimensions * 4 + maxK % 2;
+                              sum[sign == -1] -= w[idx];
+                              canSaturate16.add(j, i + maxK / 2 * 4 + maxK % 2 + x * 2, w[idx]);
+                              w[idx] = 0;
+                          }
                   }
+
+          // Non functional optimization for faster more linear access
+          std::sort(canSaturate16.ids, canSaturate16.ids + canSaturate16.count,
+                    [](const typename CanSaturate::Entry& e1, const typename CanSaturate::Entry& e2)
+                    { return e1.in == e2.in ? e1.out < e2.out : e1.in < e2.in; });
 #endif
-          }
       }
 #endif
 
@@ -130,104 +139,6 @@ namespace Eval::NNUE::Layers {
         return _mm512_reduce_add_epi32(sum) + bias;
       };
 
-      // This function takes
-      //   sum0 = [xmm0a, xmm0b, xmm0c, xmm0d]
-      //   sum1 = [xmm1a, xmm1b, xmm1c, xmm1d]
-      //   sum2 = [xmm2a, xmm2b, xmm2c, xmm2d]
-      //   sum3 = [xmm3a, xmm3b, xmm3c, xmm3d]
-      // and returns
-      //   ret = [
-      //     reduce_add_epi32(xmm0a), reduce_add_epi32(xmm1a), reduce_add_epi32(xmm2a), reduce_add_epi32(xmm3a),
-      //     reduce_add_epi32(xmm0b), reduce_add_epi32(xmm1b), reduce_add_epi32(xmm2b), reduce_add_epi32(xmm3b),
-      //     reduce_add_epi32(xmm0c), reduce_add_epi32(xmm1c), reduce_add_epi32(xmm2c), reduce_add_epi32(xmm3c),
-      //     reduce_add_epi32(xmm0d), reduce_add_epi32(xmm1d), reduce_add_epi32(xmm2d), reduce_add_epi32(xmm3d)
-      //   ]
-      [[maybe_unused]] auto m512_hadd128x16_interleave = [](
-        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
-
-        __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
-        __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
-
-        __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
-        __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
-
-        __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
-        __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
-
-        __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
-        __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
-
-        return _mm512_add_epi32(sum0123a, sum0123b);
-      };
-
-      [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
-        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
-
-        __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
-
-        __m256i sum256lo = _mm512_castsi512_si256(sum);
-        __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
-
-        sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
-
-        __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
-        __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
-
-        return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
-      };
-
-      [[maybe_unused]] auto m512_haddx8 = [m512_hadd128x16_interleave](
-        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
-        __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m256i bias) -> __m256i {
-
-        __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
-        __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
-
-        __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
-        __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
-        __m512i x = _mm512_add_epi32(
-          _mm512_permutex2var_epi64(suma, indices0, sumb),
-          _mm512_permutex2var_epi64(suma, indices1, sumb));
-
-        __m256i sum256lo = _mm512_castsi512_si256(x);
-        __m256i sum256hi = _mm512_extracti64x4_epi64(x, 1);
-
-        return _mm256_add_epi32(_mm256_add_epi32(sum256lo, sum256hi), bias);
-      };
-
-      [[maybe_unused]] auto m512_hadd256x8 =[m512_hadd128x16_interleave](
-        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m256i bias) -> __m256i {
-
-        __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
-
-        __m512i indices = _mm512_setr_epi32(
-          0, 4, 8, 12, 2, 6, 10, 14,
-          1, 5, 9, 13, 3, 7, 11, 15);
-        sum = _mm512_permutexvar_epi32(indices, sum);
-
-        __m256i sum256lo = _mm512_castsi512_si256(sum);
-        __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
-
-        return _mm256_add_epi32(_mm256_hadd_epi32(sum256lo, sum256hi), bias);
-      };
-
-      [[maybe_unused]] auto m512_hadd256x16 = [m512_hadd128x16_interleave](
-        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
-        __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m512i bias) -> __m512i {
-
-        __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
-        __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
-
-        __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
-        __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
-        __m512i x = _mm512_add_epi32(
-          _mm512_permutex2var_epi64(suma, indices0, sumb),
-          _mm512_permutex2var_epi64(suma, indices1, sumb));
-
-        __m512i indices = _mm512_setr_epi32(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
-        return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
-      };
-
       [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
 #if defined (USE_VNNI)
         acc = _mm512_dpbusd_epi32(acc, a, b);
@@ -238,14 +149,21 @@ namespace Eval::NNUE::Layers {
 #endif
       };
 
-      [[maybe_unused]] auto m512_add_dpbusd_epi32x2 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
+      [[maybe_unused]] auto m512_add_dpbusd_epi32x4 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1,
+                                                                        __m512i a2, __m512i b2, __m512i a3, __m512i b3) {
 #if defined (USE_VNNI)
         acc = _mm512_dpbusd_epi32(acc, a0, b0);
         acc = _mm512_dpbusd_epi32(acc, a1, b1);
+        acc = _mm512_dpbusd_epi32(acc, a2, b2);
+        acc = _mm512_dpbusd_epi32(acc, a3, b3);
 #else
         __m512i product0 = _mm512_maddubs_epi16(a0, b0);
         __m512i product1 = _mm512_maddubs_epi16(a1, b1);
-        product0 = _mm512_adds_epi16(product0, product1);
+        __m512i product2 = _mm512_maddubs_epi16(a2, b2);
+        __m512i product3 = _mm512_maddubs_epi16(a3, b3);
+        product0 = _mm512_add_epi16(product0, product1);
+        product2 = _mm512_add_epi16(product2, product3);
+        product0 = _mm512_add_epi16(product0, product2);
         product0 = _mm512_madd_epi16(product0, kOnes512);
         acc = _mm512_add_epi32(acc, product0);
 #endif
@@ -263,18 +181,6 @@ namespace Eval::NNUE::Layers {
         return _mm_cvtsi128_si32(sum128) + bias;
       };
 
-      [[maybe_unused]] auto m256_haddx4 = [](__m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3, __m128i bias) -> __m128i {
-        sum0 = _mm256_hadd_epi32(sum0, sum1);
-        sum2 = _mm256_hadd_epi32(sum2, sum3);
-
-        sum0 = _mm256_hadd_epi32(sum0, sum2);
-
-        __m128i sum128lo = _mm256_castsi256_si128(sum0);
-        __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
-
-        return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
-      };
-
       [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
 #if defined (USE_VNNI)
         acc = _mm256_dpbusd_epi32(acc, a, b);
@@ -285,21 +191,27 @@ namespace Eval::NNUE::Layers {
 #endif
       };
 
-      [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
+      [[maybe_unused]] auto m256_add_dpbusd_epi32x4 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1,
+                                                                        __m256i a2, __m256i b2, __m256i a3, __m256i b3) {
 #if defined (USE_VNNI)
         acc = _mm256_dpbusd_epi32(acc, a0, b0);
         acc = _mm256_dpbusd_epi32(acc, a1, b1);
+        acc = _mm256_dpbusd_epi32(acc, a2, b2);
+        acc = _mm256_dpbusd_epi32(acc, a3, b3);
 #else
         __m256i product0 = _mm256_maddubs_epi16(a0, b0);
         __m256i product1 = _mm256_maddubs_epi16(a1, b1);
-        product0 = _mm256_adds_epi16(product0, product1);
+        __m256i product2 = _mm256_maddubs_epi16(a2, b2);
+        __m256i product3 = _mm256_maddubs_epi16(a3, b3);
+        product0 = _mm256_add_epi16(product0, product1);
+        product2 = _mm256_add_epi16(product2, product3);
+        product0 = _mm256_add_epi16(product0, product2);
         product0 = _mm256_madd_epi16(product0, kOnes256);
         acc = _mm256_add_epi32(acc, product0);
 #endif
       };
 
 #endif
-
 #if defined (USE_SSSE3)
 
       [[maybe_unused]] const __m128i kOnes128 = _mm_set1_epi16(1);
@@ -310,25 +222,21 @@ namespace Eval::NNUE::Layers {
         return _mm_cvtsi128_si32(sum) + bias;
       };
 
-      [[maybe_unused]] auto m128_haddx4 = [](__m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3, __m128i bias) -> __m128i {
-        sum0 = _mm_hadd_epi32(sum0, sum1);
-        sum2 = _mm_hadd_epi32(sum2, sum3);
-
-        sum0 = _mm_hadd_epi32(sum0, sum2);
-
-        return _mm_add_epi32(sum0, bias);
-      };
-
       [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
         __m128i product0 = _mm_maddubs_epi16(a, b);
         product0 = _mm_madd_epi16(product0, kOnes128);
         acc = _mm_add_epi32(acc, product0);
       };
 
-      [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
+      [[maybe_unused]] auto m128_add_dpbusd_epi32x4 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1,
+                                                                        __m128i a2, __m128i b2, __m128i a3, __m128i b3) {
         __m128i product0 = _mm_maddubs_epi16(a0, b0);
         __m128i product1 = _mm_maddubs_epi16(a1, b1);
+        __m128i product2 = _mm_maddubs_epi16(a2, b2);
+        __m128i product3 = _mm_maddubs_epi16(a3, b3);
         product0 = _mm_adds_epi16(product0, product1);
+        product2 = _mm_adds_epi16(product2, product3);
+        product0 = _mm_adds_epi16(product0, product2);
         product0 = _mm_madd_epi16(product0, kOnes128);
         acc = _mm_add_epi32(acc, product0);
       };
@@ -336,353 +244,77 @@ namespace Eval::NNUE::Layers {
 #endif
 
 #if defined (USE_AVX512)
+      using vec_t = __m512i;
+      #define vec_setzero _mm512_setzero_si512
+      #define vec_set_32 _mm512_set1_epi32
+      auto& vec_add_dpbusd_32 = m512_add_dpbusd_epi32;
+      auto& vec_add_dpbusd_32x4 = m512_add_dpbusd_epi32x4;
+      auto& vec_hadd = m512_hadd;
+#elif defined (USE_AVX2)
+      using vec_t = __m256i;
+      #define vec_setzero _mm256_setzero_si256
+      #define vec_set_32 _mm256_set1_epi32
+      auto& vec_add_dpbusd_32 = m256_add_dpbusd_epi32;
+      auto& vec_add_dpbusd_32x4 = m256_add_dpbusd_epi32x4;
+      auto& vec_hadd = m256_hadd;
+#elif defined (USE_SSSE3)
+      using vec_t = __m128i;
+      #define vec_setzero _mm_setzero_si128
+      #define vec_set_32 _mm_set1_epi32
+      auto& vec_add_dpbusd_32 = m128_add_dpbusd_epi32;
+      auto& vec_add_dpbusd_32x4 = m128_add_dpbusd_epi32x4;
+      auto& vec_hadd = m128_hadd;
+#endif
 
-      constexpr IndexType kNumChunks512 = kPaddedInputDimensions / (kSimdWidth * 2);
-      constexpr IndexType kNumChunks256 = kPaddedInputDimensions / kSimdWidth;
+#if defined (USE_SSSE3)
 
       const auto output = reinterpret_cast<OutputType*>(buffer);
+      const auto input_vector = reinterpret_cast<const vec_t*>(input);
 
-      // Since to saturate a zmm register it takes 64 bytes we
-      // cannot use AVX512 for the smaller affine transforms.
-      // Instead we fallback to a AVX2 implementation if the
-      // kInputDimensions isn't a multiple of 64.
-      // Note that this means that for example for
-      // kInputDimensions of 96 we fallback to AVX2 even though
-      // the first 64 elements could be processed with AVX512.
-      // This is caused by mixing the __m256 and __m512 variables
-      // required to better handle that case and it would
-      // require handling more cases statically not to lose performance.
-      // This should be revisited if such input dimensions are to be considered.
-      [[maybe_unused]] const auto input_vector512 = reinterpret_cast<const __m512i*>(input);
-      [[maybe_unused]] const auto input_vector256 = reinterpret_cast<const __m256i*>(input);
+      static_assert(kOutputDimensions % kOutputSimdWidth == 0 || kOutputDimensions == 1);
 
       // kOutputDimensions is either 1 or a multiple of kSimdWidth
       // because then it is also an input dimension.
-      if constexpr (kOutputDimensions % 16 == 0 && kNumChunks256 == 1)
-      {
-        for (IndexType i = 0; i < kOutputDimensions; i += 16)
-        {
-          const IndexType offset01a = (i + 0) * kPaddedInputDimensions;
-          const IndexType offset23a = (i + 2) * kPaddedInputDimensions;
-          const IndexType offset45a = (i + 4) * kPaddedInputDimensions;
-          const IndexType offset67a = (i + 6) * kPaddedInputDimensions;
-          const IndexType offset01b = (i + 8) * kPaddedInputDimensions;
-          const IndexType offset23b = (i + 10) * kPaddedInputDimensions;
-          const IndexType offset45b = (i + 12) * kPaddedInputDimensions;
-          const IndexType offset67b = (i + 14) * kPaddedInputDimensions;
-
-          const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
-          __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
-
-          __m512i sum01a = _mm512_setzero_si512();
-          __m512i sum23a = _mm512_setzero_si512();
-          __m512i sum45a = _mm512_setzero_si512();
-          __m512i sum67a = _mm512_setzero_si512();
-          __m512i sum01b = _mm512_setzero_si512();
-          __m512i sum23b = _mm512_setzero_si512();
-          __m512i sum45b = _mm512_setzero_si512();
-          __m512i sum67b = _mm512_setzero_si512();
-
-          const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
-          const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
-          const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
-          const auto row67a = *reinterpret_cast<const __m512i*>(&weights_[offset67a]);
-          const auto row01b = *reinterpret_cast<const __m512i*>(&weights_[offset01b]);
-          const auto row23b = *reinterpret_cast<const __m512i*>(&weights_[offset23b]);
-          const auto row45b = *reinterpret_cast<const __m512i*>(&weights_[offset45b]);
-          const auto row67b = *reinterpret_cast<const __m512i*>(&weights_[offset67b]);
-
-          const __m256i in256 = input_vector256[0];
-          const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
-
-          m512_add_dpbusd_epi32(sum01a, in, row01a);
-          m512_add_dpbusd_epi32(sum23a, in, row23a);
-          m512_add_dpbusd_epi32(sum45a, in, row45a);
-          m512_add_dpbusd_epi32(sum67a, in, row67a);
-          m512_add_dpbusd_epi32(sum01b, in, row01b);
-          m512_add_dpbusd_epi32(sum23b, in, row23b);
-          m512_add_dpbusd_epi32(sum45b, in, row45b);
-          m512_add_dpbusd_epi32(sum67b, in, row67b);
-
-          *outptr = m512_hadd256x16(
-            sum01a, sum23a, sum45a, sum67a,
-            sum01b, sum23b, sum45b, sum67b, bias);
-        }
-      }
-      else if constexpr (kOutputDimensions % 4 == 0)
-      {
-        for (IndexType i = 0; i < kOutputDimensions; i += 4)
-        {
-          const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
-          const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
-          const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
-          const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
-
-          const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
-          __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
-
-          if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
-          {
-            __m512i sum0 = _mm512_setzero_si512();
-            __m512i sum1 = _mm512_setzero_si512();
-            __m512i sum2 = _mm512_setzero_si512();
-            __m512i sum3 = _mm512_setzero_si512();
-
-            const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
-            const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
-            const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
-            const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
-
-            int j = 0;
-            if (!canSaturate16x4[i / 4])
-            {
-                for (; j < (int)kNumChunks512 - 1; j += 2)
-                {
-                    const __m512i in0 = input_vector512[j];
-                    const __m512i in1 = input_vector512[j + 1];
-
-                    m512_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
-                    m512_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
-                    m512_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
-                    m512_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
-                }
-            }
-            for (; j < (int)kNumChunks512; ++j)
-            {
-              const __m512i in = input_vector512[j];
-
-              m512_add_dpbusd_epi32(sum0, in, row0[j]);
-              m512_add_dpbusd_epi32(sum1, in, row1[j]);
-              m512_add_dpbusd_epi32(sum2, in, row2[j]);
-              m512_add_dpbusd_epi32(sum3, in, row3[j]);
-            }
-
-            *outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias);
-          }
-          else
-          {
-            __m256i sum0 = _mm256_setzero_si256();
-            __m256i sum1 = _mm256_setzero_si256();
-            __m256i sum2 = _mm256_setzero_si256();
-            __m256i sum3 = _mm256_setzero_si256();
-
-            const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
-            const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
-            const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
-            const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
-
-            for (IndexType j = 0; j < kNumChunks256; ++j)
-            {
-              const __m256i in = input_vector256[j];
-
-              m256_add_dpbusd_epi32(sum0, in, row0[j]);
-              m256_add_dpbusd_epi32(sum1, in, row1[j]);
-              m256_add_dpbusd_epi32(sum2, in, row2[j]);
-              m256_add_dpbusd_epi32(sum3, in, row3[j]);
-            }
-
-            *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
-          }
-        }
-      }
-      else if constexpr (kOutputDimensions == 1)
+      if constexpr (kOutputDimensions % kOutputSimdWidth == 0)
       {
-        if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
-        {
-          __m512i sum0 = _mm512_setzero_si512();
+          constexpr IndexType kNumChunks = kPaddedInputDimensions / 4;
 
-          const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
+          const auto input32 = reinterpret_cast<const std::int32_t*>(input);
+          vec_t* outptr = reinterpret_cast<vec_t*>(output);
+          std::memcpy(output, biases_, kOutputDimensions * sizeof(OutputType));
 
-          for (IndexType j = 0; j < kNumChunks512; ++j)
+          for (int i = 0; i < (int)kNumChunks - 3; i += 4)
           {
-            const __m512i in = input_vector512[j];
-
-            m512_add_dpbusd_epi32(sum0, in, row0[j]);
+              const vec_t in0 = vec_set_32(input32[i + 0]);
+              const vec_t in1 = vec_set_32(input32[i + 1]);
+              const vec_t in2 = vec_set_32(input32[i + 2]);
+              const vec_t in3 = vec_set_32(input32[i + 3]);
+              const auto col0 = reinterpret_cast<const vec_t*>(&weights_[(i + 0) * kOutputDimensions * 4]);
+              const auto col1 = reinterpret_cast<const vec_t*>(&weights_[(i + 1) * kOutputDimensions * 4]);
+              const auto col2 = reinterpret_cast<const vec_t*>(&weights_[(i + 2) * kOutputDimensions * 4]);
+              const auto col3 = reinterpret_cast<const vec_t*>(&weights_[(i + 3) * kOutputDimensions * 4]);
+              for (int j = 0; j * kOutputSimdWidth < kOutputDimensions; ++j)
+                  vec_add_dpbusd_32x4(outptr[j], in0, col0[j], in1, col1[j], in2, col2[j], in3, col3[j]);
           }
-
-          output[0] = m512_hadd(sum0, biases_[0]);
-        }
-        else
-        {
-          __m256i sum0 = _mm256_setzero_si256();
-
-          const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
-
-          for (IndexType j = 0; j < kNumChunks256; ++j)
-          {
-            const __m256i in = input_vector256[j];
-
-            m256_add_dpbusd_epi32(sum0, in, row0[j]);
-          }
-
-          output[0] = m256_hadd(sum0, biases_[0]);
-        }
-      }
-      else
-      {
-        // This case can never happen because kOutputDimensions
-        // is always 1 or a multiple of kSimdWidth.
-        assert(false);
-      }
-
-#elif defined (USE_AVX2)
-
-      constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-
-      const auto output = reinterpret_cast<OutputType*>(buffer);
-      const auto input_vector = reinterpret_cast<const __m256i*>(input);
-
-      // kOutputDimensions is either 1 or a multiple of kSimdWidth
-      // because then it is also an input dimension.
-      if constexpr (kOutputDimensions % 4 == 0)
-      {
-        for (IndexType i = 0; i < kOutputDimensions; i += 4)
-        {
-          const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
-          const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
-          const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
-          const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
-
-          const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
-          __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
-
-          __m256i sum0 = _mm256_setzero_si256();
-          __m256i sum1 = _mm256_setzero_si256();
-          __m256i sum2 = _mm256_setzero_si256();
-          __m256i sum3 = _mm256_setzero_si256();
-
-          const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
-          const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
-          const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
-          const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
-
-          int j = 0;
-          if (!canSaturate16x4[i / 4])
-          {
-              for (; j < (int)kNumChunks - 1; j += 2)
-              {
-                  const __m256i in0 = input_vector[j];
-                  const __m256i in1 = input_vector[j + 1];
-
-                  m256_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
-                  m256_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
-                  m256_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
-                  m256_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
-              }
-          }
-          for (; j < (int)kNumChunks; ++j)
-          {
-                const __m256i in = input_vector[j];
-
-                m256_add_dpbusd_epi32(sum0, in, row0[j]);
-                m256_add_dpbusd_epi32(sum1, in, row1[j]);
-                m256_add_dpbusd_epi32(sum2, in, row2[j]);
-                m256_add_dpbusd_epi32(sum3, in, row3[j]);
-          }
-
-          *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
-        }
+          for (int i = 0; i < canSaturate16.count; ++i)
+              output[canSaturate16.ids[i].out] += input[canSaturate16.ids[i].in] * canSaturate16.ids[i].w;
       }
       else if constexpr (kOutputDimensions == 1)
       {
-        __m256i sum0 = _mm256_setzero_si256();
-
-        const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
-
-        for (IndexType j = 0; j < kNumChunks; ++j)
-        {
-            const __m256i in = input_vector[j];
-
-            m256_add_dpbusd_epi32(sum0, in, row0[j]);
-        }
-
-        output[0] = m256_hadd(sum0, biases_[0]);
-      }
-      else
-      {
-        // This case can never happen because kOutputDimensions
-        // is always 1 or a multiple of kSimdWidth.
-        assert(false);
-      }
+          constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
 
-#elif defined (USE_SSSE3)
+          vec_t sum0 = vec_setzero();
 
-      constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
+          const auto row0 = reinterpret_cast<const vec_t*>(&weights_[0]);
 
-      auto output = reinterpret_cast<OutputType*>(buffer);
-      const auto input_vector = reinterpret_cast<const __m128i*>(input);
-
-      // kOutputDimensions is either 1 or a multiple of kSimdWidth
-      // because then it is also an input dimension.
-      if constexpr (kOutputDimensions % 4 == 0)
-      {
-        for (IndexType i = 0; i < kOutputDimensions; i += 4)
-        {
-          const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
-          const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
-          const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
-          const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
-
-          const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
-          __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
-
-          __m128i sum0 = _mm_setzero_si128();
-          __m128i sum1 = _mm_setzero_si128();
-          __m128i sum2 = _mm_setzero_si128();
-          __m128i sum3 = _mm_setzero_si128();
-
-          const auto row0 = reinterpret_cast<const __m128i*>(&weights_[offset0]);
-          const auto row1 = reinterpret_cast<const __m128i*>(&weights_[offset1]);
-          const auto row2 = reinterpret_cast<const __m128i*>(&weights_[offset2]);
-          const auto row3 = reinterpret_cast<const __m128i*>(&weights_[offset3]);
-
-          int j = 0;
-          if (!canSaturate16x4[i / 4])
-          {
-              for (; j < (int)kNumChunks - 1; j += 2)
-              {
-                  const __m128i in0 = input_vector[j];
-                  const __m128i in1 = input_vector[j + 1];
-
-                  m128_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
-                  m128_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
-                  m128_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
-                  m128_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
-              }
-          }
-          for (; j < (int)kNumChunks; ++j)
+          for (int j = 0; j < (int)kNumChunks; ++j)
           {
-              const __m128i in = input_vector[j];
+              const vec_t in = input_vector[j];
 
-              m128_add_dpbusd_epi32(sum0, in, row0[j]);
-              m128_add_dpbusd_epi32(sum1, in, row1[j]);
-              m128_add_dpbusd_epi32(sum2, in, row2[j]);
-              m128_add_dpbusd_epi32(sum3, in, row3[j]);
+              vec_add_dpbusd_32(sum0, in, row0[j]);
           }
 
-          *outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias);
-        }
-      }
-      else if constexpr (kOutputDimensions == 1)
-      {
-        __m128i sum0 = _mm_setzero_si128();
-
-        const auto row0 = reinterpret_cast<const __m128i*>(&weights_[0]);
-
-        for (int j = 0; j < (int)kNumChunks; ++j)
-        {
-          const __m128i in = input_vector[j];
-
-          m128_add_dpbusd_epi32(sum0, in, row0[j]);
-        }
-
-        output[0] = m128_hadd(sum0, biases_[0]);
-      }
-      else
-      {
-        // This case can never happen because kOutputDimensions
-        // is always 1 or a multiple of kSimdWidth.
-        assert(false);
+          output[0] = vec_hadd(sum0, biases_[0]);
       }
 
 #else
@@ -693,11 +325,7 @@ namespace Eval::NNUE::Layers {
 
 #if defined(USE_SSE2)
       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-#ifndef USE_SSSE3
       const __m128i kZeros = _mm_setzero_si128();
-#else
-      const __m128i kOnes = _mm_set1_epi16(1);
-#endif
       const auto input_vector = reinterpret_cast<const __m128i*>(input);
 
 #elif defined(USE_MMX)
@@ -792,10 +420,23 @@ namespace Eval::NNUE::Layers {
 
     alignas(kCacheLineSize) BiasType biases_[kOutputDimensions];
     alignas(kCacheLineSize) WeightType weights_[kOutputDimensions * kPaddedInputDimensions];
-    union {
-        uint32_t canSaturate16x4[(kOutputDimensions + 3) / 4];
-        bool canSaturate16[kOutputDimensions];
-    };
+#if defined (USE_SSSE3)
+    struct CanSaturate {
+        int count;
+        struct Entry {
+            uint16_t out;
+            uint16_t in;
+            int8_t w;
+        } ids[kPaddedInputDimensions * kOutputDimensions * 3 / 4];
+
+        void add(int i, int j, int8_t w) {
+            ids[count].out = i;
+            ids[count].in = j;
+            ids[count].w = w;
+            ++count;
+        }
+    } canSaturate16;
+#endif
   };
 
 }  // namespace Eval::NNUE::Layers