]> git.sesse.net Git - stockfish/blob - src/nnue/features/feature_set.h
26198114a3054e5c34aaa2f74a3ef19b2af40b5f
[stockfish] / src / nnue / features / feature_set.h
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // A class template that represents the input feature set of the NNUE evaluation function
20
21 #ifndef NNUE_FEATURE_SET_H_INCLUDED
22 #define NNUE_FEATURE_SET_H_INCLUDED
23
24 #include "features_common.h"
25 #include <array>
26
27 namespace Eval::NNUE::Features {
28
29   // Class template that represents a list of values
30   template <typename T, T... Values>
31   struct CompileTimeList;
32
33   template <typename T, T First, T... Remaining>
34   struct CompileTimeList<T, First, Remaining...> {
35     static constexpr bool Contains(T value) {
36       return value == First || CompileTimeList<T, Remaining...>::Contains(value);
37     }
38     static constexpr std::array<T, sizeof...(Remaining) + 1>
39         kValues = {{First, Remaining...}};
40   };
41
42   // Base class of feature set
43   template <typename Derived>
44   class FeatureSetBase {
45
46    public:
47     // Get a list of indices for active features
48     template <typename IndexListType>
49     static void AppendActiveIndices(
50         const Position& pos, TriggerEvent trigger, IndexListType active[2]) {
51
52       for (Color perspective : { WHITE, BLACK }) {
53         Derived::CollectActiveIndices(
54             pos, trigger, perspective, &active[perspective]);
55       }
56     }
57
58     // Get a list of indices for recently changed features
59     template <typename PositionType, typename IndexListType>
60     static void AppendChangedIndices(
61         const PositionType& pos, TriggerEvent trigger,
62         IndexListType removed[2], IndexListType added[2], bool reset[2]) {
63
64       auto collect_for_one = [&](const DirtyPiece& dp) {
65         for (Color perspective : { WHITE, BLACK }) {
66           switch (trigger) {
67             case TriggerEvent::kFriendKingMoved:
68               reset[perspective] = dp.piece[0] == make_piece(perspective, KING);
69               break;
70             default:
71               assert(false);
72               break;
73           }
74           if (reset[perspective]) {
75             Derived::CollectActiveIndices(
76                 pos, trigger, perspective, &added[perspective]);
77           } else {
78             Derived::CollectChangedIndices(
79                 pos, dp, trigger, perspective,
80                 &removed[perspective], &added[perspective]);
81           }
82         }
83       };
84
85       auto collect_for_two = [&](const DirtyPiece& dp1, const DirtyPiece& dp2) {
86         for (Color perspective : { WHITE, BLACK }) {
87           switch (trigger) {
88             case TriggerEvent::kFriendKingMoved:
89               reset[perspective] = dp1.piece[0] == make_piece(perspective, KING)
90                                 || dp2.piece[0] == make_piece(perspective, KING);
91               break;
92             default:
93               assert(false);
94               break;
95           }
96           if (reset[perspective]) {
97             Derived::CollectActiveIndices(
98                 pos, trigger, perspective, &added[perspective]);
99           } else {
100             Derived::CollectChangedIndices(
101                 pos, dp1, trigger, perspective,
102                 &removed[perspective], &added[perspective]);
103             Derived::CollectChangedIndices(
104                 pos, dp2, trigger, perspective,
105                 &removed[perspective], &added[perspective]);
106           }
107         }
108       };
109
110       if (pos.state()->previous->accumulator.computed_accumulation) {
111         const auto& prev_dp = pos.state()->dirtyPiece;
112         if (prev_dp.dirty_num == 0) return;
113         collect_for_one(prev_dp);
114       } else {
115         const auto& prev_dp = pos.state()->previous->dirtyPiece;
116         if (prev_dp.dirty_num == 0) {
117           const auto& prev2_dp = pos.state()->dirtyPiece;
118           if (prev2_dp.dirty_num == 0) return;
119           collect_for_one(prev2_dp);
120         } else {
121           const auto& prev2_dp = pos.state()->dirtyPiece;
122           if (prev2_dp.dirty_num == 0) {
123             collect_for_one(prev_dp);
124           } else {
125             collect_for_two(prev_dp, prev2_dp);
126           }
127         }
128       }
129     }
130   };
131
132   // Class template that represents the feature set
133   template <typename FeatureType>
134   class FeatureSet<FeatureType> : public FeatureSetBase<FeatureSet<FeatureType>> {
135
136    public:
137     // Hash value embedded in the evaluation file
138     static constexpr std::uint32_t kHashValue = FeatureType::kHashValue;
139     // Number of feature dimensions
140     static constexpr IndexType kDimensions = FeatureType::kDimensions;
141     // Maximum number of simultaneously active features
142     static constexpr IndexType kMaxActiveDimensions =
143         FeatureType::kMaxActiveDimensions;
144     // Trigger for full calculation instead of difference calculation
145     using SortedTriggerSet =
146         CompileTimeList<TriggerEvent, FeatureType::kRefreshTrigger>;
147     static constexpr auto kRefreshTriggers = SortedTriggerSet::kValues;
148
149    private:
150     // Get a list of indices for active features
151     static void CollectActiveIndices(
152         const Position& pos, const TriggerEvent trigger, const Color perspective,
153         IndexList* const active) {
154       if (FeatureType::kRefreshTrigger == trigger) {
155         FeatureType::AppendActiveIndices(pos, perspective, active);
156       }
157     }
158
159     // Get a list of indices for recently changed features
160     static void CollectChangedIndices(
161         const Position& pos, const DirtyPiece& dp, const TriggerEvent trigger, const Color perspective,
162         IndexList* const removed, IndexList* const added) {
163
164       if (FeatureType::kRefreshTrigger == trigger) {
165         FeatureType::AppendChangedIndices(pos, dp, perspective, removed, added);
166       }
167     }
168
169     // Make the base class and the class template that recursively uses itself a friend
170     friend class FeatureSetBase<FeatureSet>;
171     template <typename... FeatureTypes>
172     friend class FeatureSet;
173   };
174
175 }  // namespace Eval::NNUE::Features
176
177 #endif // #ifndef NNUE_FEATURE_SET_H_INCLUDED