]> git.sesse.net Git - stockfish/blob - src/nnue/evaluate_nnue.cpp
b0ed7d2f5a4f78e0092603edf4756402c69de2ae
[stockfish] / src / nnue / evaluate_nnue.cpp
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // Code for calculating NNUE evaluation function
20
21 #include <iostream>
22 #include <set>
23
24 #include "../evaluate.h"
25 #include "../position.h"
26 #include "../misc.h"
27 #include "../uci.h"
28 #include "../types.h"
29
30 #include "evaluate_nnue.h"
31
32 namespace Eval::NNUE {
33
34   const uint32_t kpp_board_index[PIECE_NB][COLOR_NB] = {
35    // convention: W - us, B - them
36    // viewed from other side, W and B are reversed
37       { PS_NONE,     PS_NONE     },
38       { PS_W_PAWN,   PS_B_PAWN   },
39       { PS_W_KNIGHT, PS_B_KNIGHT },
40       { PS_W_BISHOP, PS_B_BISHOP },
41       { PS_W_ROOK,   PS_B_ROOK   },
42       { PS_W_QUEEN,  PS_B_QUEEN  },
43       { PS_W_KING,   PS_B_KING   },
44       { PS_NONE,     PS_NONE     },
45       { PS_NONE,     PS_NONE     },
46       { PS_B_PAWN,   PS_W_PAWN   },
47       { PS_B_KNIGHT, PS_W_KNIGHT },
48       { PS_B_BISHOP, PS_W_BISHOP },
49       { PS_B_ROOK,   PS_W_ROOK   },
50       { PS_B_QUEEN,  PS_W_QUEEN  },
51       { PS_B_KING,   PS_W_KING   },
52       { PS_NONE,     PS_NONE     }
53   };
54
55   // Input feature converter
56   LargePagePtr<FeatureTransformer> feature_transformer;
57
58   // Evaluation function
59   AlignedPtr<Network> network;
60
61   // Evaluation function file name
62   std::string fileName;
63
64   namespace Detail {
65
66   // Initialize the evaluation function parameters
67   template <typename T>
68   void Initialize(AlignedPtr<T>& pointer) {
69
70     pointer.reset(reinterpret_cast<T*>(std_aligned_alloc(alignof(T), sizeof(T))));
71     std::memset(pointer.get(), 0, sizeof(T));
72   }
73
74   template <typename T>
75   void Initialize(LargePagePtr<T>& pointer) {
76
77     static_assert(alignof(T) <= 4096, "aligned_large_pages_alloc() may fail for such a big alignment requirement of T");
78     pointer.reset(reinterpret_cast<T*>(aligned_large_pages_alloc(sizeof(T))));
79     std::memset(pointer.get(), 0, sizeof(T));
80   }
81
82   // Read evaluation function parameters
83   template <typename T>
84   bool ReadParameters(std::istream& stream, T& reference) {
85
86     std::uint32_t header;
87     header = read_little_endian<std::uint32_t>(stream);
88     if (!stream || header != T::GetHashValue()) return false;
89     return reference.ReadParameters(stream);
90   }
91
92   }  // namespace Detail
93
94   // Initialize the evaluation function parameters
95   void Initialize() {
96
97     Detail::Initialize(feature_transformer);
98     Detail::Initialize(network);
99   }
100
101   // Read network header
102   bool ReadHeader(std::istream& stream, std::uint32_t* hash_value, std::string* architecture)
103   {
104     std::uint32_t version, size;
105
106     version     = read_little_endian<std::uint32_t>(stream);
107     *hash_value = read_little_endian<std::uint32_t>(stream);
108     size        = read_little_endian<std::uint32_t>(stream);
109     if (!stream || version != kVersion) return false;
110     architecture->resize(size);
111     stream.read(&(*architecture)[0], size);
112     return !stream.fail();
113   }
114
115   // Read network parameters
116   bool ReadParameters(std::istream& stream) {
117
118     std::uint32_t hash_value;
119     std::string architecture;
120     if (!ReadHeader(stream, &hash_value, &architecture)) return false;
121     if (hash_value != kHashValue) return false;
122     if (!Detail::ReadParameters(stream, *feature_transformer)) return false;
123     if (!Detail::ReadParameters(stream, *network)) return false;
124     return stream && stream.peek() == std::ios::traits_type::eof();
125   }
126
127   // Evaluation function. Perform differential calculation.
128   Value evaluate(const Position& pos) {
129
130     // We manually align the arrays on the stack because with gcc < 9.3
131     // overaligning stack variables with alignas() doesn't work correctly.
132
133     constexpr uint64_t alignment = kCacheLineSize;
134
135 #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN)
136     TransformedFeatureType transformed_features_unaligned[
137       FeatureTransformer::kBufferSize + alignment / sizeof(TransformedFeatureType)];
138     char buffer_unaligned[Network::kBufferSize + alignment];
139
140     auto* transformed_features = align_ptr_up<alignment>(&transformed_features_unaligned[0]);
141     auto* buffer = align_ptr_up<alignment>(&buffer_unaligned[0]);
142 #else
143     alignas(alignment)
144       TransformedFeatureType transformed_features[FeatureTransformer::kBufferSize];
145     alignas(alignment) char buffer[Network::kBufferSize];
146 #endif
147
148     ASSERT_ALIGNED(transformed_features, alignment);
149     ASSERT_ALIGNED(buffer, alignment);
150
151     feature_transformer->Transform(pos, transformed_features);
152     const auto output = network->Propagate(transformed_features, buffer);
153
154     return static_cast<Value>(output[0] / FV_SCALE);
155   }
156
157   // Load eval, from a file stream or a memory stream
158   bool load_eval(std::string name, std::istream& stream) {
159
160     Initialize();
161     fileName = name;
162     return ReadParameters(stream);
163   }
164
165 } // namespace Eval::NNUE