]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_feature_transformer.h
Clean up and simplify some nnue code.
[stockfish] / src / nnue / nnue_feature_transformer.h
index 59a965ac769450f11c4fde517ac09d731c0749f6..855980182fccafd731879a91b6458f93e3c6570d 100644 (file)
@@ -1,6 +1,6 @@
 /*
   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
-  Copyright (C) 2004-2021 The Stockfish developers (see AUTHORS file)
+  Copyright (C) 2004-2022 The Stockfish developers (see AUTHORS file)
 
   Stockfish is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
@@ -123,8 +123,10 @@ namespace Stockfish::Eval::NNUE {
       // We use __m* types as template arguments, which causes GCC to emit warnings
       // about losing some attribute information. This is irrelevant to us as we
       // only take their size, so the following pragma are harmless.
+      #if defined(__GNUC__)
       #pragma GCC diagnostic push
       #pragma GCC diagnostic ignored "-Wignored-attributes"
+      #endif
 
       template <typename SIMDRegisterType,
                 typename LaneType,
@@ -156,9 +158,9 @@ namespace Stockfish::Eval::NNUE {
 
       static constexpr int NumRegs     = BestRegisterCount<vec_t, WeightType, TransformedFeatureDimensions, NumRegistersSIMD>();
       static constexpr int NumPsqtRegs = BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();
-
+      #if defined(__GNUC__)
       #pragma GCC diagnostic pop
-
+      #endif
   #endif
 
 
@@ -183,7 +185,7 @@ namespace Stockfish::Eval::NNUE {
 
     // Number of input/output dimensions
     static constexpr IndexType InputDimensions = FeatureSet::Dimensions;
-    static constexpr IndexType OutputDimensions = HalfDimensions * 2;
+    static constexpr IndexType OutputDimensions = HalfDimensions;
 
     // Size of forward propagation buffer
     static constexpr std::size_t BufferSize =
@@ -191,7 +193,7 @@ namespace Stockfish::Eval::NNUE {
 
     // Hash value embedded in the evaluation file
     static constexpr std::uint32_t get_hash_value() {
-      return FeatureSet::HashValue ^ OutputDimensions;
+      return FeatureSet::HashValue ^ (OutputDimensions * 2);
     }
 
     // Read network parameters
@@ -229,135 +231,130 @@ namespace Stockfish::Eval::NNUE {
         ) / 2;
 
 
-  #if defined(USE_AVX512)
-
-      constexpr IndexType NumChunks = HalfDimensions / (SimdWidth * 2);
-      static_assert(HalfDimensions % (SimdWidth * 2) == 0);
-      const __m512i Control = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
-      const __m512i Zero = _mm512_setzero_si512();
-
       for (IndexType p = 0; p < 2; ++p)
       {
-          const IndexType offset = HalfDimensions * p;
-          auto out = reinterpret_cast<__m512i*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
-          {
-              __m512i sum0 = _mm512_load_si512(&reinterpret_cast<const __m512i*>
-                                              (accumulation[perspectives[p]])[j * 2 + 0]);
-              __m512i sum1 = _mm512_load_si512(&reinterpret_cast<const __m512i*>
-                                              (accumulation[perspectives[p]])[j * 2 + 1]);
+          const IndexType offset = (HalfDimensions / 2) * p;
 
-              _mm512_store_si512(&out[j], _mm512_permutexvar_epi64(Control,
-                                 _mm512_max_epi8(_mm512_packs_epi16(sum0, sum1), Zero)));
-          }
-      }
-      return psqt;
+#if defined(USE_AVX512)
 
-  #elif defined(USE_AVX2)
+          constexpr IndexType OutputChunkSize = 512 / 8;
+          static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
+          constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
 
-      constexpr IndexType NumChunks = HalfDimensions / SimdWidth;
-      constexpr int Control = 0b11011000;
-      const __m256i Zero = _mm256_setzero_si256();
+          const __m512i Zero = _mm512_setzero_si512();
+          const __m512i One = _mm512_set1_epi16(127);
+          const __m512i Control = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          auto out = reinterpret_cast<__m256i*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
+          const __m512i* in0 = reinterpret_cast<const __m512i*>(&(accumulation[perspectives[p]][0]));
+          const __m512i* in1 = reinterpret_cast<const __m512i*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
+                __m512i* out = reinterpret_cast<      __m512i*>(output + offset);
+
+          for (IndexType j = 0; j < NumOutputChunks; j += 1)
           {
-              __m256i sum0 = _mm256_load_si256(&reinterpret_cast<const __m256i*>
-                                              (accumulation[perspectives[p]])[j * 2 + 0]);
-              __m256i sum1 = _mm256_load_si256(&reinterpret_cast<const __m256i*>
-                                              (accumulation[perspectives[p]])[j * 2 + 1]);
+              const __m512i sum0a = _mm512_max_epi16(_mm512_min_epi16(in0[j * 2 + 0], One), Zero);
+              const __m512i sum0b = _mm512_max_epi16(_mm512_min_epi16(in0[j * 2 + 1], One), Zero);
+              const __m512i sum1a = _mm512_max_epi16(_mm512_min_epi16(in1[j * 2 + 0], One), Zero);
+              const __m512i sum1b = _mm512_max_epi16(_mm512_min_epi16(in1[j * 2 + 1], One), Zero);
 
-              _mm256_store_si256(&out[j], _mm256_permute4x64_epi64(
-                                 _mm256_max_epi8(_mm256_packs_epi16(sum0, sum1), Zero), Control));
+              const __m512i pa = _mm512_srli_epi16(_mm512_mullo_epi16(sum0a, sum1a), 7);
+              const __m512i pb = _mm512_srli_epi16(_mm512_mullo_epi16(sum0b, sum1b), 7);
+
+              out[j] = _mm512_permutexvar_epi64(Control, _mm512_packs_epi16(pa, pb));
           }
-      }
-      return psqt;
 
-  #elif defined(USE_SSE2)
+#elif defined(USE_AVX2)
 
-      #ifdef USE_SSE41
-      constexpr IndexType NumChunks = HalfDimensions / SimdWidth;
-      const __m128i Zero = _mm_setzero_si128();
-      #else
-      constexpr IndexType NumChunks = HalfDimensions / SimdWidth;
-      const __m128i k0x80s = _mm_set1_epi8(-128);
-      #endif
+          constexpr IndexType OutputChunkSize = 256 / 8;
+          static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
+          constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          auto out = reinterpret_cast<__m128i*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
+          const __m256i Zero = _mm256_setzero_si256();
+          const __m256i One = _mm256_set1_epi16(127);
+          constexpr int Control = 0b11011000;
+
+          const __m256i* in0 = reinterpret_cast<const __m256i*>(&(accumulation[perspectives[p]][0]));
+          const __m256i* in1 = reinterpret_cast<const __m256i*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
+                __m256i* out = reinterpret_cast<      __m256i*>(output + offset);
+
+          for (IndexType j = 0; j < NumOutputChunks; j += 1)
           {
-              __m128i sum0 = _mm_load_si128(&reinterpret_cast<const __m128i*>
-                                           (accumulation[perspectives[p]])[j * 2 + 0]);
-              __m128i sum1 = _mm_load_si128(&reinterpret_cast<const __m128i*>
-                                           (accumulation[perspectives[p]])[j * 2 + 1]);
-              const __m128i packedbytes = _mm_packs_epi16(sum0, sum1);
-
-              #ifdef USE_SSE41
-              _mm_store_si128(&out[j], _mm_max_epi8(packedbytes, Zero));
-              #else
-              _mm_store_si128(&out[j], _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s));
-              #endif
+              const __m256i sum0a = _mm256_max_epi16(_mm256_min_epi16(in0[j * 2 + 0], One), Zero);
+              const __m256i sum0b = _mm256_max_epi16(_mm256_min_epi16(in0[j * 2 + 1], One), Zero);
+              const __m256i sum1a = _mm256_max_epi16(_mm256_min_epi16(in1[j * 2 + 0], One), Zero);
+              const __m256i sum1b = _mm256_max_epi16(_mm256_min_epi16(in1[j * 2 + 1], One), Zero);
+
+              const __m256i pa = _mm256_srli_epi16(_mm256_mullo_epi16(sum0a, sum1a), 7);
+              const __m256i pb = _mm256_srli_epi16(_mm256_mullo_epi16(sum0b, sum1b), 7);
+
+              out[j] = _mm256_permute4x64_epi64(_mm256_packs_epi16(pa, pb), Control);
           }
-      }
-      return psqt;
 
-  #elif defined(USE_MMX)
+#elif defined(USE_SSE2)
 
-      constexpr IndexType NumChunks = HalfDimensions / SimdWidth;
-      const __m64 k0x80s = _mm_set1_pi8(-128);
+          constexpr IndexType OutputChunkSize = 128 / 8;
+          static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
+          constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          auto out = reinterpret_cast<__m64*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
+          const __m128i Zero = _mm_setzero_si128();
+          const __m128i One = _mm_set1_epi16(127);
+
+          const __m128i* in0 = reinterpret_cast<const __m128i*>(&(accumulation[perspectives[p]][0]));
+          const __m128i* in1 = reinterpret_cast<const __m128i*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
+                __m128i* out = reinterpret_cast<      __m128i*>(output + offset);
+
+          for (IndexType j = 0; j < NumOutputChunks; j += 1)
           {
-              __m64 sum0 = *(&reinterpret_cast<const __m64*>(accumulation[perspectives[p]])[j * 2 + 0]);
-              __m64 sum1 = *(&reinterpret_cast<const __m64*>(accumulation[perspectives[p]])[j * 2 + 1]);
-              const __m64 packedbytes = _mm_packs_pi16(sum0, sum1);
-              out[j] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s);
+              const __m128i sum0a = _mm_max_epi16(_mm_min_epi16(in0[j * 2 + 0], One), Zero);
+              const __m128i sum0b = _mm_max_epi16(_mm_min_epi16(in0[j * 2 + 1], One), Zero);
+              const __m128i sum1a = _mm_max_epi16(_mm_min_epi16(in1[j * 2 + 0], One), Zero);
+              const __m128i sum1b = _mm_max_epi16(_mm_min_epi16(in1[j * 2 + 1], One), Zero);
+
+              const __m128i pa = _mm_srli_epi16(_mm_mullo_epi16(sum0a, sum1a), 7);
+              const __m128i pb = _mm_srli_epi16(_mm_mullo_epi16(sum0b, sum1b), 7);
+
+              out[j] = _mm_packs_epi16(pa, pb);
           }
-      }
-      _mm_empty();
-      return psqt;
 
-  #elif defined(USE_NEON)
+#elif defined(USE_NEON)
 
-      constexpr IndexType NumChunks = HalfDimensions / (SimdWidth / 2);
-      const int8x8_t Zero = {0};
+          constexpr IndexType OutputChunkSize = 128 / 8;
+          static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
+          constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          const auto out = reinterpret_cast<int8x8_t*>(&output[offset]);
-          for (IndexType j = 0; j < NumChunks; ++j)
+          const int16x8_t Zero = vdupq_n_s16(0);
+          const int16x8_t One  = vdupq_n_s16(127);
+
+          const int16x8_t* in0 = reinterpret_cast<const int16x8_t*>(&(accumulation[perspectives[p]][0]));
+          const int16x8_t* in1 = reinterpret_cast<const int16x8_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
+                int8x16_t* out = reinterpret_cast<      int8x16_t*>(output + offset);
+
+          for (IndexType j = 0; j < NumOutputChunks; j += 1)
           {
-              int16x8_t sum = reinterpret_cast<const int16x8_t*>(accumulation[perspectives[p]])[j];
-              out[j] = vmax_s8(vqmovn_s16(sum), Zero);
+              const int16x8_t sum0a = vmaxq_s16(vminq_s16(in0[j * 2 + 0], One), Zero);
+              const int16x8_t sum0b = vmaxq_s16(vminq_s16(in0[j * 2 + 1], One), Zero);
+              const int16x8_t sum1a = vmaxq_s16(vminq_s16(in1[j * 2 + 0], One), Zero);
+              const int16x8_t sum1b = vmaxq_s16(vminq_s16(in1[j * 2 + 1], One), Zero);
+
+              const int8x8_t pa = vshrn_n_s16(vmulq_s16(sum0a, sum1a), 7);
+              const int8x8_t pb = vshrn_n_s16(vmulq_s16(sum0b, sum1b), 7);
+
+              out[j] = vcombine_s8(pa, pb);
           }
-      }
-      return psqt;
 
-  #else
+#else
 
-      for (IndexType p = 0; p < 2; ++p)
-      {
-          const IndexType offset = HalfDimensions * p;
-          for (IndexType j = 0; j < HalfDimensions; ++j)
-          {
-              BiasType sum = accumulation[perspectives[p]][j];
-              output[offset + j] = static_cast<OutputType>(std::max<int>(0, std::min<int>(127, sum)));
+          for (IndexType j = 0; j < HalfDimensions / 2; ++j) {
+              BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];
+              BiasType sum1 = accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];
+              sum0 = std::max<int>(0, std::min<int>(127, sum0));
+              sum1 = std::max<int>(0, std::min<int>(127, sum1));
+              output[offset + j] = static_cast<OutputType>(sum0 * sum1 / 128);
           }
+
+#endif
       }
-      return psqt;
 
-  #endif
+      return psqt;
 
    } // end of function transform()
 
@@ -370,7 +367,6 @@ namespace Stockfish::Eval::NNUE {
       // That might depend on the feature set and generally relies on the
       // feature set's update cost calculation to be correct and never
       // allow updates with more added/removed features than MaxActiveDimensions.
-      using IndexList = ValueList<IndexType, FeatureSet::MaxActiveDimensions>;
 
   #ifdef VECTOR
       // Gcc-10.2 unnecessarily spills AVX2 registers if this array
@@ -404,12 +400,12 @@ namespace Stockfish::Eval::NNUE {
 
         // Gather all features to be updated.
         const Square ksq = pos.square<KING>(perspective);
-        IndexList removed[2], added[2];
+        FeatureSet::IndexList removed[2], added[2];
         FeatureSet::append_changed_indices(
-          ksq, next, perspective, removed[0], added[0]);
+          ksq, next->dirtyPiece, perspective, removed[0], added[0]);
         for (StateInfo *st2 = pos.state(); st2 != next; st2 = st2->previous)
           FeatureSet::append_changed_indices(
-            ksq, st2, perspective, removed[1], added[1]);
+            ksq, st2->dirtyPiece, perspective, removed[1], added[1]);
 
         // Mark the accumulators as computed.
         next->accumulator.computed[perspective] = true;
@@ -534,7 +530,7 @@ namespace Stockfish::Eval::NNUE {
         // Refresh the accumulator
         auto& accumulator = pos.state()->accumulator;
         accumulator.computed[perspective] = true;
-        IndexList active;
+        FeatureSet::IndexList active;
         FeatureSet::append_active_indices(pos, perspective, active);
 
   #ifdef VECTOR