Affine transform robust implementation
[stockfish] / src / nnue / layers / affine_transform.h
index 34777ef66e70351fefc6f5310f10d3f6983d3926..adf152eea5b8894fcdf26cdfebffcdbd602b5d7f 100644 (file)
@@ -301,20 +301,40 @@ namespace Eval::NNUE::Layers {
       }
       else if constexpr (kOutputDimensions == 1)
       {
-          constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
-
-          vec_t sum0 = vec_setzero();
-
-          const auto row0 = reinterpret_cast<const vec_t*>(&weights_[0]);
-
-          for (int j = 0; j < (int)kNumChunks; ++j)
+#if defined (USE_AVX512)
+          if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) != 0)
           {
-              const vec_t in = input_vector[j];
-
-              vec_add_dpbusd_32(sum0, in, row0[j]);
+              constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
+              const auto input_vector256 = reinterpret_cast<const __m256i*>(input);
+
+              __m256i sum0 = _mm256_setzero_si256();
+              const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
+
+              for (int j = 0; j < (int)kNumChunks; ++j)
+              {
+                  const __m256i in = input_vector256[j];
+                  m256_add_dpbusd_epi32(sum0, in, row0[j]);
+              }
+              output[0] = m256_hadd(sum0, biases_[0]);
+          }
+          else
+#endif
+          {
+#if defined (USE_AVX512)
+              constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2);
+#else
+              constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
+#endif
+              vec_t sum0 = vec_setzero();
+              const auto row0 = reinterpret_cast<const vec_t*>(&weights_[0]);
+
+              for (int j = 0; j < (int)kNumChunks; ++j)
+              {
+                  const vec_t in = input_vector[j];
+                  vec_add_dpbusd_32(sum0, in, row0[j]);
+              }
+              output[0] = vec_hadd(sum0, biases_[0]);
           }
-
-          output[0] = vec_hadd(sum0, biases_[0]);
       }
 
 #else