]> git.sesse.net Git - stockfish/blobdiff - src/nnue/evaluate_nnue.cpp
Exporting the currently loaded network file
[stockfish] / src / nnue / evaluate_nnue.cpp
index 0e53961167140228e163cf55b98073269624c918..e0d4b9117c707676e2300e5424503e8f27e59114 100644 (file)
@@ -39,6 +39,7 @@ namespace Stockfish::Eval::NNUE {
 
   // Evaluation function file name
   std::string fileName;
+  std::string netDescription;
 
   namespace Detail {
 
@@ -68,6 +69,14 @@ namespace Stockfish::Eval::NNUE {
     return reference.read_parameters(stream);
   }
 
+  // Write evaluation function parameters
+  template <typename T>
+  bool write_parameters(std::ostream& stream, const T& reference) {
+
+    write_little_endian<std::uint32_t>(stream, T::get_hash_value());
+    return reference.write_parameters(stream);
+  }
+
   }  // namespace Detail
 
   // Initialize the evaluation function parameters
@@ -78,7 +87,7 @@ namespace Stockfish::Eval::NNUE {
   }
 
   // Read network header
-  bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* architecture)
+  bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc)
   {
     std::uint32_t version, size;
 
@@ -86,8 +95,18 @@ namespace Stockfish::Eval::NNUE {
     *hashValue = read_little_endian<std::uint32_t>(stream);
     size        = read_little_endian<std::uint32_t>(stream);
     if (!stream || version != Version) return false;
-    architecture->resize(size);
-    stream.read(&(*architecture)[0], size);
+    desc->resize(size);
+    stream.read(&(*desc)[0], size);
+    return !stream.fail();
+  }
+
+  // Write network header
+  bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc)
+  {
+    write_little_endian<std::uint32_t>(stream, Version);
+    write_little_endian<std::uint32_t>(stream, hashValue);
+    write_little_endian<std::uint32_t>(stream, desc.size());
+    stream.write(&desc[0], desc.size());
     return !stream.fail();
   }
 
@@ -95,14 +114,22 @@ namespace Stockfish::Eval::NNUE {
   bool read_parameters(std::istream& stream) {
 
     std::uint32_t hashValue;
-    std::string architecture;
-    if (!read_header(stream, &hashValue, &architecture)) return false;
+    if (!read_header(stream, &hashValue, &netDescription)) return false;
     if (hashValue != HashValue) return false;
     if (!Detail::read_parameters(stream, *featureTransformer)) return false;
     if (!Detail::read_parameters(stream, *network)) return false;
     return stream && stream.peek() == std::ios::traits_type::eof();
   }
 
+  // Write network parameters
+  bool write_parameters(std::ostream& stream) {
+
+    if (!write_header(stream, HashValue, netDescription)) return false;
+    if (!Detail::write_parameters(stream, *featureTransformer)) return false;
+    if (!Detail::write_parameters(stream, *network)) return false;
+    return (bool)stream;
+  }
+
   // Evaluation function. Perform differential calculation.
   Value evaluate(const Position& pos) {
 
@@ -141,4 +168,13 @@ namespace Stockfish::Eval::NNUE {
     return read_parameters(stream);
   }
 
+  // Save eval, to a file stream or a memory stream
+  bool save_eval(std::ostream& stream) {
+
+    if (fileName.empty())
+      return false;
+
+    return write_parameters(stream);
+  }
+
 } // namespace Stockfish::Eval::NNUE