Calculate sum from first elements
[stockfish] / src / nnue / layers / affine_transform.h
index 47c9c488b0c06ab137e532157163644c3d37d6af..caf315b2792897df8b206c57aa718cb8331ec496 100644 (file)
@@ -181,13 +181,13 @@ namespace Eval::NNUE::Layers {
         return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
       };
 
-      [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
 #if defined (USE_VNNI)
+      [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
         acc = _mm512_dpbusd_epi32(acc, a, b);
 #else
+      [[maybe_unused]] auto m512_dpbusd_epi32 = [=](__m512i a, __m512i b) -> __m512i {
         __m512i product0 = _mm512_maddubs_epi16(a, b);
-        product0 = _mm512_madd_epi16(product0, kOnes512);
-        acc = _mm512_add_epi32(acc, product0);
+        return _mm512_madd_epi16(product0, kOnes512);
 #endif
       };
 
@@ -214,14 +214,13 @@ namespace Eval::NNUE::Layers {
 
         return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
       };
-
-      [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
 #if defined (USE_VNNI)
+      [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
         acc = _mm256_dpbusd_epi32(acc, a, b);
 #else
+      [[maybe_unused]] auto m256_dpbusd_epi32 = [=](__m256i a, __m256i b) -> __m256i {
         __m256i product0 = _mm256_maddubs_epi16(a, b);
-        product0 = _mm256_madd_epi16(product0, kOnes256);
-        acc = _mm256_add_epi32(acc, product0);
+        return _mm256_madd_epi16(product0, kOnes256);
 #endif
       };
 
@@ -246,10 +245,9 @@ namespace Eval::NNUE::Layers {
         return _mm_add_epi32(sum0, bias);
       };
 
-      [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
+      [[maybe_unused]] auto m128_dpbusd_epi32 = [=](__m128i a, __m128i b) -> __m128i {
         __m128i product0 = _mm_maddubs_epi16(a, b);
-        product0 = _mm_madd_epi16(product0, kOnes128);
-        acc = _mm_add_epi32(acc, product0);
+        return _mm_madd_epi16(product0, kOnes128);
       };
 
 #endif
@@ -293,15 +291,6 @@ namespace Eval::NNUE::Layers {
           const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
           __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
 
-          __m512i sum01a = _mm512_setzero_si512();
-          __m512i sum23a = _mm512_setzero_si512();
-          __m512i sum45a = _mm512_setzero_si512();
-          __m512i sum67a = _mm512_setzero_si512();
-          __m512i sum01b = _mm512_setzero_si512();
-          __m512i sum23b = _mm512_setzero_si512();
-          __m512i sum45b = _mm512_setzero_si512();
-          __m512i sum67b = _mm512_setzero_si512();
-
           const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
           const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
           const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
@@ -314,6 +303,16 @@ namespace Eval::NNUE::Layers {
           const __m256i in256 = input_vector256[0];
           const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
 
+#if defined (USE_VNNI)
+          __m512i sum01a = _mm512_setzero_si512();
+          __m512i sum23a = _mm512_setzero_si512();
+          __m512i sum45a = _mm512_setzero_si512();
+          __m512i sum67a = _mm512_setzero_si512();
+          __m512i sum01b = _mm512_setzero_si512();
+          __m512i sum23b = _mm512_setzero_si512();
+          __m512i sum45b = _mm512_setzero_si512();
+          __m512i sum67b = _mm512_setzero_si512();
+
           m512_add_dpbusd_epi32(sum01a, in, row01a);
           m512_add_dpbusd_epi32(sum23a, in, row23a);
           m512_add_dpbusd_epi32(sum45a, in, row45a);
@@ -322,6 +321,16 @@ namespace Eval::NNUE::Layers {
           m512_add_dpbusd_epi32(sum23b, in, row23b);
           m512_add_dpbusd_epi32(sum45b, in, row45b);
           m512_add_dpbusd_epi32(sum67b, in, row67b);
+#else
+          __m512i sum01a = m512_dpbusd_epi32(in, row01a);
+          __m512i sum23a = m512_dpbusd_epi32(in, row23a);
+          __m512i sum45a = m512_dpbusd_epi32(in, row45a);
+          __m512i sum67a = m512_dpbusd_epi32(in, row67a);
+          __m512i sum01b = m512_dpbusd_epi32(in, row01b);
+          __m512i sum23b = m512_dpbusd_epi32(in, row23b);
+          __m512i sum45b = m512_dpbusd_epi32(in, row45b);
+          __m512i sum67b = m512_dpbusd_epi32(in, row67b);
+#endif
 
           *outptr = m512_hadd256x16(
             sum01a, sum23a, sum45a, sum67a,
@@ -342,48 +351,80 @@ namespace Eval::NNUE::Layers {
 
           if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
           {
-            __m512i sum0 = _mm512_setzero_si512();
-            __m512i sum1 = _mm512_setzero_si512();
-            __m512i sum2 = _mm512_setzero_si512();
-            __m512i sum3 = _mm512_setzero_si512();
-
             const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
             const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
             const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
             const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
 
-            for (IndexType j = 0; j < kNumChunks512; ++j)
+#if defined (USE_VNNI)
+            __m512i sum0 = _mm512_setzero_si512();
+            __m512i sum1 = _mm512_setzero_si512();
+            __m512i sum2 = _mm512_setzero_si512();
+            __m512i sum3 = _mm512_setzero_si512();
+            const IndexType kStart = 0;
+#else
+            __m512i sum0 = m512_dpbusd_epi32(input_vector512[0], row0[0]);
+            __m512i sum1 = m512_dpbusd_epi32(input_vector512[0], row1[0]);
+            __m512i sum2 = m512_dpbusd_epi32(input_vector512[0], row2[0]);
+            __m512i sum3 = m512_dpbusd_epi32(input_vector512[0], row3[0]);
+            const IndexType kStart = 1;
+#endif
+
+            for (IndexType j = kStart; j < kNumChunks512; ++j)
             {
               const __m512i in = input_vector512[j];
 
+#if defined (USE_VNNI)
               m512_add_dpbusd_epi32(sum0, in, row0[j]);
               m512_add_dpbusd_epi32(sum1, in, row1[j]);
               m512_add_dpbusd_epi32(sum2, in, row2[j]);
               m512_add_dpbusd_epi32(sum3, in, row3[j]);
+#else
+              sum0 = _mm512_add_epi32(sum0, m512_dpbusd_epi32(in, row0[j]));
+              sum1 = _mm512_add_epi32(sum1, m512_dpbusd_epi32(in, row1[j]));
+              sum2 = _mm512_add_epi32(sum2, m512_dpbusd_epi32(in, row2[j]));
+              sum3 = _mm512_add_epi32(sum3, m512_dpbusd_epi32(in, row3[j]));
+#endif
             }
 
             *outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias);
           }
           else
           {
-            __m256i sum0 = _mm256_setzero_si256();
-            __m256i sum1 = _mm256_setzero_si256();
-            __m256i sum2 = _mm256_setzero_si256();
-            __m256i sum3 = _mm256_setzero_si256();
-
             const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
             const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
             const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
             const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
 
-            for (IndexType j = 0; j < kNumChunks256; ++j)
+#if defined (USE_VNNI)
+            __m256i sum0 = _mm256_setzero_si256();
+            __m256i sum1 = _mm256_setzero_si256();
+            __m256i sum2 = _mm256_setzero_si256();
+            __m256i sum3 = _mm256_setzero_si256();
+            const IndexType kStart = 0;
+#else
+            __m256i sum0 = m256_dpbusd_epi32(input_vector256[0], row0[0]);
+            __m256i sum1 = m256_dpbusd_epi32(input_vector256[0], row1[0]);
+            __m256i sum2 = m256_dpbusd_epi32(input_vector256[0], row2[0]);
+            __m256i sum3 = m256_dpbusd_epi32(input_vector256[0], row3[0]);
+            const IndexType kStart = 1;
+#endif
+
+            for (IndexType j = kStart; j < kNumChunks256; ++j)
             {
               const __m256i in = input_vector256[j];
 
+#if defined (USE_VNNI)
               m256_add_dpbusd_epi32(sum0, in, row0[j]);
               m256_add_dpbusd_epi32(sum1, in, row1[j]);
               m256_add_dpbusd_epi32(sum2, in, row2[j]);
               m256_add_dpbusd_epi32(sum3, in, row3[j]);
+#else
+              sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
+              sum1 = _mm256_add_epi32(sum1, m256_dpbusd_epi32(in, row1[j]));
+              sum2 = _mm256_add_epi32(sum2, m256_dpbusd_epi32(in, row2[j]));
+              sum3 = _mm256_add_epi32(sum3, m256_dpbusd_epi32(in, row3[j]));
+#endif
             }
 
             *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
@@ -394,30 +435,50 @@ namespace Eval::NNUE::Layers {
       {
         if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
         {
-          __m512i sum0 = _mm512_setzero_si512();
-
           const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
 
-          for (IndexType j = 0; j < kNumChunks512; ++j)
+#if defined (USE_VNNI)
+          __m512i sum0 = _mm512_setzero_si512();
+          const IndexType kStart = 0;
+#else
+          __m512i sum0 = m512_dpbusd_epi32(input_vector512[0], row0[0]);
+          const IndexType kStart = 1;
+#endif
+
+          for (IndexType j = kStart; j < kNumChunks512; ++j)
           {
             const __m512i in = input_vector512[j];
 
+#if defined (USE_VNNI)
             m512_add_dpbusd_epi32(sum0, in, row0[j]);
+#else
+            sum0 = _mm512_add_epi32(sum0, m512_dpbusd_epi32(in, row0[j]));
+#endif
           }
 
           output[0] = m512_hadd(sum0, biases_[0]);
         }
         else
         {
-          __m256i sum0 = _mm256_setzero_si256();
-
           const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
 
-          for (IndexType j = 0; j < kNumChunks256; ++j)
+#if defined (USE_VNNI)
+          __m256i sum0 = _mm256_setzero_si256();
+          const IndexType kStart = 0;
+#else
+          __m256i sum0 = m256_dpbusd_epi32(input_vector256[0], row0[0]);
+          const IndexType kStart = 1;
+#endif
+
+          for (IndexType j = kStart; j < kNumChunks256; ++j)
           {
             const __m256i in = input_vector256[j];
 
+#if defined (USE_VNNI)
             m256_add_dpbusd_epi32(sum0, in, row0[j]);
+#else
+            sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
+#endif
           }
 
           output[0] = m256_hadd(sum0, biases_[0]);
@@ -451,24 +512,40 @@ namespace Eval::NNUE::Layers {
           const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
           __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
 
-          __m256i sum0 = _mm256_setzero_si256();
-          __m256i sum1 = _mm256_setzero_si256();
-          __m256i sum2 = _mm256_setzero_si256();
-          __m256i sum3 = _mm256_setzero_si256();
-
           const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
           const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
           const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
           const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
 
-          for (IndexType j = 0; j < kNumChunks; ++j)
+#if defined (USE_VNNI)
+          __m256i sum0 = _mm256_setzero_si256();
+          __m256i sum1 = _mm256_setzero_si256();
+          __m256i sum2 = _mm256_setzero_si256();
+          __m256i sum3 = _mm256_setzero_si256();
+          const IndexType kStart = 0;
+#else
+          __m256i sum0 = m256_dpbusd_epi32(input_vector[0], row0[0]);
+          __m256i sum1 = m256_dpbusd_epi32(input_vector[0], row1[0]);
+          __m256i sum2 = m256_dpbusd_epi32(input_vector[0], row2[0]);
+          __m256i sum3 = m256_dpbusd_epi32(input_vector[0], row3[0]);
+          const IndexType kStart = 1;
+#endif
+
+          for (IndexType j = kStart; j < kNumChunks; ++j)
           {
             const __m256i in = input_vector[j];
 
+#if defined (USE_VNNI)
             m256_add_dpbusd_epi32(sum0, in, row0[j]);
             m256_add_dpbusd_epi32(sum1, in, row1[j]);
             m256_add_dpbusd_epi32(sum2, in, row2[j]);
             m256_add_dpbusd_epi32(sum3, in, row3[j]);
+#else
+            sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
+            sum1 = _mm256_add_epi32(sum1, m256_dpbusd_epi32(in, row1[j]));
+            sum2 = _mm256_add_epi32(sum2, m256_dpbusd_epi32(in, row2[j]));
+            sum3 = _mm256_add_epi32(sum3, m256_dpbusd_epi32(in, row3[j]));
+#endif
           }
 
           *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
@@ -476,15 +553,25 @@ namespace Eval::NNUE::Layers {
       }
       else if constexpr (kOutputDimensions == 1)
       {
-        __m256i sum0 = _mm256_setzero_si256();
-
         const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
 
-        for (IndexType j = 0; j < kNumChunks; ++j)
+#if defined (USE_VNNI)
+        __m256i sum0 = _mm256_setzero_si256();
+        const IndexType kStart = 0;
+#else
+        __m256i sum0 = m256_dpbusd_epi32(input_vector[0], row0[0]);
+        const IndexType kStart = 1;
+#endif
+
+        for (IndexType j = kStart; j < kNumChunks; ++j)
         {
           const __m256i in = input_vector[j];
 
-            m256_add_dpbusd_epi32(sum0, in, row0[j]);
+#if defined (USE_VNNI)
+          m256_add_dpbusd_epi32(sum0, in, row0[j]);
+#else
+          sum0 = _mm256_add_epi32(sum0, m256_dpbusd_epi32(in, row0[j]));
+#endif
         }
 
         output[0] = m256_hadd(sum0, biases_[0]);
@@ -517,24 +604,24 @@ namespace Eval::NNUE::Layers {
           const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
           __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
 
-          __m128i sum0 = _mm_setzero_si128();
-          __m128i sum1 = _mm_setzero_si128();
-          __m128i sum2 = _mm_setzero_si128();
-          __m128i sum3 = _mm_setzero_si128();
-
           const auto row0 = reinterpret_cast<const __m128i*>(&weights_[offset0]);
           const auto row1 = reinterpret_cast<const __m128i*>(&weights_[offset1]);
           const auto row2 = reinterpret_cast<const __m128i*>(&weights_[offset2]);
           const auto row3 = reinterpret_cast<const __m128i*>(&weights_[offset3]);
 
-          for (int j = 0; j < (int)kNumChunks; j += 1)
+          __m128i sum0 = m128_dpbusd_epi32(input_vector[0], row0[0]);
+          __m128i sum1 = m128_dpbusd_epi32(input_vector[0], row1[0]);
+          __m128i sum2 = m128_dpbusd_epi32(input_vector[0], row2[0]);
+          __m128i sum3 = m128_dpbusd_epi32(input_vector[0], row3[0]);
+
+          for (int j = 1; j < (int)kNumChunks; ++j)
           {
             const __m128i in = input_vector[j];
 
-            m128_add_dpbusd_epi32(sum0, in, row0[j]);
-            m128_add_dpbusd_epi32(sum1, in, row1[j]);
-            m128_add_dpbusd_epi32(sum2, in, row2[j]);
-            m128_add_dpbusd_epi32(sum3, in, row3[j]);
+            sum0 = _mm_add_epi32(sum0, m128_dpbusd_epi32(in, row0[j]));
+            sum1 = _mm_add_epi32(sum1, m128_dpbusd_epi32(in, row1[j]));
+            sum2 = _mm_add_epi32(sum2, m128_dpbusd_epi32(in, row2[j]));
+            sum3 = _mm_add_epi32(sum3, m128_dpbusd_epi32(in, row3[j]));
           }
 
           *outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias);
@@ -542,16 +629,12 @@ namespace Eval::NNUE::Layers {
       }
       else if constexpr (kOutputDimensions == 1)
       {
-        __m128i sum0 = _mm_setzero_si128();
-
         const auto row0 = reinterpret_cast<const __m128i*>(&weights_[0]);
 
-        for (int j = 0; j < (int)kNumChunks; j += 1)
-        {
-          const __m128i in = input_vector[j];
+        __m128i sum0 = m128_dpbusd_epi32(input_vector[0], row0[0]);
 
-          m128_add_dpbusd_epi32(sum0, in, row0[j]);
-        }
+        for (int j = 1; j < (int)kNumChunks; ++j)
+          sum0 = _mm_add_epi32(sum0, m128_dpbusd_epi32(input_vector[j], row0[j]));
 
         output[0] = m128_hadd(sum0, biases_[0]);
       }