]> git.sesse.net Git - stockfish/commitdiff
AVX-512 for smaller affine and feature transforms.
authorTomasz Sobczyk <tomasz.sobczyk1997@gmail.com>
Tue, 3 Nov 2020 21:49:10 +0000 (22:49 +0100)
committerJoost VandeVondele <Joost.VandeVondele@gmail.com>
Sat, 7 Nov 2020 15:49:49 +0000 (16:49 +0100)
For the feature transformer the code is analogical to AVX2 since there was room for easy adaptation of wider simd registers.

For the smaller affine transforms that have 32 byte stride we keep 2 columns in one zmm register. We also unroll more aggressively so that in the end we have to do 16 parallel horizontal additions on ymm slices each consisting of 4 32-bit integers. The slices are embedded in 8 zmm registers.

These changes provide about 1.5% speedup for AVX-512 builds.

Closes https://github.com/official-stockfish/Stockfish/pull/3218

No functional change.

src/nnue/layers/affine_transform.h
src/nnue/nnue_feature_transformer.h

index f0292e453c14237e59cd86717c06158103308bbe..47c9c488b0c06ab137e532157163644c3d37d6af 100644 (file)
@@ -83,7 +83,21 @@ namespace Eval::NNUE::Layers {
         return _mm512_reduce_add_epi32(sum) + bias;
       };
 
         return _mm512_reduce_add_epi32(sum) + bias;
       };
 
