]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/affine_transform.h
New NNUE architecture and net
[stockfish] / src / nnue / layers / affine_transform.h
index 9a5f62c0a098691322f911fee4d212d12a199685..d131836865f73800920fcb29ed635bd199c29b4f 100644 (file)
@@ -46,6 +46,11 @@ namespace Stockfish::Eval::NNUE::Layers {
 #elif defined (USE_SSSE3)
     static constexpr const IndexType OutputSimdWidth = SimdWidth / 4;
 #endif
+#if defined (USE_AVX512)
+    static constexpr const IndexType InputSimdWidth = SimdWidth * 2;
+#elif defined (USE_SSSE3)
+    static constexpr const IndexType InputSimdWidth = SimdWidth;
+#endif
 
     // Size of forward propagation buffer used in this layer
     static constexpr std::size_t SelfBufferSize =
@@ -72,6 +77,15 @@ namespace Stockfish::Eval::NNUE::Layers {
       for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
 #if !defined (USE_SSSE3)
         weights[i] = read_little_endian<WeightType>(stream);
+#elif defined (USE_VNNI) || defined (USE_AVX512)
+        if constexpr (OutputDimensions <= 8 && OutputDimensions != 1)
+            weights[i] = read_little_endian<WeightType>(stream);
+        else
+            weights[
+              (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
+              i / PaddedInputDimensions * 4 +
+              i % 4
+            ] = read_little_endian<WeightType>(stream);
 #else
         weights[
           (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
@@ -108,7 +122,6 @@ namespace Stockfish::Eval::NNUE::Layers {
 
       return !stream.fail();
     }
-
     // Forward propagation
     const OutputType* propagate(
         const TransformedFeatureType* transformedFeatures, char* buffer) const {
@@ -123,6 +136,40 @@ namespace Stockfish::Eval::NNUE::Layers {
         return _mm512_reduce_add_epi32(sum) + bias;
       };
 
+      [[maybe_unused]] auto m512_hadd128x16_interleave = [](
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
+
+        __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
+        __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
+
+        __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
+        __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
+
+        __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
+        __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
+
+        __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
+        __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
+
+        return _mm512_add_epi32(sum0123a, sum0123b);
+      };
+
+      [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
+
+        __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
+
+        __m256i sum256lo = _mm512_castsi512_si256(sum);
+        __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
+
+        sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
+
+        __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
+        __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
+
+        return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
+      };
+
       [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
 #if defined (USE_VNNI)
         acc = _mm512_dpbusd_epi32(acc, a, b);
@@ -133,6 +180,19 @@ namespace Stockfish::Eval::NNUE::Layers {
 #endif
       };
 
+      [[maybe_unused]] auto m512_add_dpbusd_epi32x2 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
+#if defined (USE_VNNI)
+        acc = _mm512_dpbusd_epi32(acc, a0, b0);
+        acc = _mm512_dpbusd_epi32(acc, a1, b1);
+#else
+        __m512i product0 = _mm512_maddubs_epi16(a0, b0);
+        __m512i product1 = _mm512_maddubs_epi16(a1, b1);
+        product0 = _mm512_adds_epi16(product0, product1);
+        product0 = _mm512_madd_epi16(product0, Ones512);
+        acc = _mm512_add_epi32(acc, product0);
+#endif
+      };
+
       [[maybe_unused]] auto m512_add_dpbusd_epi32x4 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1,
                                                                         __m512i a2, __m512i b2, __m512i a3, __m512i b3) {
 #if defined (USE_VNNI)
@@ -165,6 +225,18 @@ namespace Stockfish::Eval::NNUE::Layers {
         return _mm_cvtsi128_si32(sum128) + bias;
       };
 
+      [[maybe_unused]] auto m256_haddx4 = [](__m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3, __m128i bias) -> __m128i {
+        sum0 = _mm256_hadd_epi32(sum0, sum1);
+        sum2 = _mm256_hadd_epi32(sum2, sum3);
+
+        sum0 = _mm256_hadd_epi32(sum0, sum2);
+
+        __m128i sum128lo = _mm256_castsi256_si128(sum0);
+        __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
+
+        return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
+      };
+
       [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
 #if defined (USE_VNNI)
         acc = _mm256_dpbusd_epi32(acc, a, b);
@@ -175,6 +247,19 @@ namespace Stockfish::Eval::NNUE::Layers {
 #endif
       };
 
+      [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
+#if defined (USE_VNNI)
+        acc = _mm256_dpbusd_epi32(acc, a0, b0);
+        acc = _mm256_dpbusd_epi32(acc, a1, b1);
+#else
+        __m256i product0 = _mm256_maddubs_epi16(a0, b0);
+        __m256i product1 = _mm256_maddubs_epi16(a1, b1);
+        product0 = _mm256_adds_epi16(product0, product1);
+        product0 = _mm256_madd_epi16(product0, Ones256);
+        acc = _mm256_add_epi32(acc, product0);
+#endif
+      };
+
       [[maybe_unused]] auto m256_add_dpbusd_epi32x4 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1,
                                                                         __m256i a2, __m256i b2, __m256i a3, __m256i b3) {
 #if defined (USE_VNNI)
@@ -206,12 +291,27 @@ namespace Stockfish::Eval::NNUE::Layers {
         return _mm_cvtsi128_si32(sum) + bias;
       };
 
+      [[maybe_unused]] auto m128_haddx4 = [](__m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3, __m128i bias) -> __m128i {
+        sum0 = _mm_hadd_epi32(sum0, sum1);
+        sum2 = _mm_hadd_epi32(sum2, sum3);
+        sum0 = _mm_hadd_epi32(sum0, sum2);
+        return _mm_add_epi32(sum0, bias);
+      };
+
       [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
         __m128i product0 = _mm_maddubs_epi16(a, b);
         product0 = _mm_madd_epi16(product0, Ones128);
         acc = _mm_add_epi32(acc, product0);
       };
 
+      [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
+        __m128i product0 = _mm_maddubs_epi16(a0, b0);
+        __m128i product1 = _mm_maddubs_epi16(a1, b1);
+        product0 = _mm_adds_epi16(product0, product1);
+        product0 = _mm_madd_epi16(product0, Ones128);
+        acc = _mm_add_epi32(acc, product0);
+      };
+
       [[maybe_unused]] auto m128_add_dpbusd_epi32x4 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1,
                                                                         __m128i a2, __m128i b2, __m128i a3, __m128i b3) {
         __m128i product0 = _mm_maddubs_epi16(a0, b0);
@@ -231,33 +331,116 @@ namespace Stockfish::Eval::NNUE::Layers {
       using vec_t = __m512i;
       #define vec_setzero _mm512_setzero_si512
       #define vec_set_32 _mm512_set1_epi32
-      auto& vec_add_dpbusd_32 = m512_add_dpbusd_epi32;
-      auto& vec_add_dpbusd_32x4 = m512_add_dpbusd_epi32x4;
-      auto& vec_hadd = m512_hadd;
+      [[maybe_unused]] auto& vec_add_dpbusd_32 = m512_add_dpbusd_epi32;
+      [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m512_add_dpbusd_epi32x2;
+      [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m512_add_dpbusd_epi32x4;
+      [[maybe_unused]] auto& vec_hadd = m512_hadd;
+      [[maybe_unused]] auto& vec_haddx4 = m512_haddx4;
 #elif defined (USE_AVX2)
       using vec_t = __m256i;
       #define vec_setzero _mm256_setzero_si256
       #define vec_set_32 _mm256_set1_epi32
-      auto& vec_add_dpbusd_32 = m256_add_dpbusd_epi32;
-      auto& vec_add_dpbusd_32x4 = m256_add_dpbusd_epi32x4;
-      auto& vec_hadd = m256_hadd;
+      [[maybe_unused]] auto& vec_add_dpbusd_32 = m256_add_dpbusd_epi32;
+      [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m256_add_dpbusd_epi32x2;
+      [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m256_add_dpbusd_epi32x4;
+      [[maybe_unused]] auto& vec_hadd = m256_hadd;
+      [[maybe_unused]] auto& vec_haddx4 = m256_haddx4;
 #elif defined (USE_SSSE3)
       using vec_t = __m128i;
       #define vec_setzero _mm_setzero_si128
       #define vec_set_32 _mm_set1_epi32
-      auto& vec_add_dpbusd_32 = m128_add_dpbusd_epi32;
-      auto& vec_add_dpbusd_32x4 = m128_add_dpbusd_epi32x4;
-      auto& vec_hadd = m128_hadd;
+      [[maybe_unused]] auto& vec_add_dpbusd_32 = m128_add_dpbusd_epi32;
+      [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m128_add_dpbusd_epi32x2;
+      [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m128_add_dpbusd_epi32x4;
+      [[maybe_unused]] auto& vec_hadd = m128_hadd;
+      [[maybe_unused]] auto& vec_haddx4 = m128_haddx4;
 #endif
 
 #if defined (USE_SSSE3)
       const auto output = reinterpret_cast<OutputType*>(buffer);
       const auto inputVector = reinterpret_cast<const vec_t*>(input);
+#endif
+
+#if defined (USE_VNNI) || defined (USE_AVX512)
 
-      static_assert(OutputDimensions % OutputSimdWidth == 0 || OutputDimensions == 1);
+      static_assert(OutputDimensions == 1 || OutputDimensions % 4 == 0);
 
       // OutputDimensions is either 1 or a multiple of SimdWidth
       // because then it is also an input dimension.
+      if constexpr (OutputDimensions <= 8 && OutputDimensions != 1)
+      {
+          constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth;
+
+          static_assert(NumChunks % 2 == 0);
+
+          const auto input_vec = reinterpret_cast<const vec_t*>(input);
+          const auto bias_vec = reinterpret_cast<const __m128i*>(biases);
+          auto out_vec = reinterpret_cast<__m128i*>(output);
+
+          vec_t regs[OutputDimensions];
+          for (IndexType k = 0; k < OutputDimensions; ++k)
+            regs[k] = vec_setzero();
+
+          for (IndexType i = 0; i < NumChunks / 2; ++i)
+          {
+              const vec_t in0 = input_vec[i * 2 + 0];
+              const vec_t in1 = input_vec[i * 2 + 1];
+              for (IndexType k = 0; k < OutputDimensions; ++k)
+              {
+                  const vec_t w0 = reinterpret_cast<const vec_t*>(&weights[k * PaddedInputDimensions])[i * 2 + 0];
+                  const vec_t w1 = reinterpret_cast<const vec_t*>(&weights[k * PaddedInputDimensions])[i * 2 + 1];
+                  vec_add_dpbusd_32(regs[k], in0, w0);
+                  vec_add_dpbusd_32(regs[k], in1, w1);
+              }
+          }
+
+          for (IndexType k = 0; k < OutputDimensions / 4; ++k)
+          {
+            out_vec[k] = vec_haddx4(
+              regs[k * 4 + 0],
+              regs[k * 4 + 1],
+              regs[k * 4 + 2],
+              regs[k * 4 + 3],
+              bias_vec[k]
+            );
+          }
+      }
+      else if constexpr (InputDimensions == 8)
+      {
+          const auto input32 = reinterpret_cast<const std::int32_t*>(input);
+          __m256i* outptr = reinterpret_cast<__m256i*>(output);
+          std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
+
+          const __m256i in0 = _mm256_set1_epi32(input32[0]);
+          const __m256i in1 = _mm256_set1_epi32(input32[1]);
+          const auto col0 = reinterpret_cast<const __m256i*>(&weights[0]);
+          const auto col1 = reinterpret_cast<const __m256i*>(&weights[OutputDimensions * 4]);
+          for (IndexType j = 0; j * 8 < OutputDimensions; ++j)
+              m256_add_dpbusd_epi32x2(outptr[j], in0, col0[j], in1, col1[j]);
+      }
+      else
+
+#elif defined (USE_SSSE3)
+
+      if constexpr (OutputDimensions % OutputSimdWidth == 0 && InputDimensions == 8)
+      {
+          const auto input32 = reinterpret_cast<const std::int32_t*>(input);
+          vec_t* outptr = reinterpret_cast<vec_t*>(output);
+          std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
+
+          const vec_t in0 = vec_set_32(input32[0]);
+          const vec_t in1 = vec_set_32(input32[1]);
+          const auto col0 = reinterpret_cast<const vec_t*>(&weights[0]);
+          const auto col1 = reinterpret_cast<const vec_t*>(&weights[OutputDimensions * 4]);
+          for (IndexType j = 0; j * OutputSimdWidth < OutputDimensions; ++j)
+              vec_add_dpbusd_32x2(outptr[j], in0, col0[j], in1, col1[j]);
+      }
+      else
+
+#endif
+
+#if defined (USE_SSSE3)
+
       if constexpr (OutputDimensions % OutputSimdWidth == 0)
       {
           static_assert(InputDimensions % 16 == 0);
@@ -337,8 +520,8 @@ namespace Stockfish::Eval::NNUE::Layers {
 
 #if defined(USE_SSE2)
       // At least a multiple of 16, with SSE2.
-      static_assert(InputDimensions % SimdWidth == 0);
-      constexpr IndexType NumChunks = InputDimensions / SimdWidth;
+      static_assert(PaddedInputDimensions % SimdWidth == 0);
+      constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
       const __m128i Zeros = _mm_setzero_si128();
       const auto inputVector = reinterpret_cast<const __m128i*>(input);
 
@@ -349,8 +532,8 @@ namespace Stockfish::Eval::NNUE::Layers {
       const auto inputVector = reinterpret_cast<const __m64*>(input);
 
 #elif defined(USE_NEON)
-      static_assert(InputDimensions % SimdWidth == 0);
-      constexpr IndexType NumChunks = InputDimensions / SimdWidth;
+      static_assert(PaddedInputDimensions % SimdWidth == 0);
+      constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
       const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
 #endif
 
@@ -423,6 +606,13 @@ namespace Stockfish::Eval::NNUE::Layers {
       _mm_empty();
 #endif
 
+#endif
+
+#if (!defined (USE_SSSE3) && defined (USE_SSE2)) || defined (USE_NEON)
+      static_assert(SimdWidth <= 16, "Otherwise we run outside of the padding for the output.");
+      if constexpr (SimdWidth > OutputDimensions && OutputDimensions != 1)
+          for (IndexType i = OutputDimensions; i < SimdWidth; ++i)
+            output[i] = 0;
 #endif
 
       return output;