2 Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3 Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file)
5 Stockfish is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, either version 3 of the License, or
8 (at your option) any later version.
10 Stockfish is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with this program. If not, see <http://www.gnu.org/licenses/>.
19 // A class template that represents the input feature set of the NNUE evaluation function
21 #ifndef NNUE_FEATURE_SET_H_INCLUDED
22 #define NNUE_FEATURE_SET_H_INCLUDED
24 #include "features_common.h"
27 namespace Eval::NNUE::Features {
29 // Class template that represents a list of values
30 template <typename T, T... Values>
31 struct CompileTimeList;
33 template <typename T, T First, T... Remaining>
34 struct CompileTimeList<T, First, Remaining...> {
35 static constexpr bool Contains(T value) {
36 return value == First || CompileTimeList<T, Remaining...>::Contains(value);
38 static constexpr std::array<T, sizeof...(Remaining) + 1>
39 kValues = {{First, Remaining...}};
42 // Base class of feature set
43 template <typename Derived>
44 class FeatureSetBase {
47 // Get a list of indices for active features
48 template <typename IndexListType>
49 static void AppendActiveIndices(
50 const Position& pos, TriggerEvent trigger, IndexListType active[2]) {
52 for (Color perspective : { WHITE, BLACK }) {
53 Derived::CollectActiveIndices(
54 pos, trigger, perspective, &active[perspective]);
58 // Get a list of indices for recently changed features
59 template <typename PositionType, typename IndexListType>
60 static void AppendChangedIndices(
61 const PositionType& pos, TriggerEvent trigger,
62 IndexListType removed[2], IndexListType added[2], bool reset[2]) {
64 auto collect_for_one = [&](const DirtyPiece& dp) {
65 for (Color perspective : { WHITE, BLACK }) {
67 case TriggerEvent::kFriendKingMoved:
68 reset[perspective] = dp.piece[0] == make_piece(perspective, KING);
74 if (reset[perspective]) {
75 Derived::CollectActiveIndices(
76 pos, trigger, perspective, &added[perspective]);
78 Derived::CollectChangedIndices(
79 pos, dp, trigger, perspective,
80 &removed[perspective], &added[perspective]);
85 auto collect_for_two = [&](const DirtyPiece& dp1, const DirtyPiece& dp2) {
86 for (Color perspective : { WHITE, BLACK }) {
88 case TriggerEvent::kFriendKingMoved:
89 reset[perspective] = dp1.piece[0] == make_piece(perspective, KING)
90 || dp2.piece[0] == make_piece(perspective, KING);
96 if (reset[perspective]) {
97 Derived::CollectActiveIndices(
98 pos, trigger, perspective, &added[perspective]);
100 Derived::CollectChangedIndices(
101 pos, dp1, trigger, perspective,
102 &removed[perspective], &added[perspective]);
103 Derived::CollectChangedIndices(
104 pos, dp2, trigger, perspective,
105 &removed[perspective], &added[perspective]);
110 if (pos.state()->previous->accumulator.computed_accumulation) {
111 const auto& prev_dp = pos.state()->dirtyPiece;
112 if (prev_dp.dirty_num == 0) return;
113 collect_for_one(prev_dp);
115 const auto& prev_dp = pos.state()->previous->dirtyPiece;
116 if (prev_dp.dirty_num == 0) {
117 const auto& prev2_dp = pos.state()->dirtyPiece;
118 if (prev2_dp.dirty_num == 0) return;
119 collect_for_one(prev2_dp);
121 const auto& prev2_dp = pos.state()->dirtyPiece;
122 if (prev2_dp.dirty_num == 0) {
123 collect_for_one(prev_dp);
125 collect_for_two(prev_dp, prev2_dp);
132 // Class template that represents the feature set
133 template <typename FeatureType>
134 class FeatureSet<FeatureType> : public FeatureSetBase<FeatureSet<FeatureType>> {
137 // Hash value embedded in the evaluation file
138 static constexpr std::uint32_t kHashValue = FeatureType::kHashValue;
139 // Number of feature dimensions
140 static constexpr IndexType kDimensions = FeatureType::kDimensions;
141 // Maximum number of simultaneously active features
142 static constexpr IndexType kMaxActiveDimensions =
143 FeatureType::kMaxActiveDimensions;
144 // Trigger for full calculation instead of difference calculation
145 using SortedTriggerSet =
146 CompileTimeList<TriggerEvent, FeatureType::kRefreshTrigger>;
147 static constexpr auto kRefreshTriggers = SortedTriggerSet::kValues;
150 // Get a list of indices for active features
151 static void CollectActiveIndices(
152 const Position& pos, const TriggerEvent trigger, const Color perspective,
153 IndexList* const active) {
154 if (FeatureType::kRefreshTrigger == trigger) {
155 FeatureType::AppendActiveIndices(pos, perspective, active);
159 // Get a list of indices for recently changed features
160 static void CollectChangedIndices(
161 const Position& pos, const DirtyPiece& dp, const TriggerEvent trigger, const Color perspective,
162 IndexList* const removed, IndexList* const added) {
164 if (FeatureType::kRefreshTrigger == trigger) {
165 FeatureType::AppendChangedIndices(pos, dp, perspective, removed, added);
169 // Make the base class and the class template that recursively uses itself a friend
170 friend class FeatureSetBase<FeatureSet>;
171 template <typename... FeatureTypes>
172 friend class FeatureSet;
175 } // namespace Eval::NNUE::Features
177 #endif // #ifndef NNUE_FEATURE_SET_H_INCLUDED