-      [[maybe_unused]] auto m512_haddx4 = [](__m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
+      // This function takes
+      //   sum0 = [xmm0a, xmm0b, xmm0c, xmm0d]
+      //   sum1 = [xmm1a, xmm1b, xmm1c, xmm1d]
+      //   sum2 = [xmm2a, xmm2b, xmm2c, xmm2d]
+      //   sum3 = [xmm3a, xmm3b, xmm3c, xmm3d]
+      // and returns
+      //   ret = [
+      //     reduce_add_epi32(xmm0a), reduce_add_epi32(xmm1a), reduce_add_epi32(xmm2a), reduce_add_epi32(xmm3a),
+      //     reduce_add_epi32(xmm0b), reduce_add_epi32(xmm1b), reduce_add_epi32(xmm2b), reduce_add_epi32(xmm3b),
+      //     reduce_add_epi32(xmm0c), reduce_add_epi32(xmm1c), reduce_add_epi32(xmm2c), reduce_add_epi32(xmm3c),
+      //     reduce_add_epi32(xmm0d), reduce_add_epi32(xmm1d), reduce_add_epi32(xmm2d), reduce_add_epi32(xmm3d)
+      //   ]
+      [[maybe_unused]] auto m512_hadd128x16_interleave = [](
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
+
         __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
         __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
 
         __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
         __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
 
@@ -96,7 +110,13 @@ namespace Eval::NNUE::Layers {
         __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
         __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
 
         __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
         __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
 
-        __m512i sum = _mm512_add_epi32(sum0123a, sum0123b);
+        return _mm512_add_epi32(sum0123a, sum0123b);
+      };
+
+      [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
+
+        __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
 
         __m256i sum256lo = _mm512_castsi512_si256(sum);
         __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
 
         __m256i sum256lo = _mm512_castsi512_si256(sum);
         __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
@@ -109,6 +129,58 @@ namespace Eval::NNUE::Layers {
         return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
       };
 
         return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
       };
 
+      [[maybe_unused]] auto m512_haddx8 = [m512_hadd128x16_interleave](
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
+        __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m256i bias) -> __m256i {
+
+        __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
+        __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
+
+        __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
+        __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
+        __m512i x = _mm512_add_epi32(
+          _mm512_permutex2var_epi64(suma, indices0, sumb),
+          _mm512_permutex2var_epi64(suma, indices1, sumb));
+
+        __m256i sum256lo = _mm512_castsi512_si256(x);
+        __m256i sum256hi = _mm512_extracti64x4_epi64(x, 1);
+
+        return _mm256_add_epi32(_mm256_add_epi32(sum256lo, sum256hi), bias);
+      };
+
+      [[maybe_unused]] auto m512_hadd256x8 =[m512_hadd128x16_interleave](
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m256i bias) -> __m256i {
+
+        __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
+
+        __m512i indices = _mm512_setr_epi32(
+          0, 4, 8, 12, 2, 6, 10, 14,
+          1, 5, 9, 13, 3, 7, 11, 15);
+        sum = _mm512_permutexvar_epi32(indices, sum);
+
+        __m256i sum256lo = _mm512_castsi512_si256(sum);
+        __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
+
+        return _mm256_add_epi32(_mm256_hadd_epi32(sum256lo, sum256hi), bias);
+      };
+
+      [[maybe_unused]] auto m512_hadd256x16 = [m512_hadd128x16_interleave](
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
+        __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m512i bias) -> __m512i {
+
+        __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
+        __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
+
+        __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
+        __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
+        __m512i x = _mm512_add_epi32(
+          _mm512_permutex2var_epi64(suma, indices0, sumb),
+          _mm512_permutex2var_epi64(suma, indices1, sumb));
+
+        __m512i indices = _mm512_setr_epi32(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
+        return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
+      };
+
       [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
 #if defined (USE_VNNI)
         acc = _mm512_dpbusd_epi32(acc, a, b);
       [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
 #if defined (USE_VNNI)
         acc = _mm512_dpbusd_epi32(acc, a, b);
@@ -205,7 +277,58 @@ namespace Eval::NNUE::Layers {
 
       // kOutputDimensions is either 1 or a multiple of kSimdWidth
       // because then it is also an input dimension.
 
       // kOutputDimensions is either 1 or a multiple of kSimdWidth
       // because then it is also an input dimension.
-      if constexpr (kOutputDimensions % 4 == 0)
+      if constexpr (kOutputDimensions % 16 == 0 && kNumChunks256 == 1)
+      {
+        for (IndexType i = 0; i < kOutputDimensions; i += 16)
+        {
+          const IndexType offset01a = (i + 0) * kPaddedInputDimensions;
+          const IndexType offset23a = (i + 2) * kPaddedInputDimensions;
+          const IndexType offset45a = (i + 4) * kPaddedInputDimensions;
+          const IndexType offset67a = (i + 6) * kPaddedInputDimensions;
+          const IndexType offset01b = (i + 8) * kPaddedInputDimensions;
+          const IndexType offset23b = (i + 10) * kPaddedInputDimensions;
+          const IndexType offset45b = (i + 12) * kPaddedInputDimensions;
+          const IndexType offset67b = (i + 14) * kPaddedInputDimensions;
+
+          const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
+          __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
+
+          __m512i sum01a = _mm512_setzero_si512();
+          __m512i sum23a = _mm512_setzero_si512();
+          __m512i sum45a = _mm512_setzero_si512();
+          __m512i sum67a = _mm512_setzero_si512();
+          __m512i sum01b = _mm512_setzero_si512();
+          __m512i sum23b = _mm512_setzero_si512();
+          __m512i sum45b = _mm512_setzero_si512();
+          __m512i sum67b = _mm512_setzero_si512();
+
+          const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
+          const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
+          const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
+          const auto row67a = *reinterpret_cast<const __m512i*>(&weights_[offset67a]);
+          const auto row01b = *reinterpret_cast<const __m512i*>(&weights_[offset01b]);
+          const auto row23b = *reinterpret_cast<const __m512i*>(&weights_[offset23b]);
+          const auto row45b = *reinterpret_cast<const __m512i*>(&weights_[offset45b]);
+          const auto row67b = *reinterpret_cast<const __m512i*>(&weights_[offset67b]);
+
+          const __m256i in256 = input_vector256[0];
+          const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
+
+          m512_add_dpbusd_epi32(sum01a, in, row01a);
+          m512_add_dpbusd_epi32(sum23a, in, row23a);
+          m512_add_dpbusd_epi32(sum45a, in, row45a);
+          m512_add_dpbusd_epi32(sum67a, in, row67a);
+          m512_add_dpbusd_epi32(sum01b, in, row01b);
+          m512_add_dpbusd_epi32(sum23b, in, row23b);
+          m512_add_dpbusd_epi32(sum45b, in, row45b);
+          m512_add_dpbusd_epi32(sum67b, in, row67b);
+
+          *outptr = m512_hadd256x16(
+            sum01a, sum23a, sum45a, sum67a,
+            sum01b, sum23b, sum45b, sum67b, bias);
+        }
+      }
+      else if constexpr (kOutputDimensions % 4 == 0)
       {
         for (IndexType i = 0; i < kOutputDimensions; i += 4)
         {
       {
         for (IndexType i = 0; i < kOutputDimensions; i += 4)
         {
index c3f012e412ec2415b726a469d080378b8049526d..f49777b50bbe3f6809acd358399ae9b5dbf1bda1 100644 (file)
@@ -127,7 +127,13 @@ namespace Eval::NNUE {
 
       const auto& accumulation = pos.state()->accumulator.accumulation;
 
 
       const auto& accumulation = pos.state()->accumulator.accumulation;
 
-  #if defined(USE_AVX2)
+  #if defined(USE_AVX512)
+      constexpr IndexType kNumChunks = kHalfDimensions / (kSimdWidth * 2);
+      static_assert(kHalfDimensions % (kSimdWidth * 2) == 0);
+      const __m512i kControl = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
+      const __m512i kZero = _mm512_setzero_si512();
+
+  #elif defined(USE_AVX2)
       constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
       constexpr int kControl = 0b11011000;
       const __m256i kZero = _mm256_setzero_si256();
       constexpr IndexType kNumChunks = kHalfDimensions / kSimdWidth;
       constexpr int kControl = 0b11011000;
       const __m256i kZero = _mm256_setzero_si256();
@@ -154,13 +160,24 @@ namespace Eval::NNUE {
       for (IndexType p = 0; p < 2; ++p) {
         const IndexType offset = kHalfDimensions * p;
 
       for (IndexType p = 0; p < 2; ++p) {
         const IndexType offset = kHalfDimensions * p;
 
-  #if defined(USE_AVX2)
+  #if defined(USE_AVX512)
+        auto out = reinterpret_cast<__m512i*>(&output[offset]);
+        for (IndexType j = 0; j < kNumChunks; ++j) {
+          __m512i sum0 = _mm512_load_si512(
+              &reinterpret_cast<const __m512i*>(accumulation[perspectives[p]][0])[j * 2 + 0]);
+          __m512i sum1 = _mm512_load_si512(
+              &reinterpret_cast<const __m512i*>(accumulation[perspectives[p]][0])[j * 2 + 1]);
+          _mm512_store_si512(&out[j], _mm512_permutexvar_epi64(kControl,
+              _mm512_max_epi8(_mm512_packs_epi16(sum0, sum1), kZero)));
+        }
+
+  #elif defined(USE_AVX2)
         auto out = reinterpret_cast<__m256i*>(&output[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
           __m256i sum0 = _mm256_load_si256(
               &reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 0]);
           __m256i sum1 = _mm256_load_si256(
         auto out = reinterpret_cast<__m256i*>(&output[offset]);
         for (IndexType j = 0; j < kNumChunks; ++j) {
           __m256i sum0 = _mm256_load_si256(
               &reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 0]);
           __m256i sum1 = _mm256_load_si256(
-            &reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 1]);
+              &reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 1]);
           _mm256_store_si256(&out[j], _mm256_permute4x64_epi64(_mm256_max_epi8(
               _mm256_packs_epi16(sum0, sum1), kZero), kControl));
         }
           _mm256_store_si256(&out[j], _mm256_permute4x64_epi64(_mm256_max_epi8(
               _mm256_packs_epi16(sum0, sum1), kZero), kControl));
         }
@@ -177,9 +194,9 @@ namespace Eval::NNUE {
           _mm_store_si128(&out[j],
 
   #ifdef USE_SSE41
           _mm_store_si128(&out[j],
 
   #ifdef USE_SSE41
-            _mm_max_epi8(packedbytes, kZero)
+              _mm_max_epi8(packedbytes, kZero)
   #else
   #else
-            _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
+              _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
   #endif
 
           );
   #endif
 
           );