]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_feature_transformer.h
Update architecture to "SFNNv4". Update network to nn-6877cd24400e.nnue.
[stockfish] / src / nnue / nnue_feature_transformer.h
index f4024dce83b9d1e5f02ce87528a80d70a68874fc..fb867421f6c249b7513190ef5f2a48296726ac03 100644 (file)
@@ -183,7 +183,7 @@ namespace Stockfish::Eval::NNUE {
 
     // Number of input/output dimensions
     static constexpr IndexType InputDimensions = FeatureSet::Dimensions;
-    static constexpr IndexType OutputDimensions = HalfDimensions * 2;
+    static constexpr IndexType OutputDimensions = HalfDimensions;
 
     // Size of forward propagation buffer
     static constexpr std::size_t BufferSize =
@@ -191,7 +191,7 @@ namespace Stockfish::Eval::NNUE {
 
     // Hash value embedded in the evaluation file
     static constexpr std::uint32_t get_hash_value() {
-      return FeatureSet::HashValue ^ OutputDimensions;
+      return FeatureSet::HashValue ^ (OutputDimensions * 2);
     }
 
     // Read network parameters
@@ -229,142 +229,130 @@ namespace Stockfish::Eval::NNUE {
         ) / 2;
 
 
-  #if defined(USE_AVX512)
-
-      constexpr IndexType NumChunks = HalfDimensions / (SimdWidth * 2);
-      static_assert(HalfDimensions % (SimdWidth * 2) == 0);
-      const __m512i Control = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
-      const __m512i Zero = _mm512_setzero_si512();
-
       for (IndexType p = 0; p < 2; ++p)
       {
-          const IndexType offset = HalfDimensions * p;
-          auto out = reinterpret_cast<__m512i*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
-          {
-              __m512i sum0 = _mm512_load_si512(&reinterpret_cast<const __m512i*>
-                                              (accumulation[perspectives[p]])[j * 2 + 0]);
-              __m512i sum1 = _mm512_load_si512(&reinterpret_cast<const __m512i*>
-                                              (accumulation[perspectives[p]])[j * 2 + 1]);
+          const IndexType offset = (HalfDimensions / 2) * p;
 
-              _mm512_store_si512(&out[j], _mm512_permutexvar_epi64(Control,
-                                 _mm512_max_epi8(_mm512_packs_epi16(sum0, sum1), Zero)));
-          }
-      }
-      return psqt;
+#if defined(USE_AVX512)
 
-  #elif defined(USE_AVX2)
+          constexpr IndexType OutputChunkSize = 512 / 8;
+          static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
+          constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
 
-      constexpr IndexType NumChunks = HalfDimensions / SimdWidth;
-      constexpr int Control = 0b11011000;
-      const __m256i Zero = _mm256_setzero_si256();
+          const __m512i Zero = _mm512_setzero_si512();
+          const __m512i One = _mm512_set1_epi16(127);
+          const __m512i Control = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          auto out = reinterpret_cast<__m256i*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
+          const __m512i* in0 = reinterpret_cast<const __m512i*>(&(accumulation[perspectives[p]][0]));
+          const __m512i* in1 = reinterpret_cast<const __m512i*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
+                __m512i* out = reinterpret_cast<      __m512i*>(output + offset);
+
+          for (IndexType j = 0; j < NumOutputChunks; j += 1)
           {
-              __m256i sum0 = _mm256_load_si256(&reinterpret_cast<const __m256i*>
-                                              (accumulation[perspectives[p]])[j * 2 + 0]);
-              __m256i sum1 = _mm256_load_si256(&reinterpret_cast<const __m256i*>
-                                              (accumulation[perspectives[p]])[j * 2 + 1]);
+              const __m512i sum0a = _mm512_max_epi16(_mm512_min_epi16(in0[j * 2 + 0], One), Zero);
+              const __m512i sum0b = _mm512_max_epi16(_mm512_min_epi16(in0[j * 2 + 1], One), Zero);
+              const __m512i sum1a = _mm512_max_epi16(_mm512_min_epi16(in1[j * 2 + 0], One), Zero);
+              const __m512i sum1b = _mm512_max_epi16(_mm512_min_epi16(in1[j * 2 + 1], One), Zero);
 
-              _mm256_store_si256(&out[j], _mm256_permute4x64_epi64(
-                                 _mm256_max_epi8(_mm256_packs_epi16(sum0, sum1), Zero), Control));
+              const __m512i pa = _mm512_srli_epi16(_mm512_mullo_epi16(sum0a, sum1a), 7);
+              const __m512i pb = _mm512_srli_epi16(_mm512_mullo_epi16(sum0b, sum1b), 7);
+
+              out[j] = _mm512_permutexvar_epi64(Control, _mm512_packs_epi16(pa, pb));
           }
-      }
-      return psqt;
 
-  #elif defined(USE_SSE2)
+#elif defined(USE_AVX2)
 
-      #ifdef USE_SSE41
-      constexpr IndexType NumChunks = HalfDimensions / SimdWidth;
-      const __m128i Zero = _mm_setzero_si128();
-      #else
-      constexpr IndexType NumChunks = HalfDimensions / SimdWidth;
-      const __m128i k0x80s = _mm_set1_epi8(-128);
-      #endif
+          constexpr IndexType OutputChunkSize = 256 / 8;
+          static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
+          constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          auto out = reinterpret_cast<__m128i*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
+          const __m256i Zero = _mm256_setzero_si256();
+          const __m256i One = _mm256_set1_epi16(127);
+          constexpr int Control = 0b11011000;
+
+          const __m256i* in0 = reinterpret_cast<const __m256i*>(&(accumulation[perspectives[p]][0]));
+          const __m256i* in1 = reinterpret_cast<const __m256i*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
+                __m256i* out = reinterpret_cast<      __m256i*>(output + offset);
+
+          for (IndexType j = 0; j < NumOutputChunks; j += 1)
           {
-              __m128i sum0 = _mm_load_si128(&reinterpret_cast<const __m128i*>
-                                           (accumulation[perspectives[p]])[j * 2 + 0]);
-              __m128i sum1 = _mm_load_si128(&reinterpret_cast<const __m128i*>
-                                           (accumulation[perspectives[p]])[j * 2 + 1]);
-              const __m128i packedbytes = _mm_packs_epi16(sum0, sum1);
-
-              #ifdef USE_SSE41
-              _mm_store_si128(&out[j], _mm_max_epi8(packedbytes, Zero));
-              #else
-              _mm_store_si128(&out[j], _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s));
-              #endif
+              const __m256i sum0a = _mm256_max_epi16(_mm256_min_epi16(in0[j * 2 + 0], One), Zero);
+              const __m256i sum0b = _mm256_max_epi16(_mm256_min_epi16(in0[j * 2 + 1], One), Zero);
+              const __m256i sum1a = _mm256_max_epi16(_mm256_min_epi16(in1[j * 2 + 0], One), Zero);
+              const __m256i sum1b = _mm256_max_epi16(_mm256_min_epi16(in1[j * 2 + 1], One), Zero);
+
+              const __m256i pa = _mm256_srli_epi16(_mm256_mullo_epi16(sum0a, sum1a), 7);
+              const __m256i pb = _mm256_srli_epi16(_mm256_mullo_epi16(sum0b, sum1b), 7);
+
+              out[j] = _mm256_permute4x64_epi64(_mm256_packs_epi16(pa, pb), Control);
           }
