X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fnnue%2Ffeatures%2Ffeature_set.h;h=26198114a3054e5c34aaa2f74a3ef19b2af40b5f;hp=79ca83aed0b973df26be9f780394a63296e13c91;hb=c065abdcafe0486cb5cfa7de12a4ac6a905a54c5;hpb=84f3e867903f62480c33243dd0ecbffd342796fc diff --git a/src/nnue/features/feature_set.h b/src/nnue/features/feature_set.h index 79ca83ae..26198114 100644 --- a/src/nnue/features/feature_set.h +++ b/src/nnue/features/feature_set.h @@ -61,27 +61,69 @@ namespace Eval::NNUE::Features { const PositionType& pos, TriggerEvent trigger, IndexListType removed[2], IndexListType added[2], bool reset[2]) { - const auto& dp = pos.state()->dirtyPiece; - if (dp.dirty_num == 0) return; - - for (Color perspective : { WHITE, BLACK }) { - reset[perspective] = false; - switch (trigger) { - case TriggerEvent::kFriendKingMoved: - reset[perspective] = - dp.pieceId[0] == PIECE_ID_KING + perspective; - break; - default: - assert(false); - break; + auto collect_for_one = [&](const DirtyPiece& dp) { + for (Color perspective : { WHITE, BLACK }) { + switch (trigger) { + case TriggerEvent::kFriendKingMoved: + reset[perspective] = dp.piece[0] == make_piece(perspective, KING); + break; + default: + assert(false); + break; + } + if (reset[perspective]) { + Derived::CollectActiveIndices( + pos, trigger, perspective, &added[perspective]); + } else { + Derived::CollectChangedIndices( + pos, dp, trigger, perspective, + &removed[perspective], &added[perspective]); + } } - if (reset[perspective]) { - Derived::CollectActiveIndices( - pos, trigger, perspective, &added[perspective]); + }; + + auto collect_for_two = [&](const DirtyPiece& dp1, const DirtyPiece& dp2) { + for (Color perspective : { WHITE, BLACK }) { + switch (trigger) { + case TriggerEvent::kFriendKingMoved: + reset[perspective] = dp1.piece[0] == make_piece(perspective, KING) + || dp2.piece[0] == make_piece(perspective, KING); + break; + default: + assert(false); + break; + } + if (reset[perspective]) { + Derived::CollectActiveIndices( + pos, trigger, perspective, &added[perspective]); + } else { + Derived::CollectChangedIndices( + pos, dp1, trigger, perspective, + &removed[perspective], &added[perspective]); + Derived::CollectChangedIndices( + pos, dp2, trigger, perspective, + &removed[perspective], &added[perspective]); + } + } + }; + + if (pos.state()->previous->accumulator.computed_accumulation) { + const auto& prev_dp = pos.state()->dirtyPiece; + if (prev_dp.dirty_num == 0) return; + collect_for_one(prev_dp); + } else { + const auto& prev_dp = pos.state()->previous->dirtyPiece; + if (prev_dp.dirty_num == 0) { + const auto& prev2_dp = pos.state()->dirtyPiece; + if (prev2_dp.dirty_num == 0) return; + collect_for_one(prev2_dp); } else { - Derived::CollectChangedIndices( - pos, trigger, perspective, - &removed[perspective], &added[perspective]); + const auto& prev2_dp = pos.state()->dirtyPiece; + if (prev2_dp.dirty_num == 0) { + collect_for_one(prev_dp); + } else { + collect_for_two(prev_dp, prev2_dp); + } } } } @@ -116,11 +158,11 @@ namespace Eval::NNUE::Features { // Get a list of indices for recently changed features static void CollectChangedIndices( - const Position& pos, const TriggerEvent trigger, const Color perspective, + const Position& pos, const DirtyPiece& dp, const TriggerEvent trigger, const Color perspective, IndexList* const removed, IndexList* const added) { if (FeatureType::kRefreshTrigger == trigger) { - FeatureType::AppendChangedIndices(pos, perspective, removed, added); + FeatureType::AppendChangedIndices(pos, dp, perspective, removed, added); } }