]> git.sesse.net Git - stockfish/blobdiff - src/nnue/evaluate_nnue.cpp
Do not use lazy evaluation inside NNUE
[stockfish] / src / nnue / evaluate_nnue.cpp
index 0e53961167140228e163cf55b98073269624c918..4a3c206b8087fcb065ad0b9c4e595c68495701a8 100644 (file)
@@ -35,10 +35,11 @@ namespace Stockfish::Eval::NNUE {
   LargePagePtr<FeatureTransformer> featureTransformer;
 
   // Evaluation function
-  AlignedPtr<Network> network;
+  AlignedPtr<Network> network[LayerStacks];
 
   // Evaluation function file name
   std::string fileName;
+  std::string netDescription;
 
   namespace Detail {
 
@@ -68,26 +69,45 @@ namespace Stockfish::Eval::NNUE {
     return reference.read_parameters(stream);
   }
 
+  // Write evaluation function parameters
+  template <typename T>
+  bool write_parameters(std::ostream& stream, const T& reference) {
+
+    write_little_endian<std::uint32_t>(stream, T::get_hash_value());
+    return reference.write_parameters(stream);
+  }
+
   }  // namespace Detail
 
   // Initialize the evaluation function parameters
   void initialize() {
 
     Detail::initialize(featureTransformer);
-    Detail::initialize(network);
+    for (std::size_t i = 0; i < LayerStacks; ++i)
+      Detail::initialize(network[i]);
   }
 
   // Read network header
-  bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* architecture)
+  bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc)
   {
     std::uint32_t version, size;
 
     version     = read_little_endian<std::uint32_t>(stream);
-    *hashValue = read_little_endian<std::uint32_t>(stream);
+    *hashValue  = read_little_endian<std::uint32_t>(stream);
     size        = read_little_endian<std::uint32_t>(stream);
     if (!stream || version != Version) return false;
-    architecture->resize(size);
-    stream.read(&(*architecture)[0], size);
+    desc->resize(size);
+    stream.read(&(*desc)[0], size);
+    return !stream.fail();
+  }
+
+  // Write network header
+  bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc)
+  {
+    write_little_endian<std::uint32_t>(stream, Version);
+    write_little_endian<std::uint32_t>(stream, hashValue);
+    write_little_endian<std::uint32_t>(stream, desc.size());
+    stream.write(&desc[0], desc.size());
     return !stream.fail();
   }
 
@@ -95,16 +115,26 @@ namespace Stockfish::Eval::NNUE {
   bool read_parameters(std::istream& stream) {
 
     std::uint32_t hashValue;
-    std::string architecture;
-    if (!read_header(stream, &hashValue, &architecture)) return false;
+    if (!read_header(stream, &hashValue, &netDescription)) return false;
     if (hashValue != HashValue) return false;
     if (!Detail::read_parameters(stream, *featureTransformer)) return false;
-    if (!Detail::read_parameters(stream, *network)) return false;
+    for (std::size_t i = 0; i < LayerStacks; ++i)
+      if (!Detail::read_parameters(stream, *(network[i]))) return false;
     return stream && stream.peek() == std::ios::traits_type::eof();
   }
 
+  // Write network parameters
+  bool write_parameters(std::ostream& stream) {
+
+    if (!write_header(stream, HashValue, netDescription)) return false;
+    if (!Detail::write_parameters(stream, *featureTransformer)) return false;
+    for (std::size_t i = 0; i < LayerStacks; ++i)
+      if (!Detail::write_parameters(stream, *(network[i]))) return false;
+    return (bool)stream;
+  }
+
   // Evaluation function. Perform differential calculation.
-  Value evaluate(const Position& pos) {
+  Value evaluate(const Position& pos, bool adjusted) {
 
     // We manually align the arrays on the stack because with gcc < 9.3
     // overaligning stack variables with alignas() doesn't work correctly.
@@ -127,10 +157,22 @@ namespace Stockfish::Eval::NNUE {
     ASSERT_ALIGNED(transformedFeatures, alignment);
     ASSERT_ALIGNED(buffer, alignment);
 
-    featureTransformer->transform(pos, transformedFeatures);
-    const auto output = network->propagate(transformedFeatures, buffer);
+    const std::size_t bucket = (pos.count<ALL_PIECES>() - 1) / 4;
+    const auto psqt = featureTransformer->transform(pos, transformedFeatures, bucket);
+    const auto output = network[bucket]->propagate(transformedFeatures, buffer);
+
+    int materialist = psqt;
+    int positional  = output[0];
+
+    int delta_npm = abs(pos.non_pawn_material(WHITE) - pos.non_pawn_material(BLACK));
+    int entertainment = (adjusted && delta_npm <= BishopValueMg - KnightValueMg ? 7 : 0);
+
+    int A = 128 - entertainment;
+    int B = 128 + entertainment;
+
+    int sum = (A * materialist + B * positional) / 128;
 
-    return static_cast<Value>(output[0] / OutputScale);
+    return static_cast<Value>( sum / OutputScale );
   }
 
   // Load eval, from a file stream or a memory stream
@@ -141,4 +183,13 @@ namespace Stockfish::Eval::NNUE {
     return read_parameters(stream);
   }
 
+  // Save eval, to a file stream or a memory stream
+  bool save_eval(std::ostream& stream) {
+
+    if (fileName.empty())
+      return false;
+
+    return write_parameters(stream);
+  }
+
 } // namespace Stockfish::Eval::NNUE