X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fnnue%2Ffeatures%2Ffeature_set.h;fp=src%2Fnnue%2Ffeatures%2Ffeature_set.h;h=79ca83aed0b973df26be9f780394a63296e13c91;hp=0000000000000000000000000000000000000000;hb=84f3e867903f62480c33243dd0ecbffd342796fc;hpb=9587eeeb5ed29f834d4f956b92e0e732877c47a7
diff --git a/src/nnue/features/feature_set.h b/src/nnue/features/feature_set.h
new file mode 100644
index 00000000..79ca83ae
--- /dev/null
+++ b/src/nnue/features/feature_set.h
@@ -0,0 +1,135 @@
+/*
+ Stockfish, a UCI chess playing engine derived from Glaurung 2.1
+ Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file)
+
+ Stockfish is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Stockfish is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+*/
+
+// A class template that represents the input feature set of the NNUE evaluation function
+
+#ifndef NNUE_FEATURE_SET_H_INCLUDED
+#define NNUE_FEATURE_SET_H_INCLUDED
+
+#include "features_common.h"
+#include
+
+namespace Eval::NNUE::Features {
+
+ // Class template that represents a list of values
+ template
+ struct CompileTimeList;
+
+ template
+ struct CompileTimeList {
+ static constexpr bool Contains(T value) {
+ return value == First || CompileTimeList::Contains(value);
+ }
+ static constexpr std::array
+ kValues = {{First, Remaining...}};
+ };
+
+ // Base class of feature set
+ template
+ class FeatureSetBase {
+
+ public:
+ // Get a list of indices for active features
+ template
+ static void AppendActiveIndices(
+ const Position& pos, TriggerEvent trigger, IndexListType active[2]) {
+
+ for (Color perspective : { WHITE, BLACK }) {
+ Derived::CollectActiveIndices(
+ pos, trigger, perspective, &active[perspective]);
+ }
+ }
+
+ // Get a list of indices for recently changed features
+ template
+ static void AppendChangedIndices(
+ const PositionType& pos, TriggerEvent trigger,
+ IndexListType removed[2], IndexListType added[2], bool reset[2]) {
+
+ const auto& dp = pos.state()->dirtyPiece;
+ if (dp.dirty_num == 0) return;
+
+ for (Color perspective : { WHITE, BLACK }) {
+ reset[perspective] = false;
+ switch (trigger) {
+ case TriggerEvent::kFriendKingMoved:
+ reset[perspective] =
+ dp.pieceId[0] == PIECE_ID_KING + perspective;
+ break;
+ default:
+ assert(false);
+ break;
+ }
+ if (reset[perspective]) {
+ Derived::CollectActiveIndices(
+ pos, trigger, perspective, &added[perspective]);
+ } else {
+ Derived::CollectChangedIndices(
+ pos, trigger, perspective,
+ &removed[perspective], &added[perspective]);
+ }
+ }
+ }
+ };
+
+ // Class template that represents the feature set
+ template
+ class FeatureSet : public FeatureSetBase> {
+
+ public:
+ // Hash value embedded in the evaluation file
+ static constexpr std::uint32_t kHashValue = FeatureType::kHashValue;
+ // Number of feature dimensions
+ static constexpr IndexType kDimensions = FeatureType::kDimensions;
+ // Maximum number of simultaneously active features
+ static constexpr IndexType kMaxActiveDimensions =
+ FeatureType::kMaxActiveDimensions;
+ // Trigger for full calculation instead of difference calculation
+ using SortedTriggerSet =
+ CompileTimeList;
+ static constexpr auto kRefreshTriggers = SortedTriggerSet::kValues;
+
+ private:
+ // Get a list of indices for active features
+ static void CollectActiveIndices(
+ const Position& pos, const TriggerEvent trigger, const Color perspective,
+ IndexList* const active) {
+ if (FeatureType::kRefreshTrigger == trigger) {
+ FeatureType::AppendActiveIndices(pos, perspective, active);
+ }
+ }
+
+ // Get a list of indices for recently changed features
+ static void CollectChangedIndices(
+ const Position& pos, const TriggerEvent trigger, const Color perspective,
+ IndexList* const removed, IndexList* const added) {
+
+ if (FeatureType::kRefreshTrigger == trigger) {
+ FeatureType::AppendChangedIndices(pos, perspective, removed, added);
+ }
+ }
+
+ // Make the base class and the class template that recursively uses itself a friend
+ friend class FeatureSetBase;
+ template
+ friend class FeatureSet;
+ };
+
+} // namespace Eval::NNUE::Features
+
+#endif // #ifndef NNUE_FEATURE_SET_H_INCLUDED