]> git.sesse.net Git - stockfish/blob - feature_set.h
79ca83aed0b973df26be9f780394a63296e13c91
[stockfish] / feature_set.h
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // A class template that represents the input feature set of the NNUE evaluation function
20
21 #ifndef NNUE_FEATURE_SET_H_INCLUDED
22 #define NNUE_FEATURE_SET_H_INCLUDED
23
24 #include "features_common.h"
25 #include <array>
26
27 namespace Eval::NNUE::Features {
28
29   // Class template that represents a list of values
30   template <typename T, T... Values>
31   struct CompileTimeList;
32
33   template <typename T, T First, T... Remaining>
34   struct CompileTimeList<T, First, Remaining...> {
35     static constexpr bool Contains(T value) {
36       return value == First || CompileTimeList<T, Remaining...>::Contains(value);
37     }
38     static constexpr std::array<T, sizeof...(Remaining) + 1>
39         kValues = {{First, Remaining...}};
40   };
41
42   // Base class of feature set
43   template <typename Derived>
44   class FeatureSetBase {
45
46    public:
47     // Get a list of indices for active features
48     template <typename IndexListType>
49     static void AppendActiveIndices(
50         const Position& pos, TriggerEvent trigger, IndexListType active[2]) {
51
52       for (Color perspective : { WHITE, BLACK }) {
53         Derived::CollectActiveIndices(
54             pos, trigger, perspective, &active[perspective]);
55       }
56     }
57
58     // Get a list of indices for recently changed features
59     template <typename PositionType, typename IndexListType>
60     static void AppendChangedIndices(
61         const PositionType& pos, TriggerEvent trigger,
62         IndexListType removed[2], IndexListType added[2], bool reset[2]) {
63
64       const auto& dp = pos.state()->dirtyPiece;
65       if (dp.dirty_num == 0) return;
66
67       for (Color perspective : { WHITE, BLACK }) {
68         reset[perspective] = false;
69         switch (trigger) {
70           case TriggerEvent::kFriendKingMoved:
71             reset[perspective] =
72                 dp.pieceId[0] == PIECE_ID_KING + perspective;
73             break;
74           default:
75             assert(false);
76             break;
77         }
78         if (reset[perspective]) {
79           Derived::CollectActiveIndices(
80               pos, trigger, perspective, &added[perspective]);
81         } else {
82           Derived::CollectChangedIndices(
83               pos, trigger, perspective,
84               &removed[perspective], &added[perspective]);
85         }
86       }
87     }
88   };
89
90   // Class template that represents the feature set
91   template <typename FeatureType>
92   class FeatureSet<FeatureType> : public FeatureSetBase<FeatureSet<FeatureType>> {
93
94    public:
95     // Hash value embedded in the evaluation file
96     static constexpr std::uint32_t kHashValue = FeatureType::kHashValue;
97     // Number of feature dimensions
98     static constexpr IndexType kDimensions = FeatureType::kDimensions;
99     // Maximum number of simultaneously active features
100     static constexpr IndexType kMaxActiveDimensions =
101         FeatureType::kMaxActiveDimensions;
102     // Trigger for full calculation instead of difference calculation
103     using SortedTriggerSet =
104         CompileTimeList<TriggerEvent, FeatureType::kRefreshTrigger>;
105     static constexpr auto kRefreshTriggers = SortedTriggerSet::kValues;
106
107    private:
108     // Get a list of indices for active features
109     static void CollectActiveIndices(
110         const Position& pos, const TriggerEvent trigger, const Color perspective,
111         IndexList* const active) {
112       if (FeatureType::kRefreshTrigger == trigger) {
113         FeatureType::AppendActiveIndices(pos, perspective, active);
114       }
115     }
116
117     // Get a list of indices for recently changed features
118     static void CollectChangedIndices(
119         const Position& pos, const TriggerEvent trigger, const Color perspective,
120         IndexList* const removed, IndexList* const added) {
121
122       if (FeatureType::kRefreshTrigger == trigger) {
123         FeatureType::AppendChangedIndices(pos, perspective, removed, added);
124       }
125     }
126
127     // Make the base class and the class template that recursively uses itself a friend
128     friend class FeatureSetBase<FeatureSet>;
129     template <typename... FeatureTypes>
130     friend class FeatureSet;
131   };
132
133 }  // namespace Eval::NNUE::Features
134
135 #endif // #ifndef NNUE_FEATURE_SET_H_INCLUDED