-      }
-      return psqt;
 
-  #elif defined(USE_MMX)
+#elif defined(USE_SSE2)
 
-      constexpr IndexType NumChunks = HalfDimensions / SimdWidth;
-      const __m64 k0x80s = _mm_set1_pi8(-128);
+          constexpr IndexType OutputChunkSize = 128 / 8;
+          static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
+          constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          auto out = reinterpret_cast<__m64*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
+          const __m128i Zero = _mm_setzero_si128();
+          const __m128i One = _mm_set1_epi16(127);
+
+          const __m128i* in0 = reinterpret_cast<const __m128i*>(&(accumulation[perspectives[p]][0]));
+          const __m128i* in1 = reinterpret_cast<const __m128i*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
+                __m128i* out = reinterpret_cast<      __m128i*>(output + offset);
+
+          for (IndexType j = 0; j < NumOutputChunks; j += 1)
           {
-              __m64 sum0 = *(&reinterpret_cast<const __m64*>(accumulation[perspectives[p]])[j * 2 + 0]);
-              __m64 sum1 = *(&reinterpret_cast<const __m64*>(accumulation[perspectives[p]])[j * 2 + 1]);
-              const __m64 packedbytes = _mm_packs_pi16(sum0, sum1);
-              out[j] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s);
+              const __m128i sum0a = _mm_max_epi16(_mm_min_epi16(in0[j * 2 + 0], One), Zero);
+              const __m128i sum0b = _mm_max_epi16(_mm_min_epi16(in0[j * 2 + 1], One), Zero);
+              const __m128i sum1a = _mm_max_epi16(_mm_min_epi16(in1[j * 2 + 0], One), Zero);
+              const __m128i sum1b = _mm_max_epi16(_mm_min_epi16(in1[j * 2 + 1], One), Zero);
+
+              const __m128i pa = _mm_srli_epi16(_mm_mullo_epi16(sum0a, sum1a), 7);
+              const __m128i pb = _mm_srli_epi16(_mm_mullo_epi16(sum0b, sum1b), 7);
+
+              out[j] = _mm_packs_epi16(pa, pb);
           }
