]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/affine_transform.h
AVX512, AVX2 and SSSE3 speedups
[stockfish] / src / nnue / layers / affine_transform.h
index 0e0515f932a0773cc82f72c1620bc0a1afe5e5eb..a715ca85090b8d5c3d530152768810fdd2c94da5 100644 (file)
@@ -66,6 +66,53 @@ namespace Eval::NNUE::Layers {
         biases_[i] = read_little_endian<BiasType>(stream);
       for (std::size_t i = 0; i < kOutputDimensions * kPaddedInputDimensions; ++i)
         weights_[i] = read_little_endian<WeightType>(stream);
+
+#if defined (USE_SSSE3)
+      // Determine if quadruplets of weight and input products can be summed using 16bits
+      // without saturation. We assume worst case combinations of 0 and 127 for all inputs.
+      if (!stream.fail())
+      {
+          auto can_saturate = [](const WeightType* w, int idx[4]) {
+              int pSum = 0, nSum = 0;
+              for (int p = 0; p < 4; ++p)
+                  if (w[idx[p]] > 0)
+                      pSum += w[idx[p]];
+                  else
+                      nSum += w[idx[p]];
+
+              return pSum > 258 || nSum < -258;
+          };
+
+          for (IndexType i = 0; i < kOutputDimensions; ++i)
+          {
+              canSaturate16[i] = false;
+              const WeightType* w = &weights_[i * kPaddedInputDimensions];
+#if defined (USE_AVX512)
+              for (IndexType j = 0; j < (kPaddedInputDimensions & ~127) && !canSaturate16[i]; j += 128)
+                  for (int k = 0; k < 64 && !canSaturate16[i]; k += 2)
+                  {
+                      int spacing[4] = { 0, 1, 64, 65 };
+                      canSaturate16[i] = can_saturate(&w[j + k], spacing);
+                  }
+#elif defined (USE_AVX2)
+              for (IndexType j = 0; j < (kPaddedInputDimensions & ~63) && !canSaturate16[i]; j += 64)
+                  for (int k = 0; k < 32 && !canSaturate16[i]; k += 2)
+                  {
+                      int spacing[4] = { 0, 1, 32, 33 };
+                      canSaturate16[i] = can_saturate(&w[j + k], spacing);
+                  }
+#elif defined (USE_SSSE3)
+              for (IndexType j = 0; j < (kPaddedInputDimensions & ~31) && !canSaturate16[i]; j += 32)
+                  for (int k = 0; k < 16 && !canSaturate16[i]; k += 2)
+                  {
+                      int spacing[4] = { 0, 1, 16, 17 };
+                      canSaturate16[i] = can_saturate(&w[j + k], spacing);
+                  }
+#endif
+          }
+      }
+#endif
+
       return !stream.fail();
     }
 
@@ -181,13 +228,26 @@ namespace Eval::NNUE::Layers {
         return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
       };
 
-#if defined (USE_VNNI)
       [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
+#if defined (USE_VNNI)
         acc = _mm512_dpbusd_epi32(acc, a, b);
 #else
-      [[maybe_unused]] auto m512_dpbusd_epi32 = [=](__m512i a, __m512i b) -> __m512i {
         __m512i product0 = _mm512_maddubs_epi16(a, b);
-        return _mm512_madd_epi16(product0, kOnes512);
+        product0 = _mm512_madd_epi16(product0, kOnes512);
+        acc = _mm512_add_epi32(acc, product0);
+#endif
+      };
+
+      [[maybe_unused]] auto m512_add_dpbusd_epi32x2 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
+#if defined (USE_VNNI)
+        acc = _mm512_dpbusd_epi32(acc, a0, b0);
+        acc = _mm512_dpbusd_epi32(acc, a1, b1);
+#else
+        __m512i product0 = _mm512_maddubs_epi16(a0, b0);
+        __m512i product1 = _mm512_maddubs_epi16(a1, b1);
+        product0 = _mm512_adds_epi16(product0, product1);
+        product0 = _mm512_madd_epi16(product0, kOnes512);
+        acc = _mm512_add_epi32(acc, product0);
 #endif
       };
 
@@ -214,13 +274,27 @@ namespace Eval::NNUE::Layers {
 
         return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
       };
