]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_feature_transformer.h
New NNUE architecture and net
[stockfish] / src / nnue / nnue_feature_transformer.h
index a4a8e98f9c5e8f579cea140b77126f9763184421..2c0a0c6d3134b61270a9a16a3a1ae199176d9d07 100644 (file)
@@ -35,45 +35,82 @@ namespace Stockfish::Eval::NNUE {
   // vector registers.
   #define VECTOR
 
+  static_assert(PSQTBuckets == 8, "Assumed by the current choice of constants.");
+
   #ifdef USE_AVX512
   typedef __m512i vec_t;
+  typedef __m256i psqt_vec_t;
   #define vec_load(a) _mm512_load_si512(a)
   #define vec_store(a,b) _mm512_store_si512(a,b)
   #define vec_add_16(a,b) _mm512_add_epi16(a,b)
   #define vec_sub_16(a,b) _mm512_sub_epi16(a,b)
+  #define vec_load_psqt(a) _mm256_load_si256(a)
+  #define vec_store_psqt(a,b) _mm256_store_si256(a,b)
+  #define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
+  #define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
+  #define vec_zero_psqt() _mm256_setzero_si256()
   static constexpr IndexType NumRegs = 8; // only 8 are needed
+  static constexpr IndexType NumPsqtRegs = 1;
 
   #elif USE_AVX2
   typedef __m256i vec_t;
+  typedef __m256i psqt_vec_t;
   #define vec_load(a) _mm256_load_si256(a)
   #define vec_store(a,b) _mm256_store_si256(a,b)
   #define vec_add_16(a,b) _mm256_add_epi16(a,b)
   #define vec_sub_16(a,b) _mm256_sub_epi16(a,b)
+  #define vec_load_psqt(a) _mm256_load_si256(a)
+  #define vec_store_psqt(a,b) _mm256_store_si256(a,b)
+  #define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
+  #define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
+  #define vec_zero_psqt() _mm256_setzero_si256()
   static constexpr IndexType NumRegs = 16;
+  static constexpr IndexType NumPsqtRegs = 1;
 
   #elif USE_SSE2
   typedef __m128i vec_t;
+  typedef __m128i psqt_vec_t;
   #define vec_load(a) (*(a))
   #define vec_store(a,b) *(a)=(b)
   #define vec_add_16(a,b) _mm_add_epi16(a,b)
   #define vec_sub_16(a,b) _mm_sub_epi16(a,b)
+  #define vec_load_psqt(a) (*(a))
+  #define vec_store_psqt(a,b) *(a)=(b)
+  #define vec_add_psqt_32(a,b) _mm_add_epi32(a,b)
+  #define vec_sub_psqt_32(a,b) _mm_sub_epi32(a,b)
+  #define vec_zero_psqt() _mm_setzero_si128()
   static constexpr IndexType NumRegs = Is64Bit ? 16 : 8;
+  static constexpr IndexType NumPsqtRegs = 2;
 
   #elif USE_MMX
   typedef __m64 vec_t;
+  typedef std::int32_t psqt_vec_t;
   #define vec_load(a) (*(a))
   #define vec_store(a,b) *(a)=(b)
   #define vec_add_16(a,b) _mm_add_pi16(a,b)
   #define vec_sub_16(a,b) _mm_sub_pi16(a,b)
+  #define vec_load_psqt(a) (*(a))
+  #define vec_store_psqt(a,b) *(a)=(b)
+  #define vec_add_psqt_32(a,b) a+b
+  #define vec_sub_psqt_32(a,b) a-b
+  #define vec_zero_psqt() 0
   static constexpr IndexType NumRegs = 8;
+  static constexpr IndexType NumPsqtRegs = 8;
 
   #elif USE_NEON
   typedef int16x8_t vec_t;
+  typedef int32x4_t psqt_vec_t;
   #define vec_load(a) (*(a))
   #define vec_store(a,b) *(a)=(b)
   #define vec_add_16(a,b) vaddq_s16(a,b)
   #define vec_sub_16(a,b) vsubq_s16(a,b)
+  #define vec_load_psqt(a) (*(a))
+  #define vec_store_psqt(a,b) *(a)=(b)
+  #define vec_add_psqt_32(a,b) vaddq_s32(a,b)
+  #define vec_sub_psqt_32(a,b) vsubq_s32(a,b)
+  #define vec_zero_psqt() psqt_vec_t{0}
   static constexpr IndexType NumRegs = 16;
+  static constexpr IndexType NumPsqtRegs = 2;
 
   #else
   #undef VECTOR
@@ -87,9 +124,13 @@ namespace Stockfish::Eval::NNUE {
     // Number of output dimensions for one side
     static constexpr IndexType HalfDimensions = TransformedFeatureDimensions;
 
+    static constexpr int LazyThreshold = 1400;
+
     #ifdef VECTOR
     static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
+    static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;
     static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
+    static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
     #endif
 
    public:
@@ -115,6 +156,8 @@ namespace Stockfish::Eval::NNUE {
         biases[i] = read_little_endian<BiasType>(stream);
       for (std::size_t i = 0; i < HalfDimensions * InputDimensions; ++i)
         weights[i] = read_little_endian<WeightType>(stream);
+      for (std::size_t i = 0; i < PSQTBuckets * InputDimensions; ++i)
+        psqtWeights[i] = read_little_endian<PSQTWeightType>(stream);
       return !stream.fail();
     }
 
@@ -128,11 +171,21 @@ namespace Stockfish::Eval::NNUE {
     }
 
     // Convert input features
-    void transform(const Position& pos, OutputType* output) const {
+    std::pair<std::int32_t, bool> transform(const Position& pos, OutputType* output, int bucket) const {
       update_accumulator(pos, WHITE);
       update_accumulator(pos, BLACK);
 
+      const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()};
       const auto& accumulation = pos.state()->accumulator.accumulation;
+      const auto& psqtAccumulation = pos.state()->accumulator.psqtAccumulation;
+
+      const auto psqt = (
+            psqtAccumulation[static_cast<int>(perspectives[0])][bucket]
+          - psqtAccumulation[static_cast<int>(perspectives[1])][bucket]
+        ) / 2;
+
+      if (abs(psqt) > LazyThreshold * OutputScale)
+        return { psqt, true };
 
   #if defined(USE_AVX512)
       constexpr IndexType NumChunks = HalfDimensions / (SimdWidth * 2);
@@ -163,7 +216,6 @@ namespace Stockfish::Eval::NNUE {
       const int8x8_t Zero = {0};
   #endif
 
-      const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()};
       for (IndexType p = 0; p < 2; ++p) {
         const IndexType offset = HalfDimensions * p;
 
@@ -240,6 +292,8 @@ namespace Stockfish::Eval::NNUE {
   #if defined(USE_MMX)
       _mm_empty();
   #endif
+
+      return { psqt, false };
     }
 
    private:
@@ -255,6 +309,7 @@ namespace Stockfish::Eval::NNUE {
       // Gcc-10.2 unnecessarily spills AVX2 registers if this array
       // is defined in the VECTOR code below, once in each branch
       vec_t acc[NumRegs];
+      psqt_vec_t psqt[NumPsqtRegs];
   #endif
 
       // Look for a usable accumulator of an earlier position. We keep track
@@ -333,12 +388,52 @@ namespace Stockfish::Eval::NNUE {
           }
         }
 
+        for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
+        {
+          // Load accumulator
+          auto accTilePsqt = reinterpret_cast<psqt_vec_t*>(
+            &st->accumulator.psqtAccumulation[perspective][j * PsqtTileHeight]);
+          for (std::size_t k = 0; k < NumPsqtRegs; ++k)
+            psqt[k] = vec_load_psqt(&accTilePsqt[k]);
+
+          for (IndexType i = 0; states_to_update[i]; ++i)
+          {
+            // Difference calculation for the deactivated features
+            for (const auto index : removed[i])
+            {
+              const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
+              auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
+              for (std::size_t k = 0; k < NumPsqtRegs; ++k)
+                psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]);
+            }
+
+            // Difference calculation for the activated features
+            for (const auto index : added[i])
+            {
+              const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
+              auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
+              for (std::size_t k = 0; k < NumPsqtRegs; ++k)
+                psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
+            }
+
+            // Store accumulator
+            accTilePsqt = reinterpret_cast<psqt_vec_t*>(
+              &states_to_update[i]->accumulator.psqtAccumulation[perspective][j * PsqtTileHeight]);
+            for (std::size_t k = 0; k < NumPsqtRegs; ++k)
+              vec_store_psqt(&accTilePsqt[k], psqt[k]);
+          }
+        }
+
   #else
         for (IndexType i = 0; states_to_update[i]; ++i)
         {
           std::memcpy(states_to_update[i]->accumulator.accumulation[perspective],
               st->accumulator.accumulation[perspective],
               HalfDimensions * sizeof(BiasType));
+
+          for (std::size_t k = 0; k < PSQTBuckets; ++k)
+            states_to_update[i]->accumulator.psqtAccumulation[perspective][k] = st->accumulator.psqtAccumulation[perspective][k];
+
           st = states_to_update[i];
 
           // Difference calculation for the deactivated features
@@ -348,6 +443,9 @@ namespace Stockfish::Eval::NNUE {
 
             for (IndexType j = 0; j < HalfDimensions; ++j)
               st->accumulator.accumulation[perspective][j] -= weights[offset + j];
+
+            for (std::size_t k = 0; k < PSQTBuckets; ++k)
+              st->accumulator.psqtAccumulation[perspective][k] -= psqtWeights[index * PSQTBuckets + k];
           }
 
           // Difference calculation for the activated features
@@ -357,6 +455,9 @@ namespace Stockfish::Eval::NNUE {
 
             for (IndexType j = 0; j < HalfDimensions; ++j)
               st->accumulator.accumulation[perspective][j] += weights[offset + j];
+
+            for (std::size_t k = 0; k < PSQTBuckets; ++k)
+              st->accumulator.psqtAccumulation[perspective][k] += psqtWeights[index * PSQTBuckets + k];
           }
         }
   #endif
@@ -392,16 +493,42 @@ namespace Stockfish::Eval::NNUE {
             vec_store(&accTile[k], acc[k]);
         }
 
+        for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
+        {
+          for (std::size_t k = 0; k < NumPsqtRegs; ++k)
+            psqt[k] = vec_zero_psqt();
+
+          for (const auto index : active)
+          {
+            const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
+            auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
+
+            for (std::size_t k = 0; k < NumPsqtRegs; ++k)
+              psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
+          }
+
+          auto accTilePsqt = reinterpret_cast<psqt_vec_t*>(
+            &accumulator.psqtAccumulation[perspective][j * PsqtTileHeight]);
+          for (std::size_t k = 0; k < NumPsqtRegs; ++k)
+            vec_store_psqt(&accTilePsqt[k], psqt[k]);
+        }
+
   #else
         std::memcpy(accumulator.accumulation[perspective], biases,
             HalfDimensions * sizeof(BiasType));
 
+        for (std::size_t k = 0; k < PSQTBuckets; ++k)
+          accumulator.psqtAccumulation[perspective][k] = 0;
+
         for (const auto index : active)
         {
           const IndexType offset = HalfDimensions * index;
 
           for (IndexType j = 0; j < HalfDimensions; ++j)
             accumulator.accumulation[perspective][j] += weights[offset + j];
+
+          for (std::size_t k = 0; k < PSQTBuckets; ++k)
+            accumulator.psqtAccumulation[perspective][k] += psqtWeights[index * PSQTBuckets + k];
         }
   #endif
       }
@@ -413,9 +540,11 @@ namespace Stockfish::Eval::NNUE {
 
     using BiasType = std::int16_t;
     using WeightType = std::int16_t;
+    using PSQTWeightType = std::int32_t;
 
     alignas(CacheLineSize) BiasType biases[HalfDimensions];
     alignas(CacheLineSize) WeightType weights[HalfDimensions * InputDimensions];
+    alignas(CacheLineSize) PSQTWeightType psqtWeights[InputDimensions * PSQTBuckets];
   };
 
 }  // namespace Stockfish::Eval::NNUE