]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/affine_transform.h
Merge remote-tracking branch 'upstream/master'
[stockfish] / src / nnue / layers / affine_transform.h
index 3fba45ed87dc57e5ace02a366bad1a63e1f1c5d7..59a6149f0c40dd6d379fba733354776ee3f7417b 100644 (file)
@@ -1,6 +1,6 @@
 /*
   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
 /*
   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
-  Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file)
+  Copyright (C) 2004-2024 The Stockfish developers (see AUTHORS file)
 
   Stockfish is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
 
   Stockfish is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
 
 namespace Stockfish::Eval::NNUE::Layers {
 
 
 namespace Stockfish::Eval::NNUE::Layers {
 
+#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD)
+    #define ENABLE_SEQ_OPT
+#endif
+
 // Fallback implementation for older/other architectures.
 // Requires the input to be padded to at least 16 values.
 // Fallback implementation for older/other architectures.
 // Requires the input to be padded to at least 16 values.
-#if !defined(USE_SSSE3)
+#ifndef ENABLE_SEQ_OPT
+
 template<IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>
 static void affine_transform_non_ssse3(std::int32_t*       output,
                                        const std::int8_t*  weights,
                                        const std::int32_t* biases,
                                        const std::uint8_t* input) {
 template<IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>
 static void affine_transform_non_ssse3(std::int32_t*       output,
                                        const std::int8_t*  weights,
                                        const std::int32_t* biases,
                                        const std::uint8_t* input) {
-    #if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON)
+    #if defined(USE_SSE2) || defined(USE_NEON)
         #if defined(USE_SSE2)
     // At least a multiple of 16, with SSE2.
     constexpr IndexType NumChunks   = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
     const __m128i       Zeros       = _mm_setzero_si128();
     const auto          inputVector = reinterpret_cast<const __m128i*>(input);
 
         #if defined(USE_SSE2)
     // At least a multiple of 16, with SSE2.
     constexpr IndexType NumChunks   = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
     const __m128i       Zeros       = _mm_setzero_si128();
     const auto          inputVector = reinterpret_cast<const __m128i*>(input);
 
-        #elif defined(USE_NEON_DOTPROD)
-    constexpr IndexType NumChunks   = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
-    const auto          inputVector = reinterpret_cast<const int8x16_t*>(input);
-
         #elif defined(USE_NEON)
     constexpr IndexType NumChunks   = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
     const auto          inputVector = reinterpret_cast<const int8x8_t*>(input);
         #elif defined(USE_NEON)
     constexpr IndexType NumChunks   = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
     const auto          inputVector = reinterpret_cast<const int8x8_t*>(input);
@@ -91,16 +92,8 @@ static void affine_transform_non_ssse3(std::int32_t*       output,
         sum                   = _mm_add_epi32(sum, sum_second_32);
         output[i]             = _mm_cvtsi128_si32(sum);
 
         sum                   = _mm_add_epi32(sum, sum_second_32);
         output[i]             = _mm_cvtsi128_si32(sum);
 
-        #elif defined(USE_NEON_DOTPROD)
-        int32x4_t  sum = {biases[i]};
-        const auto row = reinterpret_cast<const int8x16_t*>(&weights[offset]);
-        for (IndexType j = 0; j < NumChunks; ++j)
-        {
-            sum = vdotq_s32(sum, inputVector[j], row[j]);
-        }
-        output[i] = vaddvq_s32(sum);
-
         #elif defined(USE_NEON)
         #elif defined(USE_NEON)
+
         int32x4_t  sum = {biases[i]};
         const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
         for (IndexType j = 0; j < NumChunks; ++j)
         int32x4_t  sum = {biases[i]};
         const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
         for (IndexType j = 0; j < NumChunks; ++j)
@@ -127,7 +120,8 @@ static void affine_transform_non_ssse3(std::int32_t*       output,
         }
     #endif
 }
         }
     #endif
 }
-#endif
+
+#endif  // !ENABLE_SEQ_OPT
 
 template<IndexType InDims, IndexType OutDims>
 class AffineTransform {
 
 template<IndexType InDims, IndexType OutDims>
 class AffineTransform {
@@ -162,7 +156,7 @@ class AffineTransform {
     }
 
     static constexpr IndexType get_weight_index(IndexType i) {
     }
 
     static constexpr IndexType get_weight_index(IndexType i) {
-#if defined(USE_SSSE3)
+#ifdef ENABLE_SEQ_OPT
         return get_weight_index_scrambled(i);
 #else
         return i;
         return get_weight_index_scrambled(i);
 #else
         return i;
@@ -190,32 +184,28 @@ class AffineTransform {
     // Forward propagation
     void propagate(const InputType* input, OutputType* output) const {
 
     // Forward propagation
     void propagate(const InputType* input, OutputType* output) const {
 
-#if defined(USE_SSSE3)
+#ifdef ENABLE_SEQ_OPT
 
         if constexpr (OutputDimensions > 1)
         {
 
         if constexpr (OutputDimensions > 1)
         {
-
     #if defined(USE_AVX512)
             using vec_t = __m512i;
     #if defined(USE_AVX512)
             using vec_t = __m512i;
-        #define vec_setzero _mm512_setzero_si512
         #define vec_set_32 _mm512_set1_epi32
         #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
         #define vec_set_32 _mm512_set1_epi32
         #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
-        #define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2
-        #define vec_hadd Simd::m512_hadd
     #elif defined(USE_AVX2)
             using vec_t = __m256i;
     #elif defined(USE_AVX2)
             using vec_t = __m256i;
-        #define vec_setzero _mm256_setzero_si256
         #define vec_set_32 _mm256_set1_epi32
         #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
         #define vec_set_32 _mm256_set1_epi32
         #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
-        #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
-        #define vec_hadd Simd::m256_hadd
     #elif defined(USE_SSSE3)
             using vec_t = __m128i;
     #elif defined(USE_SSSE3)
             using vec_t = __m128i;
-        #define vec_setzero _mm_setzero_si128
         #define vec_set_32 _mm_set1_epi32
         #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
         #define vec_set_32 _mm_set1_epi32
         #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
-        #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
-        #define vec_hadd Simd::m128_hadd
+    #elif defined(USE_NEON_DOTPROD)
+            using vec_t = int32x4_t;
+        #define vec_set_32 vdupq_n_s32
+        #define vec_add_dpbusd_32(acc, a, b) \
+            Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \
+                                                vreinterpretq_s8_s32(b))
     #endif
 
             static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);
     #endif
 
             static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);
@@ -231,46 +221,47 @@ class AffineTransform {
             for (IndexType k = 0; k < NumRegs; ++k)
                 acc[k] = biasvec[k];
 
             for (IndexType k = 0; k < NumRegs; ++k)
                 acc[k] = biasvec[k];
 
-            for (IndexType i = 0; i < NumChunks; i += 2)
+            for (IndexType i = 0; i < NumChunks; ++i)
             {
             {
-                const vec_t in0 = vec_set_32(input32[i + 0]);
-                const vec_t in1 = vec_set_32(input32[i + 1]);
+                const vec_t in0 = vec_set_32(input32[i]);
                 const auto  col0 =
                 const auto  col0 =
-                  reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]);
-                const auto col1 =
-                  reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
+                  reinterpret_cast<const vec_t*>(&weights[i * OutputDimensions * 4]);
+
                 for (IndexType k = 0; k < NumRegs; ++k)
                 for (IndexType k = 0; k < NumRegs; ++k)
-                    vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]);
+                    vec_add_dpbusd_32(acc[k], in0, col0[k]);
             }
 
             vec_t* outptr = reinterpret_cast<vec_t*>(output);
             for (IndexType k = 0; k < NumRegs; ++k)
                 outptr[k] = acc[k];
 
             }
 
             vec_t* outptr = reinterpret_cast<vec_t*>(output);
             for (IndexType k = 0; k < NumRegs; ++k)
                 outptr[k] = acc[k];
 
-    #undef vec_setzero
     #undef vec_set_32
     #undef vec_add_dpbusd_32
     #undef vec_set_32
     #undef vec_add_dpbusd_32
-    #undef vec_add_dpbusd_32x2
-    #undef vec_hadd
         }
         else if constexpr (OutputDimensions == 1)
         {
         }
         else if constexpr (OutputDimensions == 1)
         {
-
-    // We cannot use AVX512 for the last layer because there's only 32 inputs and the buffer is not padded to 64 elements.
+    // We cannot use AVX512 for the last layer because there are only 32 inputs
+    // and the buffer is not padded to 64 elements.
     #if defined(USE_AVX2)
             using vec_t = __m256i;
     #if defined(USE_AVX2)
             using vec_t = __m256i;
-        #define vec_setzero _mm256_setzero_si256
+        #define vec_setzero() _mm256_setzero_si256()
         #define vec_set_32 _mm256_set1_epi32
         #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
         #define vec_set_32 _mm256_set1_epi32
         #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
-        #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
         #define vec_hadd Simd::m256_hadd
     #elif defined(USE_SSSE3)
             using vec_t = __m128i;
         #define vec_hadd Simd::m256_hadd
     #elif defined(USE_SSSE3)
             using vec_t = __m128i;
-        #define vec_setzero _mm_setzero_si128
+        #define vec_setzero() _mm_setzero_si128()
         #define vec_set_32 _mm_set1_epi32
         #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
         #define vec_set_32 _mm_set1_epi32
         #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
-        #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
         #define vec_hadd Simd::m128_hadd
         #define vec_hadd Simd::m128_hadd
+    #elif defined(USE_NEON_DOTPROD)
+            using vec_t = int32x4_t;
+        #define vec_setzero() vdupq_n_s32(0)
+        #define vec_set_32 vdupq_n_s32
+        #define vec_add_dpbusd_32(acc, a, b) \
+            Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \
+                                                vreinterpretq_s8_s32(b))
+        #define vec_hadd Simd::neon_m128_hadd
     #endif
 
             const auto inputVector = reinterpret_cast<const vec_t*>(input);
     #endif
 
             const auto inputVector = reinterpret_cast<const vec_t*>(input);
@@ -293,7 +284,6 @@ class AffineTransform {
     #undef vec_setzero
     #undef vec_set_32
     #undef vec_add_dpbusd_32
     #undef vec_setzero
     #undef vec_set_32
     #undef vec_add_dpbusd_32
-    #undef vec_add_dpbusd_32x2
     #undef vec_hadd
         }
 #else
     #undef vec_hadd
         }
 #else