]> git.sesse.net Git - stockfish/blobdiff - src/nnue/evaluate_nnue.cpp
New NNUE architecture and net
[stockfish] / src / nnue / evaluate_nnue.cpp
index e0d4b9117c707676e2300e5424503e8f27e59114..97cef81480fa8033525c2ce05eaf5eb495a675f7 100644 (file)
@@ -35,7 +35,7 @@ namespace Stockfish::Eval::NNUE {
   LargePagePtr<FeatureTransformer> featureTransformer;
 
   // Evaluation function
-  AlignedPtr<Network> network;
+  AlignedPtr<Network> network[LayerStacks];
 
   // Evaluation function file name
   std::string fileName;
@@ -83,7 +83,8 @@ namespace Stockfish::Eval::NNUE {
   void initialize() {
 
     Detail::initialize(featureTransformer);
-    Detail::initialize(network);
+    for (std::size_t i = 0; i < LayerStacks; ++i)
+      Detail::initialize(network[i]);
   }
 
   // Read network header
@@ -92,7 +93,7 @@ namespace Stockfish::Eval::NNUE {
     std::uint32_t version, size;
 
     version     = read_little_endian<std::uint32_t>(stream);
-    *hashValue = read_little_endian<std::uint32_t>(stream);
+    *hashValue  = read_little_endian<std::uint32_t>(stream);
     size        = read_little_endian<std::uint32_t>(stream);
     if (!stream || version != Version) return false;
     desc->resize(size);
@@ -117,7 +118,8 @@ namespace Stockfish::Eval::NNUE {
     if (!read_header(stream, &hashValue, &netDescription)) return false;
     if (hashValue != HashValue) return false;
     if (!Detail::read_parameters(stream, *featureTransformer)) return false;
-    if (!Detail::read_parameters(stream, *network)) return false;
+    for (std::size_t i = 0; i < LayerStacks; ++i)
+      if (!Detail::read_parameters(stream, *(network[i]))) return false;
     return stream && stream.peek() == std::ios::traits_type::eof();
   }
 
@@ -126,7 +128,8 @@ namespace Stockfish::Eval::NNUE {
 
     if (!write_header(stream, HashValue, netDescription)) return false;
     if (!Detail::write_parameters(stream, *featureTransformer)) return false;
-    if (!Detail::write_parameters(stream, *network)) return false;
+    for (std::size_t i = 0; i < LayerStacks; ++i)
+      if (!Detail::write_parameters(stream, *(network[i]))) return false;
     return (bool)stream;
   }
 
@@ -154,10 +157,15 @@ namespace Stockfish::Eval::NNUE {
     ASSERT_ALIGNED(transformedFeatures, alignment);
     ASSERT_ALIGNED(buffer, alignment);
 
-    featureTransformer->transform(pos, transformedFeatures);
-    const auto output = network->propagate(transformedFeatures, buffer);
+    const std::size_t bucket = (pos.count<ALL_PIECES>() - 1) / 4;
 
-    return static_cast<Value>(output[0] / OutputScale);
+    const auto [psqt, lazy] = featureTransformer->transform(pos, transformedFeatures, bucket);
+    if (lazy) {
+      return static_cast<Value>(psqt / OutputScale);
+    } else {
+      const auto output = network[bucket]->propagate(transformedFeatures, buffer);
+      return static_cast<Value>((output[0] + psqt) / OutputScale);
+    }
   }
 
   // Load eval, from a file stream or a memory stream