-#if defined (USE_VNNI)
+
       [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
+#if defined (USE_VNNI)
         acc = _mm256_dpbusd_epi32(acc, a, b);
 #else
-      [[maybe_unused]] auto m256_dpbusd_epi32 = [=](__m256i a, __m256i b) -> __m256i {
         __m256i product0 = _mm256_maddubs_epi16(a, b);
-        return _mm256_madd_epi16(product0, kOnes256);
+        product0 = _mm256_madd_epi16(product0, kOnes256);
+        acc = _mm256_add_epi32(acc, product0);
+#endif
+      };
+
+      [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
+#if defined (USE_VNNI)
+        acc = _mm256_dpbusd_epi32(acc, a0, b0);
+        acc = _mm256_dpbusd_epi32(acc, a1, b1);
+#else
+        __m256i product0 = _mm256_maddubs_epi16(a0, b0);
+        __m256i product1 = _mm256_maddubs_epi16(a1, b1);
+        product0 = _mm256_adds_epi16(product0, product1);
+        product0 = _mm256_madd_epi16(product0, kOnes256);
+        acc = _mm256_add_epi32(acc, product0);
 #endif
       };
 
@@ -245,9 +319,18 @@ namespace Eval::NNUE::Layers {
         return _mm_add_epi32(sum0, bias);
       };
 
-      [[maybe_unused]] auto m128_dpbusd_epi32 = [=](__m128i a, __m128i b) -> __m128i {
+      [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
         __m128i product0 = _mm_maddubs_epi16(a, b);
-        return _mm_madd_epi16(product0, kOnes128);
+        product0 = _mm_madd_epi16(product0, kOnes128);
+        acc = _mm_add_epi32(acc, product0);
+      };
+
+      [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
+        __m128i product0 = _mm_maddubs_epi16(a0, b0);
+        __m128i product1 = _mm_maddubs_epi16(a1, b1);
+        product0 = _mm_adds_epi16(product0, product1);
+        product0 = _mm_madd_epi16(product0, kOnes128);
+        acc = _mm_add_epi32(acc, product0);
       };
 
 #endif
@@ -291,6 +374,15 @@ namespace Eval::NNUE::Layers {
           const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
           __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
 
+          __m512i sum01a = _mm512_setzero_si512();
+          __m512i sum23a = _mm512_setzero_si512();
+          __m512i sum45a = _mm512_setzero_si512();
+          __m512i sum67a = _mm512_setzero_si512();
+          __m512i sum01b = _mm512_setzero_si512();
+          __m512i sum23b = _mm512_setzero_si512();
+          __m512i sum45b = _mm512_setzero_si512();
+          __m512i sum67b = _mm512_setzero_si512();
+
           const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
           const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
           const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
@@ -303,16 +395,6 @@ namespace Eval::NNUE::Layers {
           const __m256i in256 = input_vector256[0];
           const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
 
-#if defined (USE_VNNI)
-          __m512i sum01a = _mm512_setzero_si512();
-          __m512i sum23a = _mm512_setzero_si512();
-          __m512i sum45a = _mm512_setzero_si512();
-          __m512i sum67a = _mm512_setzero_si512();
-          __m512i sum01b = _mm512_setzero_si512();
-          __m512i sum23b = _mm512_setzero_si512();
-          __m512i sum45b = _mm512_setzero_si512();
-          __m512i sum67b = _mm512_setzero_si512();
-
           m512_add_dpbusd_epi32(sum01a, in, row01a);
           m512_add_dpbusd_epi32(sum23a, in, row23a);
           m512_add_dpbusd_epi32(sum45a, in, row45a);
@@ -321,16 +403,6 @@ namespace Eval::NNUE::Layers {
           m512_add_dpbusd_epi32(sum23b, in, row23b);
           m512_add_dpbusd_epi32(sum45b, in, row45b);
           m512_add_dpbusd_epi32(sum67b, in, row67b);
-#else
-          __m512i sum01a = m512_dpbusd_epi32(in, row01a);
-          __m512i sum23a = m512_dpbusd_epi32(in, row23a);
-          __m512i sum45a = m512_dpbusd_epi32(in, row45a);
-          __m512i sum67a = m512_dpbusd_epi32(in, row67a);
-          __m512i sum01b = m512_dpbusd_epi32(in, row01b);
-          __m512i sum23b = m512_dpbusd_epi32(in, row23b);
-          __m512i sum45b = m512_dpbusd_epi32(in, row45b);
-          __m512i sum67b = m512_dpbusd_epi32(in, row67b);
-#endif
 
           *outptr = m512_hadd256x16(
             sum01a, sum23a, sum45a, sum67a,
@@ -351,80 +423,62 @@ namespace Eval::NNUE::Layers {
 
           if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
           {
-            const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
-            const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
-            const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
-            const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
-
-#if defined (USE_VNNI)
             __m512i sum0 = _mm512_setzero_si512();
             __m512i sum1 = _mm512_setzero_si512();
             __m512i sum2 = _mm512_setzero_si512();
             __m512i sum3 = _mm512_setzero_si512();
-            const IndexType kStart = 0;
-#else
-            __m512i sum0 = m512_dpbusd_epi32(input_vector512[0], row0[0]);
-            __m512i sum1 = m512_dpbusd_epi32(input_vector512[0], row1[0]);
-            __m512i sum2 = m512_dpbusd_epi32(input_vector512[0], row2[0]);
-            __m512i sum3 = m512_dpbusd_epi32(input_vector512[0], row3[0]);
-            const IndexType kStart = 1;
-#endif
 
-            for (IndexType j = kStart; j < kNumChunks512; ++j)
+            const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
+            const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
+            const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
+            const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
+
+            int j = 0;
+            if (!canSaturate16x4[i / 4])
+            {
+                for (; j < (int)kNumChunks512 - 1; j += 2)
+                {
+                    const __m512i in0 = input_vector512[j];
+                    const __m512i in1 = input_vector512[j + 1];
+
+                    m512_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
+                    m512_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
+                    m512_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
+                    m512_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
+                }
+            }
+            for (; j < (int)kNumChunks512; ++j)
             {
               const __m512i in = input_vector512[j];
 
-#if defined (USE_VNNI)
               m512_add_dpbusd_epi32(sum0, in, row0[j]);
               m512_add_dpbusd_epi32(sum1, in, row1[j]);
               m512_add_dpbusd_epi32(sum2, in, row2[j]);
               m512_add_dpbusd_epi32(sum3, in, row3[j]);
-#else
-              sum0 = _mm512_add_epi32(sum0, m512_dpbusd_epi32(in, row0[j]));
-              sum1 = _mm512_add_epi32(sum1, m512_dpbusd_epi32(in, row1[j]));
-              sum2 = _mm512_add_epi32(sum2, m512_dpbusd_epi32(in, row2[j]));
-              sum3 = _mm512_add_epi32(sum3, m512_dpbusd_epi32(in, row3[j]));
-#endif
             }
 
             *outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias);
           }
           else
           {
-            const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
-            const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
-            const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
-            const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
-
-#if defined (USE_VNNI)
             __m256i sum0 = _mm256_setzero_si256();
             __m256i sum1 = _mm256_setzero_si256();
             __m256i sum2 = _mm256_setzero_si256();
             __m256i sum3 = _mm256_setzero_si256();
-            const IndexType kStart = 0;
-#else
-            __m256i sum0 = m256_dpbusd_epi32(input_vector256[0], row0[0]);
-            __m256i sum1 = m256_dpbusd_epi32(input_vector256[0], row1[0]);
-            __m256i sum2 = m256_dpbusd_epi32(input_vector256[0], row2[0]);
-            __m256i sum3 = m256_dpbusd_epi32(input_vector256[0], row3[0]);
-            const IndexType kStart = 1;
-#endif
 
-            for (IndexType j = kStart; j < kNumChunks256; ++j)
+            const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
+            const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
+            const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
+            const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
+
+            for (IndexType j = 0; j < kNumChunks256; ++j)
             {
               const __m256i in = input_vector256[j];
 
-#if defined (USE_VNNI)
               m256_add_dpbusd_epi32(sum0, in, row0[j]);
               m256_add_dpbusd_epi32(sum1, in, row1[j]);
               m256_add_dpbusd_epi32(sum2, in, row2[j]);
               m256_add_dpbusd_epi32(sum3, in, row3[j]);
-#else
-              sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
-              sum1 = _mm256_add_epi32(sum1, m256_dpbusd_epi32(in, row1[j]));
-              sum2 = _mm256_add_epi32(sum2, m256_dpbusd_epi32(in, row2[j]));
-              sum3 = _mm256_add_epi32(sum3, m256_dpbusd_epi32(in, row3[j]));
-#endif
             }
 
             *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
@@ -435,50 +489,30 @@ namespace Eval::NNUE::Layers {
       {
         if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
         {
-          const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
-
-#if defined (USE_VNNI)
           __m512i sum0 = _mm512_setzero_si512();
-          const IndexType kStart = 0;
-#else
-          __m512i sum0 = m512_dpbusd_epi32(input_vector512[0], row0[0]);
-          const IndexType kStart = 1;
-#endif
 
-          for (IndexType j = kStart; j < kNumChunks512; ++j)
+          const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
+
+          for (IndexType j = 0; j < kNumChunks512; ++j)
           {
             const __m512i in = input_vector512[j];
 
-#if defined (USE_VNNI)
             m512_add_dpbusd_epi32(sum0, in, row0[j]);
-#else
-            sum0 = _mm512_add_epi32(sum0, m512_dpbusd_epi32(in, row0[j]));
-#endif
           }
 
           output[0] = m512_hadd(sum0, biases_[0]);
         }
         else
         {
-          const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
-
-#if defined (USE_VNNI)
           __m256i sum0 = _mm256_setzero_si256();
-          const IndexType kStart = 0;
-#else
-          __m256i sum0 = m256_dpbusd_epi32(input_vector256[0], row0[0]);
-          const IndexType kStart = 1;
-#endif
 
-          for (IndexType j = kStart; j < kNumChunks256; ++j)
+          const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
+
+          for (IndexType j = 0; j < kNumChunks256; ++j)
           {
             const __m256i in = input_vector256[j];
 
-#if defined (USE_VNNI)
             m256_add_dpbusd_epi32(sum0, in, row0[j]);
-#else
-            sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
-#endif
           }
 
           output[0] = m256_hadd(sum0, biases_[0]);
@@ -512,40 +546,38 @@ namespace Eval::NNUE::Layers {
           const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
           __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
 
-          const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
-          const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
-          const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
-          const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
-
-#if defined (USE_VNNI)
           __m256i sum0 = _mm256_setzero_si256();
           __m256i sum1 = _mm256_setzero_si256();
           __m256i sum2 = _mm256_setzero_si256();
           __m256i sum3 = _mm256_setzero_si256();
-          const IndexType kStart = 0;
-#else
-          __m256i sum0 = m256_dpbusd_epi32(input_vector[0], row0[0]);
-          __m256i sum1 = m256_dpbusd_epi32(input_vector[0], row1[0]);
-          __m256i sum2 = m256_dpbusd_epi32(input_vector[0], row2[0]);
-          __m256i sum3 = m256_dpbusd_epi32(input_vector[0], row3[0]);
-          const IndexType kStart = 1;
-#endif
 
-          for (IndexType j = kStart; j < kNumChunks; ++j)
+          const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
+          const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
+          const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
+          const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
+
+          int j = 0;
+          if (!canSaturate16x4[i / 4])
           {
-            const __m256i in = input_vector[j];
+              for (; j < (int)kNumChunks - 1; j += 2)
+              {
+                  const __m256i in0 = input_vector[j];
+                  const __m256i in1 = input_vector[j + 1];
+
+                  m256_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
+                  m256_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
+                  m256_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
+                  m256_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
+              }
+          }
+          for (; j < (int)kNumChunks; ++j)
+          {
+                const __m256i in = input_vector[j];
 
-#if defined (USE_VNNI)
-            m256_add_dpbusd_epi32(sum0, in, row0[j]);
-            m256_add_dpbusd_epi32(sum1, in, row1[j]);
-            m256_add_dpbusd_epi32(sum2, in, row2[j]);
-            m256_add_dpbusd_epi32(sum3, in, row3[j]);
-#else
-            sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
-            sum1 = _mm256_add_epi32(sum1, m256_dpbusd_epi32(in, row1[j]));
-            sum2 = _mm256_add_epi32(sum2, m256_dpbusd_epi32(in, row2[j]));
-            sum3 = _mm256_add_epi32(sum3, m256_dpbusd_epi32(in, row3[j]));
-#endif
+                m256_add_dpbusd_epi32(sum0, in, row0[j]);
+                m256_add_dpbusd_epi32(sum1, in, row1[j]);
+                m256_add_dpbusd_epi32(sum2, in, row2[j]);
+                m256_add_dpbusd_epi32(sum3, in, row3[j]);
           }
 
           *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
@@ -553,25 +585,15 @@ namespace Eval::NNUE::Layers {
       }
       else if constexpr (kOutputDimensions == 1)
       {
-        const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
-
-#if defined (USE_VNNI)
         __m256i sum0 = _mm256_setzero_si256();
-        const IndexType kStart = 0;
-#else
-        __m256i sum0 = m256_dpbusd_epi32(input_vector[0], row0[0]);
-        const IndexType kStart = 1;
-#endif
 
-        for (IndexType j = kStart; j < kNumChunks; ++j)
+        const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
+
+        for (IndexType j = 0; j < kNumChunks; ++j)
         {
-          const __m256i in = input_vector[j];
+            const __m256i in = input_vector[j];
 
-#if defined (USE_VNNI)
-          m256_add_dpbusd_epi32(sum0, in, row0[j]);
-#else
-          sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
-#endif
+            m256_add_dpbusd_epi32(sum0, in, row0[j]);
         }
 
         output[0] = m256_hadd(sum0, biases_[0]);
@@ -604,24 +626,38 @@ namespace Eval::NNUE::Layers {
           const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
           __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
 
+          __m128i sum0 = _mm_setzero_si128();
+          __m128i sum1 = _mm_setzero_si128();
+          __m128i sum2 = _mm_setzero_si128();
+          __m128i sum3 = _mm_setzero_si128();
+
           const auto row0 = reinterpret_cast<const __m128i*>(&weights_[offset0]);
           const auto row1 = reinterpret_cast<const __m128i*>(&weights_[offset1]);
           const auto row2 = reinterpret_cast<const __m128i*>(&weights_[offset2]);
           const auto row3 = reinterpret_cast<const __m128i*>(&weights_[offset3]);
 
-          __m128i sum0 = m128_dpbusd_epi32(input_vector[0], row0[0]);
-          __m128i sum1 = m128_dpbusd_epi32(input_vector[0], row1[0]);
-          __m128i sum2 = m128_dpbusd_epi32(input_vector[0], row2[0]);
-          __m128i sum3 = m128_dpbusd_epi32(input_vector[0], row3[0]);
-
-          for (int j = 1; j < (int)kNumChunks; ++j)
+          int j = 0;
+          if (!canSaturate16x4[i / 4])
+          {
+              for (; j < (int)kNumChunks - 1; j += 2)
+              {
+                  const __m128i in0 = input_vector[j];
+                  const __m128i in1 = input_vector[j + 1];
+
+                  m128_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
+                  m128_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
+                  m128_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
+                  m128_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
+              }
+          }
+          for (; j < (int)kNumChunks; ++j)
           {
-            const __m128i in = input_vector[j];
+              const __m128i in = input_vector[j];
 
-            sum0 = _mm_add_epi32(sum0, m128_dpbusd_epi32(in, row0[j]));
-            sum1 = _mm_add_epi32(sum1, m128_dpbusd_epi32(in, row1[j]));
-            sum2 = _mm_add_epi32(sum2, m128_dpbusd_epi32(in, row2[j]));
-            sum3 = _mm_add_epi32(sum3, m128_dpbusd_epi32(in, row3[j]));
+              m128_add_dpbusd_epi32(sum0, in, row0[j]);
+              m128_add_dpbusd_epi32(sum1, in, row1[j]);
+              m128_add_dpbusd_epi32(sum2, in, row2[j]);
+              m128_add_dpbusd_epi32(sum3, in, row3[j]);
           }
 
           *outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias);
@@ -629,12 +665,16 @@ namespace Eval::NNUE::Layers {
       }
       else if constexpr (kOutputDimensions == 1)
       {
+        __m128i sum0 = _mm_setzero_si128();
+
         const auto row0 = reinterpret_cast<const __m128i*>(&weights_[0]);
 
-        __m128i sum0 = m128_dpbusd_epi32(input_vector[0], row0[0]);
+        for (int j = 0; j < (int)kNumChunks; ++j)
+        {
+          const __m128i in = input_vector[j];
 
-        for (int j = 1; j < (int)kNumChunks; ++j)
-          sum0 = _mm_add_epi32(sum0, m128_dpbusd_epi32(input_vector[j], row0[j]));
+          m128_add_dpbusd_epi32(sum0, in, row0[j]);
+        }
 
         output[0] = m128_hadd(sum0, biases_[0]);
       }
@@ -751,8 +791,11 @@ namespace Eval::NNUE::Layers {
     PreviousLayer previous_layer_;
 
     alignas(kCacheLineSize) BiasType biases_[kOutputDimensions];
-    alignas(kCacheLineSize)
-        WeightType weights_[kOutputDimensions * kPaddedInputDimensions];
+    alignas(kCacheLineSize) WeightType weights_[kOutputDimensions * kPaddedInputDimensions];
+    union {
+        uint32_t canSaturate16x4[(kOutputDimensions + 3) / 4];
+        bool canSaturate16[kOutputDimensions];
+    };
   };
 
 }  // namespace Eval::NNUE::Layers