]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_feature_transformer.h
Cleanup and simplify NNUE code.
[stockfish] / src / nnue / nnue_feature_transformer.h
index de4b49374f91768dce1d04f1dc2dd2615b69bf3f..f441274915329f3605831414841920a65093f0db 100644 (file)
@@ -23,7 +23,8 @@
 
 #include "nnue_common.h"
 #include "nnue_architecture.h"
-#include "features/index_list.h"
+
+#include "../misc.h"
 
 #include <cstring> // std::memset()
 
@@ -96,7 +97,7 @@ namespace Stockfish::Eval::NNUE {
     using OutputType = TransformedFeatureType;
 
     // Number of input/output dimensions
-    static constexpr IndexType InputDimensions = RawFeatures::Dimensions;
+    static constexpr IndexType InputDimensions = FeatureSet::Dimensions;
     static constexpr IndexType OutputDimensions = HalfDimensions * 2;
 
     // Size of forward propagation buffer
@@ -105,7 +106,7 @@ namespace Stockfish::Eval::NNUE {
 
     // Hash value embedded in the evaluation file
     static constexpr std::uint32_t get_hash_value() {
-      return RawFeatures::HashValue ^ OutputDimensions;
+      return FeatureSet::HashValue ^ OutputDimensions;
     }
 
     // Read network parameters
@@ -161,9 +162,9 @@ namespace Stockfish::Eval::NNUE {
         auto out = reinterpret_cast<__m512i*>(&output[offset]);
         for (IndexType j = 0; j < NumChunks; ++j) {
           __m512i sum0 = _mm512_load_si512(
-              &reinterpret_cast<const __m512i*>(accumulation[perspectives[p]][0])[j * 2 + 0]);
+              &reinterpret_cast<const __m512i*>(accumulation[perspectives[p]])[j * 2 + 0]);
           __m512i sum1 = _mm512_load_si512(
-              &reinterpret_cast<const __m512i*>(accumulation[perspectives[p]][0])[j * 2 + 1]);
+              &reinterpret_cast<const __m512i*>(accumulation[perspectives[p]])[j * 2 + 1]);
           _mm512_store_si512(&out[j], _mm512_permutexvar_epi64(Control,
               _mm512_max_epi8(_mm512_packs_epi16(sum0, sum1), Zero)));
         }
@@ -172,9 +173,9 @@ namespace Stockfish::Eval::NNUE {
         auto out = reinterpret_cast<__m256i*>(&output[offset]);
         for (IndexType j = 0; j < NumChunks; ++j) {
           __m256i sum0 = _mm256_load_si256(
-              &reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 0]);
+              &reinterpret_cast<const __m256i*>(accumulation[perspectives[p]])[j * 2 + 0]);
           __m256i sum1 = _mm256_load_si256(
-              &reinterpret_cast<const __m256i*>(accumulation[perspectives[p]][0])[j * 2 + 1]);
+              &reinterpret_cast<const __m256i*>(accumulation[perspectives[p]])[j * 2 + 1]);
           _mm256_store_si256(&out[j], _mm256_permute4x64_epi64(_mm256_max_epi8(
               _mm256_packs_epi16(sum0, sum1), Zero), Control));
         }
@@ -183,9 +184,9 @@ namespace Stockfish::Eval::NNUE {
         auto out = reinterpret_cast<__m128i*>(&output[offset]);
         for (IndexType j = 0; j < NumChunks; ++j) {
           __m128i sum0 = _mm_load_si128(&reinterpret_cast<const __m128i*>(
-              accumulation[perspectives[p]][0])[j * 2 + 0]);
+              accumulation[perspectives[p]])[j * 2 + 0]);
           __m128i sum1 = _mm_load_si128(&reinterpret_cast<const __m128i*>(
-              accumulation[perspectives[p]][0])[j * 2 + 1]);
+              accumulation[perspectives[p]])[j * 2 + 1]);
       const __m128i packedbytes = _mm_packs_epi16(sum0, sum1);
 
           _mm_store_si128(&out[j],
@@ -203,9 +204,9 @@ namespace Stockfish::Eval::NNUE {
         auto out = reinterpret_cast<__m64*>(&output[offset]);
         for (IndexType j = 0; j < NumChunks; ++j) {
           __m64 sum0 = *(&reinterpret_cast<const __m64*>(
-              accumulation[perspectives[p]][0])[j * 2 + 0]);
+              accumulation[perspectives[p]])[j * 2 + 0]);
           __m64 sum1 = *(&reinterpret_cast<const __m64*>(
-              accumulation[perspectives[p]][0])[j * 2 + 1]);
+              accumulation[perspectives[p]])[j * 2 + 1]);
           const __m64 packedbytes = _mm_packs_pi16(sum0, sum1);
           out[j] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s);
         }
