]> git.sesse.net Git - stockfish/blobdiff - src/nnue/features/feature_set.h
Use incremental updates more often
[stockfish] / src / nnue / features / feature_set.h
index 558a6b228779e38607793b94623ecc0c4516caba..26198114a3054e5c34aaa2f74a3ef19b2af40b5f 100644 (file)
@@ -61,26 +61,69 @@ namespace Eval::NNUE::Features {
         const PositionType& pos, TriggerEvent trigger,
         IndexListType removed[2], IndexListType added[2], bool reset[2]) {
 
-      const auto& dp = pos.state()->dirtyPiece;
-      if (dp.dirty_num == 0) return;
-
-      for (Color perspective : { WHITE, BLACK }) {
-        reset[perspective] = false;
-        switch (trigger) {
-          case TriggerEvent::kFriendKingMoved:
-            reset[perspective] = dp.piece[0] == make_piece(perspective, KING);
-            break;
-          default:
-            assert(false);
-            break;
+      auto collect_for_one = [&](const DirtyPiece& dp) {
+        for (Color perspective : { WHITE, BLACK }) {
+          switch (trigger) {
+            case TriggerEvent::kFriendKingMoved:
+              reset[perspective] = dp.piece[0] == make_piece(perspective, KING);
+              break;
+            default:
+              assert(false);
+              break;
+          }
+          if (reset[perspective]) {
+            Derived::CollectActiveIndices(
+                pos, trigger, perspective, &added[perspective]);
+          } else {
+            Derived::CollectChangedIndices(
+                pos, dp, trigger, perspective,
+                &removed[perspective], &added[perspective]);
+          }
         }
-        if (reset[perspective]) {
-          Derived::CollectActiveIndices(
-              pos, trigger, perspective, &added[perspective]);
+      };
+
+      auto collect_for_two = [&](const DirtyPiece& dp1, const DirtyPiece& dp2) {
+        for (Color perspective : { WHITE, BLACK }) {
+          switch (trigger) {
+            case TriggerEvent::kFriendKingMoved:
+              reset[perspective] = dp1.piece[0] == make_piece(perspective, KING)
+                                || dp2.piece[0] == make_piece(perspective, KING);
+              break;
+            default:
+              assert(false);
+              break;
+          }
+          if (reset[perspective]) {
+            Derived::CollectActiveIndices(
+                pos, trigger, perspective, &added[perspective]);
+          } else {
+            Derived::CollectChangedIndices(
+                pos, dp1, trigger, perspective,
+                &removed[perspective], &added[perspective]);
+            Derived::CollectChangedIndices(
+                pos, dp2, trigger, perspective,
+                &removed[perspective], &added[perspective]);
+          }
+        }
+      };
+
+      if (pos.state()->previous->accumulator.computed_accumulation) {
+        const auto& prev_dp = pos.state()->dirtyPiece;
+        if (prev_dp.dirty_num == 0) return;
+        collect_for_one(prev_dp);
+      } else {
+        const auto& prev_dp = pos.state()->previous->dirtyPiece;
+        if (prev_dp.dirty_num == 0) {
+          const auto& prev2_dp = pos.state()->dirtyPiece;
+          if (prev2_dp.dirty_num == 0) return;
+          collect_for_one(prev2_dp);
         } else {
-          Derived::CollectChangedIndices(
-              pos, trigger, perspective,
-              &removed[perspective], &added[perspective]);
+          const auto& prev2_dp = pos.state()->dirtyPiece;
+          if (prev2_dp.dirty_num == 0) {
+            collect_for_one(prev_dp);
+          } else {
+            collect_for_two(prev_dp, prev2_dp);
+          }
         }
       }
     }
@@ -115,11 +158,11 @@ namespace Eval::NNUE::Features {
 
     // Get a list of indices for recently changed features
     static void CollectChangedIndices(
-        const Position& pos, const TriggerEvent trigger, const Color perspective,
+        const Position& pos, const DirtyPiece& dp, const TriggerEvent trigger, const Color perspective,
         IndexList* const removed, IndexList* const added) {
 
       if (FeatureType::kRefreshTrigger == trigger) {
-        FeatureType::AppendChangedIndices(pos, perspective, removed, added);
+        FeatureType::AppendChangedIndices(pos, dp, perspective, removed, added);
       }
     }