-      }
-      _mm_empty();
-      return psqt;
 
-  #elif defined(USE_NEON)
+#elif defined(USE_NEON)
 
-      constexpr IndexType NumChunks = HalfDimensions / (SimdWidth / 2);
-      const int8x8_t Zero = {0};
+          constexpr IndexType OutputChunkSize = 128 / 8;
+          static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
+          constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          const auto out = reinterpret_cast<int8x8_t*>(&output[offset]);
+          const int16x8_t Zero = vdupq_n_s16(0);
+          const int16x8_t One  = vdupq_n_s16(127);
 
-          constexpr IndexType UnrollFactor = 16;
-          static_assert(UnrollFactor % UnrollFactor == 0);
-          for (IndexType j = 0; j < NumChunks; j += UnrollFactor)
+          const int16x8_t* in0 = reinterpret_cast<const int16x8_t*>(&(accumulation[perspectives[p]][0]));
+          const int16x8_t* in1 = reinterpret_cast<const int16x8_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
+                int8x16_t* out = reinterpret_cast<      int8x16_t*>(output + offset);
+
+          for (IndexType j = 0; j < NumOutputChunks; j += 1)
           {
-              int16x8_t sums[UnrollFactor];
-              for (IndexType i = 0; i < UnrollFactor; ++i)
-                sums[i] = reinterpret_cast<const int16x8_t*>(accumulation[perspectives[p]])[j+i];
+              const int16x8_t sum0a = vmaxq_s16(vminq_s16(in0[j * 2 + 0], One), Zero);
+              const int16x8_t sum0b = vmaxq_s16(vminq_s16(in0[j * 2 + 1], One), Zero);
+              const int16x8_t sum1a = vmaxq_s16(vminq_s16(in1[j * 2 + 0], One), Zero);
+              const int16x8_t sum1b = vmaxq_s16(vminq_s16(in1[j * 2 + 1], One), Zero);
+
+              const int8x8_t pa = vshrn_n_s16(vmulq_s16(sum0a, sum1a), 7);
+              const int8x8_t pb = vshrn_n_s16(vmulq_s16(sum0b, sum1b), 7);
 
-              for (IndexType i = 0; i < UnrollFactor; ++i)
-                out[j+i] = vmax_s8(vqmovn_s16(sums[i]), Zero);
+              out[j] = vcombine_s8(pa, pb);
           }
-      }
-      return psqt;
 
-  #else
+#else
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          for (IndexType j = 0; j < HalfDimensions; ++j)
-          {
-              BiasType sum = accumulation[perspectives[p]][j];
-              output[offset + j] = static_cast<OutputType>(std::max<int>(0, std::min<int>(127, sum)));
+          for (IndexType j = 0; j < HalfDimensions / 2; ++j) {
+              BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];
+              BiasType sum1 = accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];
+              sum0 = std::max<int>(0, std::min<int>(127, sum0));
+              sum1 = std::max<int>(0, std::min<int>(127, sum1));
+              output[offset + j] = static_cast<OutputType>(sum0 * sum1 / 128);
           }
+
+#endif
       }
-      return psqt;
 
-  #endif
+      return psqt;
 
    } // end of function transform()