@@ -214,13 +215,13 @@ namespace Stockfish::Eval::NNUE {
         const auto out = reinterpret_cast<int8x8_t*>(&output[offset]);
         for (IndexType j = 0; j < NumChunks; ++j) {
           int16x8_t sum = reinterpret_cast<const int16x8_t*>(
-              accumulation[perspectives[p]][0])[j];
+              accumulation[perspectives[p]])[j];
           out[j] = vmax_s8(vqmovn_s16(sum), Zero);
         }
 
   #else
         for (IndexType j = 0; j < HalfDimensions; ++j) {
-          BiasType sum = accumulation[static_cast<int>(perspectives[p])][0][j];
+          BiasType sum = accumulation[static_cast<int>(perspectives[p])][j];
           output[offset + j] = static_cast<OutputType>(
               std::max<int>(0, std::min<int>(127, sum)));
         }
@@ -233,7 +234,13 @@ namespace Stockfish::Eval::NNUE {
     }
 
    private:
-    void update_accumulator(const Position& pos, const Color c) const {
+    void update_accumulator(const Position& pos, const Color perspective) const {
+
+      // The size must be enough to contain the largest possible update.
+      // That might depend on the feature set and generally relies on the
+      // feature set's update cost calculation to be correct and never
+      // allow updates with more added/removed features than MaxActiveDimensions.
+      using IndexList = ValueList<IndexType, FeatureSet::MaxActiveDimensions>;
 
   #ifdef VECTOR
       // Gcc-10.2 unnecessarily spills AVX2 registers if this array
@@ -244,23 +251,19 @@ namespace Stockfish::Eval::NNUE {
       // Look for a usable accumulator of an earlier position. We keep track
       // of the estimated gain in terms of features to be added/subtracted.
       StateInfo *st = pos.state(), *next = nullptr;
-      int gain = pos.count<ALL_PIECES>() - 2;
-      while (st->accumulator.state[c] == EMPTY)
+      int gain = FeatureSet::refresh_cost(pos);
+      while (st->accumulator.state[perspective] == EMPTY)
       {
-        auto& dp = st->dirtyPiece;
-        // The first condition tests whether an incremental update is
-        // possible at all: if this side's king has moved, it is not possible.
-        static_assert(std::is_same_v<RawFeatures::SortedTriggerSet,
-              Features::CompileTimeList<Features::TriggerEvent, Features::TriggerEvent::FriendKingMoved>>,
-              "Current code assumes that only FriendlyKingMoved refresh trigger is being used.");
-        if (   dp.piece[0] == make_piece(c, KING)
-            || (gain -= dp.dirty_num + 1) < 0)
+        // This governs when a full feature refresh is needed and how many
+        // updates are better than just one full refresh.
+        if (   FeatureSet::requires_refresh(st, perspective)
+            || (gain -= FeatureSet::update_cost(st) + 1) < 0)
           break;
         next = st;
         st = st->previous;
       }
 
-      if (st->accumulator.state[c] == COMPUTED)
+      if (st->accumulator.state[perspective] == COMPUTED)
       {
         if (next == nullptr)
           return;
@@ -268,34 +271,32 @@ namespace Stockfish::Eval::NNUE {
         // Update incrementally in two steps. First, we update the "next"
         // accumulator. Then, we update the current accumulator (pos.state()).
 
-        // Gather all features to be updated. This code assumes HalfKP features
-        // only and doesn't support refresh triggers.
-        static_assert(std::is_same_v<Features::FeatureSet<Features::HalfKP<Features::Side::Friend>>,
-                                     RawFeatures>);
-        Features::IndexList removed[2], added[2];
-        Features::HalfKP<Features::Side::Friend>::append_changed_indices(pos,
-            next->dirtyPiece, c, &removed[0], &added[0]);
+        // Gather all features to be updated.
+        const Square ksq = pos.square<KING>(perspective);
+        IndexList removed[2], added[2];
+        FeatureSet::append_changed_indices(
+          ksq, next, perspective, removed[0], added[0]);
         for (StateInfo *st2 = pos.state(); st2 != next; st2 = st2->previous)
-          Features::HalfKP<Features::Side::Friend>::append_changed_indices(pos,
-              st2->dirtyPiece, c, &removed[1], &added[1]);
+          FeatureSet::append_changed_indices(
+            ksq, st2, perspective, removed[1], added[1]);
 
         // Mark the accumulators as computed.
-        next->accumulator.state[c] = COMPUTED;
-        pos.state()->accumulator.state[c] = COMPUTED;
+        next->accumulator.state[perspective] = COMPUTED;
+        pos.state()->accumulator.state[perspective] = COMPUTED;
 
-        // Now update the accumulators listed in info[], where the last element is a sentinel.
-        StateInfo *info[3] =
+        // Now update the accumulators listed in states_to_update[], where the last element is a sentinel.
+        StateInfo *states_to_update[3] =
           { next, next == pos.state() ? nullptr : pos.state(), nullptr };
   #ifdef VECTOR
         for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j)
         {
           // Load accumulator
           auto accTile = reinterpret_cast<vec_t*>(
-            &st->accumulator.accumulation[c][0][j * TileHeight]);
+            &st->accumulator.accumulation[perspective][j * TileHeight]);
           for (IndexType k = 0; k < NumRegs; ++k)
             acc[k] = vec_load(&accTile[k]);
 
-          for (IndexType i = 0; info[i]; ++i)
+          for (IndexType i = 0; states_to_update[i]; ++i)
           {
             // Difference calculation for the deactivated features
             for (const auto index : removed[i])
@@ -317,19 +318,19 @@ namespace Stockfish::Eval::NNUE {
 
             // Store accumulator
             accTile = reinterpret_cast<vec_t*>(
-              &info[i]->accumulator.accumulation[c][0][j * TileHeight]);
+              &states_to_update[i]->accumulator.accumulation[perspective][j * TileHeight]);
             for (IndexType k = 0; k < NumRegs; ++k)
               vec_store(&accTile[k], acc[k]);
           }
         }
 
   #else
-        for (IndexType i = 0; info[i]; ++i)
+        for (IndexType i = 0; states_to_update[i]; ++i)
         {
-          std::memcpy(info[i]->accumulator.accumulation[c][0],
-              st->accumulator.accumulation[c][0],
+          std::memcpy(states_to_update[i]->accumulator.accumulation[perspective],
+              st->accumulator.accumulation[perspective],
               HalfDimensions * sizeof(BiasType));
-          st = info[i];
+          st = states_to_update[i];
 
           // Difference calculation for the deactivated features
           for (const auto index : removed[i])
@@ -337,7 +338,7 @@ namespace Stockfish::Eval::NNUE {
             const IndexType offset = HalfDimensions * index;
 
             for (IndexType j = 0; j < HalfDimensions; ++j)
-              st->accumulator.accumulation[c][0][j] -= weights[offset + j];
+              st->accumulator.accumulation[perspective][j] -= weights[offset + j];
           }
 
           // Difference calculation for the activated features
@@ -346,7 +347,7 @@ namespace Stockfish::Eval::NNUE {
             const IndexType offset = HalfDimensions * index;
 
             for (IndexType j = 0; j < HalfDimensions; ++j)
-              st->accumulator.accumulation[c][0][j] += weights[offset + j];
+              st->accumulator.accumulation[perspective][j] += weights[offset + j];
           }
         }
   #endif
@@ -355,9 +356,9 @@ namespace Stockfish::Eval::NNUE {
       {
         // Refresh the accumulator
         auto& accumulator = pos.state()->accumulator;
-        accumulator.state[c] = COMPUTED;
-        Features::IndexList active;
-        Features::HalfKP<Features::Side::Friend>::append_active_indices(pos, c, &active);
+        accumulator.state[perspective] = COMPUTED;
+        IndexList active;
+        FeatureSet::append_active_indices(pos, perspective, active);
 
   #ifdef VECTOR
         for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j)
@@ -377,13 +378,13 @@ namespace Stockfish::Eval::NNUE {
           }
 
           auto accTile = reinterpret_cast<vec_t*>(
-              &accumulator.accumulation[c][0][j * TileHeight]);
+              &accumulator.accumulation[perspective][j * TileHeight]);
           for (unsigned k = 0; k < NumRegs; k++)
             vec_store(&accTile[k], acc[k]);
         }
 
   #else
-        std::memcpy(accumulator.accumulation[c][0], biases,
+        std::memcpy(accumulator.accumulation[perspective], biases,
             HalfDimensions * sizeof(BiasType));
 
         for (const auto index : active)
@@ -391,7 +392,7 @@ namespace Stockfish::Eval::NNUE {
           const IndexType offset = HalfDimensions * index;
 
           for (IndexType j = 0; j < HalfDimensions; ++j)
-            accumulator.accumulation[c][0][j] += weights[offset + j];
+            accumulator.accumulation[perspective][j] += weights[offset + j];
         }
   #endif
       }
@@ -405,8 +406,7 @@ namespace Stockfish::Eval::NNUE {
     using WeightType = std::int16_t;
 
     alignas(CacheLineSize) BiasType biases[HalfDimensions];
-    alignas(CacheLineSize)
-        WeightType weights[HalfDimensions * InputDimensions];
+    alignas(CacheLineSize) WeightType weights[HalfDimensions * InputDimensions];
   };
 
 }  // namespace Stockfish::Eval::NNUE