]> git.sesse.net Git - stockfish/blob - src/nnue/evaluate_nnue.cpp
d6ac9894cbbd4203303478c554a0cc5c1ce1ba78
[stockfish] / src / nnue / evaluate_nnue.cpp
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // Code for calculating NNUE evaluation function
20
21 #include <iostream>
22 #include <set>
23
24 #include "../evaluate.h"
25 #include "../position.h"
26 #include "../misc.h"
27 #include "../uci.h"
28
29 #include "evaluate_nnue.h"
30
31 namespace Eval::NNUE {
32
33   uint32_t kpp_board_index[PIECE_NB][COLOR_NB] = {
34    // convention: W - us, B - them
35    // viewed from other side, W and B are reversed
36       { PS_NONE,     PS_NONE     },
37       { PS_W_PAWN,   PS_B_PAWN   },
38       { PS_W_KNIGHT, PS_B_KNIGHT },
39       { PS_W_BISHOP, PS_B_BISHOP },
40       { PS_W_ROOK,   PS_B_ROOK   },
41       { PS_W_QUEEN,  PS_B_QUEEN  },
42       { PS_W_KING,   PS_B_KING   },
43       { PS_NONE,     PS_NONE     },
44       { PS_NONE,     PS_NONE     },
45       { PS_B_PAWN,   PS_W_PAWN   },
46       { PS_B_KNIGHT, PS_W_KNIGHT },
47       { PS_B_BISHOP, PS_W_BISHOP },
48       { PS_B_ROOK,   PS_W_ROOK   },
49       { PS_B_QUEEN,  PS_W_QUEEN  },
50       { PS_B_KING,   PS_W_KING   },
51       { PS_NONE,     PS_NONE     }
52   };
53
54   // Input feature converter
55   AlignedPtr<FeatureTransformer> feature_transformer;
56
57   // Evaluation function
58   AlignedPtr<Network> network;
59
60   // Evaluation function file name
61   std::string fileName;
62
63   namespace Detail {
64
65   // Initialize the evaluation function parameters
66   template <typename T>
67   void Initialize(AlignedPtr<T>& pointer) {
68
69     pointer.reset(reinterpret_cast<T*>(std_aligned_alloc(alignof(T), sizeof(T))));
70     std::memset(pointer.get(), 0, sizeof(T));
71   }
72
73   // Read evaluation function parameters
74   template <typename T>
75   bool ReadParameters(std::istream& stream, const AlignedPtr<T>& pointer) {
76
77     std::uint32_t header;
78     header = read_little_endian<std::uint32_t>(stream);
79     if (!stream || header != T::GetHashValue()) return false;
80     return pointer->ReadParameters(stream);
81   }
82
83   }  // namespace Detail
84
85   // Initialize the evaluation function parameters
86   void Initialize() {
87
88     Detail::Initialize(feature_transformer);
89     Detail::Initialize(network);
90   }
91
92   // Read network header
93   bool ReadHeader(std::istream& stream, std::uint32_t* hash_value, std::string* architecture)
94   {
95     std::uint32_t version, size;
96
97     version     = read_little_endian<std::uint32_t>(stream);
98     *hash_value = read_little_endian<std::uint32_t>(stream);
99     size        = read_little_endian<std::uint32_t>(stream);
100     if (!stream || version != kVersion) return false;
101     architecture->resize(size);
102     stream.read(&(*architecture)[0], size);
103     return !stream.fail();
104   }
105
106   // Read network parameters
107   bool ReadParameters(std::istream& stream) {
108
109     std::uint32_t hash_value;
110     std::string architecture;
111     if (!ReadHeader(stream, &hash_value, &architecture)) return false;
112     if (hash_value != kHashValue) return false;
113     if (!Detail::ReadParameters(stream, feature_transformer)) return false;
114     if (!Detail::ReadParameters(stream, network)) return false;
115     return stream && stream.peek() == std::ios::traits_type::eof();
116   }
117
118   // Proceed with the difference calculation if possible
119   static void UpdateAccumulatorIfPossible(const Position& pos) {
120
121     feature_transformer->UpdateAccumulatorIfPossible(pos);
122   }
123
124   // Calculate the evaluation value
125   static Value ComputeScore(const Position& pos, bool refresh) {
126
127     auto& accumulator = pos.state()->accumulator;
128     if (!refresh && accumulator.computed_score) {
129       return accumulator.score;
130     }
131
132     alignas(kCacheLineSize) TransformedFeatureType
133         transformed_features[FeatureTransformer::kBufferSize];
134     feature_transformer->Transform(pos, transformed_features, refresh);
135     alignas(kCacheLineSize) char buffer[Network::kBufferSize];
136     const auto output = network->Propagate(transformed_features, buffer);
137
138     auto score = static_cast<Value>(output[0] / FV_SCALE);
139
140     accumulator.score = score;
141     accumulator.computed_score = true;
142     return accumulator.score;
143   }
144
145   // Load eval, from a file stream or a memory stream
146   bool load_eval(std::string streamName, std::istream& stream) {
147
148     Initialize();
149     fileName = streamName;
150     return ReadParameters(stream);
151   }
152
153   // Evaluation function. Perform differential calculation.
154   Value evaluate(const Position& pos) {
155     return ComputeScore(pos, false);
156   }
157
158   // Evaluation function. Perform full calculation.
159   Value compute_eval(const Position& pos) {
160     return ComputeScore(pos, true);
161   }
162
163   // Proceed with the difference calculation if possible
164   void update_eval(const Position& pos) {
165     UpdateAccumulatorIfPossible(pos);
166   }
167
168 } // namespace Eval::NNUE