]> git.sesse.net Git - stockfish/commitdiff
Exporting the currently loaded network file
authorTomasz Sobczyk <tomasz.sobczyk1997@gmail.com>
Fri, 7 May 2021 10:24:12 +0000 (12:24 +0200)
committerJoost VandeVondele <Joost.VandeVondele@gmail.com>
Tue, 11 May 2021 17:36:11 +0000 (19:36 +0200)
This PR adds an ability to export any currently loaded network.
The export_net command now takes an optional filename parameter.
If the loaded net is not the embedded net the filename parameter is required.

Two changes were required to support this:

* the "architecture" string, which is really just a some kind of description in the net, is now saved into netDescription on load and correctly saved on export.
* the AffineTransform scrambles weights for some architectures and sparsifies them, such that retrieving the index is hard. This is solved by having a temporary scrambled<->unscrambled index lookup table when loading the network, and the actual index is saved for each individual weight that makes it to canSaturate16. This increases the size of the canSaturate16 entries by 6 bytes.

closes https://github.com/official-stockfish/Stockfish/pull/3456

No functional change

README.md
src/evaluate.cpp
src/evaluate.h
src/nnue/evaluate_nnue.cpp
src/nnue/layers/affine_transform.h
src/nnue/layers/clipped_relu.h
src/nnue/layers/input_slice.h
src/nnue/nnue_common.h
src/nnue/nnue_feature_transformer.h
src/uci.cpp

index 013d4b32caafaeaaab44be436f7b30d65fcea2cd..8d5ce8d0edac51f7c24b66290f2cb973a61afa9e 100644 (file)
--- a/README.md
+++ b/README.md
@@ -24,13 +24,13 @@ This distribution of Stockfish consists of the following files:
   * Readme.md, the file you are currently reading.
 
   * Copying.txt, a text file containing the GNU General Public License version 3.
   * Readme.md, the file you are currently reading.
 
   * Copying.txt, a text file containing the GNU General Public License version 3.
-  
+
   * AUTHORS, a text file with the list of authors for the project
 
   * src, a subdirectory containing the full source code, including a Makefile
     that can be used to compile Stockfish on Unix-like systems.
 
   * AUTHORS, a text file with the list of authors for the project
 
   * src, a subdirectory containing the full source code, including a Makefile
     that can be used to compile Stockfish on Unix-like systems.
 
-  * a file with the .nnue extension, storing the neural network for the NNUE 
+  * a file with the .nnue extension, storing the neural network for the NNUE
     evaluation. Binary distributions will have this file embedded.
 
 ## The UCI protocol and available options
     evaluation. Binary distributions will have this file embedded.
 
 ## The UCI protocol and available options
@@ -156,8 +156,14 @@ For developers the following non-standard commands might be of interest, mainly
   * #### eval
     Return the evaluation of the current position.
 
   * #### eval
     Return the evaluation of the current position.
 
-  * #### export_net
-    If the binary contains an embedded net, save it in a file (named according to the default value of EvalFile).
+  * #### export_net [filename]
+    Exports the currently loaded network to a file.
+    If the currently loaded network is the embedded network and the filename
+    is not specified then the network is saved to the file matching the name
+    of the embedded network, as defined in evaluate.h.
+    If the currently loaded network is not the embedded network (some net set
+    through the UCI setoption) then the filename parameter is required and the
+    network is saved into that file.
 
   * #### flip
     Flips the side to move.
 
   * #### flip
     Flips the side to move.
@@ -189,7 +195,7 @@ Stockfish binary, but the default value of the EvalFile UCI option is the name o
 that is guaranteed to be compatible with that binary.
 
 2) to use the NNUE evaluation, the additional data file with neural network parameters
 that is guaranteed to be compatible with that binary.
 
 2) to use the NNUE evaluation, the additional data file with neural network parameters
-needs to be available. Normally, this file is already embedded in the binary or it 
+needs to be available. Normally, this file is already embedded in the binary or it
 can be downloaded. The filename for the default (recommended) net can be found as the default
 value of the `EvalFile` UCI option, with the format `nn-[SHA256 first 12 digits].nnue`
 (for instance, `nn-c157e0a5755b.nnue`). This file can be downloaded from
 can be downloaded. The filename for the default (recommended) net can be found as the default
 value of the `EvalFile` UCI option, with the format `nn-[SHA256 first 12 digits].nnue`
 (for instance, `nn-c157e0a5755b.nnue`). This file can be downloaded from
@@ -202,7 +208,7 @@ replacing `[filename]` as needed.
 
 If the engine is searching a position that is not in the tablebases (e.g.
 a position with 8 pieces), it will access the tablebases during the search.
 
 If the engine is searching a position that is not in the tablebases (e.g.
 a position with 8 pieces), it will access the tablebases during the search.
-If the engine reports a very large score (typically 153.xx), this means 
+If the engine reports a very large score (typically 153.xx), this means
 it has found a winning line into a tablebase position.
 
 If the engine is given a position to search that is in the tablebases, it
 it has found a winning line into a tablebase position.
 
 If the engine is given a position to search that is in the tablebases, it
index f0784e8f915e113f414ba85e2496c9e7eab66156..c396e0f757569bbf1674c3fb44bd2b303b1c40d0 100644 (file)
@@ -47,9 +47,7 @@
 // Note that this does not work in Microsoft Visual Studio.
 #if !defined(_MSC_VER) && !defined(NNUE_EMBEDDING_OFF)
   INCBIN(EmbeddedNNUE, EvalFileDefaultName);
 // Note that this does not work in Microsoft Visual Studio.
 #if !defined(_MSC_VER) && !defined(NNUE_EMBEDDING_OFF)
   INCBIN(EmbeddedNNUE, EvalFileDefaultName);
-  constexpr bool             gHasEmbeddedNet = true;
 #else
 #else
-  constexpr bool             gHasEmbeddedNet = false;
   const unsigned char        gEmbeddedNNUEData[1] = {0x0};
   const unsigned char *const gEmbeddedNNUEEnd = &gEmbeddedNNUEData[1];
   const unsigned int         gEmbeddedNNUESize = 1;
   const unsigned char        gEmbeddedNNUEData[1] = {0x0};
   const unsigned char *const gEmbeddedNNUEEnd = &gEmbeddedNNUEData[1];
   const unsigned int         gEmbeddedNNUESize = 1;
@@ -116,12 +114,23 @@ namespace Eval {
         }
   }
 
         }
   }
 
-  void NNUE::export_net() {
-    if constexpr (gHasEmbeddedNet) {
-      ofstream stream(EvalFileDefaultName, std::ios_base::binary);
-      stream.write(reinterpret_cast<const char*>(gEmbeddedNNUEData), gEmbeddedNNUESize);
+  void NNUE::export_net(const std::optional<std::string>& filename) {
+    std::string actualFilename;
+    if (filename.has_value()) {
+      actualFilename = filename.value();
     } else {
     } else {
-      sync_cout << "No embedded network file." << sync_endl;
+      if (eval_file_loaded != EvalFileDefaultName) {
+        sync_cout << "Failed to export a net. A non-embedded net can only be saved if the filename is specified." << sync_endl;
+        return;
+      }
+      actualFilename = EvalFileDefaultName;
+    }
+
+    ofstream stream(actualFilename, std::ios_base::binary);
+    if (save_eval(stream)) {
+        sync_cout << "Network saved successfully to " << actualFilename << "." << sync_endl;
+    } else {
+        sync_cout << "Failed to export a net." << sync_endl;
     }
   }
 
     }
   }
 
@@ -1128,7 +1137,7 @@ Value Eval::evaluate(const Position& pos) {
       bool lowPieceEndgame =   pos.non_pawn_material() == BishopValueMg
                             || (pos.non_pawn_material() < 2 * RookValueMg && pos.count<PAWN>() < 2);
 
       bool lowPieceEndgame =   pos.non_pawn_material() == BishopValueMg
                             || (pos.non_pawn_material() < 2 * RookValueMg && pos.count<PAWN>() < 2);
 
-      v = classical || lowPieceEndgame ? Evaluation<NO_TRACE>(pos).value() 
+      v = classical || lowPieceEndgame ? Evaluation<NO_TRACE>(pos).value()
                                        : adjusted_NNUE();
 
       // If the classical eval is small and imbalance large, use NNUE nevertheless.
                                        : adjusted_NNUE();
 
       // If the classical eval is small and imbalance large, use NNUE nevertheless.
index b7525aab8bbed6d92b72ff376e61348962f91b07..128a7caefcc13611f05437d19e8e4e42207ba354 100644 (file)
@@ -20,6 +20,7 @@
 #define EVALUATE_H_INCLUDED
 
 #include <string>
 #define EVALUATE_H_INCLUDED
 
 #include <string>
+#include <optional>
 
 #include "types.h"
 
 
 #include "types.h"
 
@@ -44,8 +45,9 @@ namespace Eval {
 
     Value evaluate(const Position& pos);
     bool load_eval(std::string name, std::istream& stream);
 
     Value evaluate(const Position& pos);
     bool load_eval(std::string name, std::istream& stream);
+    bool save_eval(std::ostream& stream);
     void init();
     void init();
-    void export_net();
+    void export_net(const std::optional<std::string>& filename);
     void verify();
 
   } // namespace NNUE
     void verify();
 
   } // namespace NNUE
index 0e53961167140228e163cf55b98073269624c918..e0d4b9117c707676e2300e5424503e8f27e59114 100644 (file)
@@ -39,6 +39,7 @@ namespace Stockfish::Eval::NNUE {
 
   // Evaluation function file name
   std::string fileName;
 
   // Evaluation function file name
   std::string fileName;
+  std::string netDescription;
 
   namespace Detail {
 
 
   namespace Detail {
 
@@ -68,6 +69,14 @@ namespace Stockfish::Eval::NNUE {
     return reference.read_parameters(stream);
   }
 
     return reference.read_parameters(stream);
   }
 
+  // Write evaluation function parameters
+  template <typename T>
+  bool write_parameters(std::ostream& stream, const T& reference) {
+
+    write_little_endian<std::uint32_t>(stream, T::get_hash_value());
+    return reference.write_parameters(stream);
+  }
+
   }  // namespace Detail
 
   // Initialize the evaluation function parameters
   }  // namespace Detail
 
   // Initialize the evaluation function parameters
@@ -78,7 +87,7 @@ namespace Stockfish::Eval::NNUE {
   }
 
   // Read network header
   }
 
   // Read network header
-  bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* architecture)
+  bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc)
   {
     std::uint32_t version, size;
 
   {
     std::uint32_t version, size;
 
@@ -86,8 +95,18 @@ namespace Stockfish::Eval::NNUE {
     *hashValue = read_little_endian<std::uint32_t>(stream);
     size        = read_little_endian<std::uint32_t>(stream);
     if (!stream || version != Version) return false;
     *hashValue = read_little_endian<std::uint32_t>(stream);
     size        = read_little_endian<std::uint32_t>(stream);
     if (!stream || version != Version) return false;
-    architecture->resize(size);
-    stream.read(&(*architecture)[0], size);
+    desc->resize(size);
+    stream.read(&(*desc)[0], size);
+    return !stream.fail();
+  }
+
+  // Write network header
+  bool write_header(std::ostream& stream, std::uint32_t hashValue, const std::string& desc)
+  {
+    write_little_endian<std::uint32_t>(stream, Version);
+    write_little_endian<std::uint32_t>(stream, hashValue);
+    write_little_endian<std::uint32_t>(stream, desc.size());
+    stream.write(&desc[0], desc.size());
     return !stream.fail();
   }
 
     return !stream.fail();
   }
 
@@ -95,14 +114,22 @@ namespace Stockfish::Eval::NNUE {
   bool read_parameters(std::istream& stream) {
 
     std::uint32_t hashValue;
   bool read_parameters(std::istream& stream) {
 
     std::uint32_t hashValue;
-    std::string architecture;
-    if (!read_header(stream, &hashValue, &architecture)) return false;
+    if (!read_header(stream, &hashValue, &netDescription)) return false;
     if (hashValue != HashValue) return false;
     if (!Detail::read_parameters(stream, *featureTransformer)) return false;
     if (!Detail::read_parameters(stream, *network)) return false;
     return stream && stream.peek() == std::ios::traits_type::eof();
   }
 
     if (hashValue != HashValue) return false;
     if (!Detail::read_parameters(stream, *featureTransformer)) return false;
     if (!Detail::read_parameters(stream, *network)) return false;
     return stream && stream.peek() == std::ios::traits_type::eof();
   }
 
+  // Write network parameters
+  bool write_parameters(std::ostream& stream) {
+
+    if (!write_header(stream, HashValue, netDescription)) return false;
+    if (!Detail::write_parameters(stream, *featureTransformer)) return false;
+    if (!Detail::write_parameters(stream, *network)) return false;
+    return (bool)stream;
+  }
+
   // Evaluation function. Perform differential calculation.
   Value evaluate(const Position& pos) {
 
   // Evaluation function. Perform differential calculation.
   Value evaluate(const Position& pos) {
 
@@ -141,4 +168,13 @@ namespace Stockfish::Eval::NNUE {
     return read_parameters(stream);
   }
 
     return read_parameters(stream);
   }
 
+  // Save eval, to a file stream or a memory stream
+  bool save_eval(std::ostream& stream) {
+
+    if (fileName.empty())
+      return false;
+
+    return write_parameters(stream);
+  }
+
 } // namespace Stockfish::Eval::NNUE
 } // namespace Stockfish::Eval::NNUE
index 424fad5650f164b5c013ad775b3867b3109eb579..fc1926912df1f9a15a731db50f47dd86e4f8d33e 100644 (file)
@@ -69,15 +69,19 @@ namespace Stockfish::Eval::NNUE::Layers {
       if (!previousLayer.read_parameters(stream)) return false;
       for (std::size_t i = 0; i < OutputDimensions; ++i)
         biases[i] = read_little_endian<BiasType>(stream);
       if (!previousLayer.read_parameters(stream)) return false;
       for (std::size_t i = 0; i < OutputDimensions; ++i)
         biases[i] = read_little_endian<BiasType>(stream);
-      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
 #if !defined (USE_SSSE3)
 #if !defined (USE_SSSE3)
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
         weights[i] = read_little_endian<WeightType>(stream);
 #else
         weights[i] = read_little_endian<WeightType>(stream);
 #else
-        weights[
+      std::unique_ptr<uint32_t[]> indexMap = std::make_unique<uint32_t[]>(OutputDimensions * PaddedInputDimensions);
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) {
+        const uint32_t scrambledIdx =
           (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
           i / PaddedInputDimensions * 4 +
           (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
           i / PaddedInputDimensions * 4 +
-          i % 4
-        ] = read_little_endian<WeightType>(stream);
+          i % 4;
+        weights[scrambledIdx] = read_little_endian<WeightType>(stream);
+        indexMap[scrambledIdx] = i;
+      }
 
       // Determine if eights of weight and input products can be summed using 16bits
       // without saturation. We assume worst case combinations of 0 and 127 for all inputs.
 
       // Determine if eights of weight and input products can be summed using 16bits
       // without saturation. We assume worst case combinations of 0 and 127 for all inputs.
@@ -109,7 +113,8 @@ namespace Stockfish::Eval::NNUE::Layers {
 
                               IndexType idx = maxK / 2 * OutputDimensions * 4 + maxK % 2;
                               sum[sign == -1] -= w[idx];
 
                               IndexType idx = maxK / 2 * OutputDimensions * 4 + maxK % 2;
                               sum[sign == -1] -= w[idx];
-                              canSaturate16.add(j, i + maxK / 2 * 4 + maxK % 2 + x * 2, w[idx]);
+                              const uint32_t scrambledIdx = idx + i * OutputDimensions + j * 4 + x * 2;
+                              canSaturate16.add(j, i + maxK / 2 * 4 + maxK % 2 + x * 2, w[idx], indexMap[scrambledIdx]);
                               w[idx] = 0;
                           }
                   }
                               w[idx] = 0;
                           }
                   }
@@ -125,6 +130,34 @@ namespace Stockfish::Eval::NNUE::Layers {
       return !stream.fail();
     }
 
       return !stream.fail();
     }
 
+    // Write network parameters
+    bool write_parameters(std::ostream& stream) const {
+      if (!previousLayer.write_parameters(stream)) return false;
+      for (std::size_t i = 0; i < OutputDimensions; ++i)
+          write_little_endian<BiasType>(stream, biases[i]);
+#if !defined (USE_SSSE3)
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
+          write_little_endian<WeightType>(stream, weights[i]);
+#else
+      std::unique_ptr<WeightType[]> unscrambledWeights = std::make_unique<WeightType[]>(OutputDimensions * PaddedInputDimensions);
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) {
+          unscrambledWeights[i] =
+              weights[
+                (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
+                i / PaddedInputDimensions * 4 +
+                i % 4
+              ];
+      }
+      for (int i = 0; i < canSaturate16.count; ++i)
+          unscrambledWeights[canSaturate16.ids[i].wIdx] = canSaturate16.ids[i].w;
+
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
+          write_little_endian<WeightType>(stream, unscrambledWeights[i]);
+#endif
+
+      return !stream.fail();
+    }
+
     // Forward propagation
     const OutputType* propagate(
         const TransformedFeatureType* transformedFeatures, char* buffer) const {
     // Forward propagation
     const OutputType* propagate(
         const TransformedFeatureType* transformedFeatures, char* buffer) const {
@@ -444,12 +477,14 @@ namespace Stockfish::Eval::NNUE::Layers {
     struct CanSaturate {
         int count;
         struct Entry {
     struct CanSaturate {
         int count;
         struct Entry {
+            uint32_t wIdx;
             uint16_t out;
             uint16_t in;
             int8_t w;
         } ids[PaddedInputDimensions * OutputDimensions * 3 / 4];
 
             uint16_t out;
             uint16_t in;
             int8_t w;
         } ids[PaddedInputDimensions * OutputDimensions * 3 / 4];
 
-        void add(int i, int j, int8_t w) {
+        void add(int i, int j, int8_t w, uint32_t wIdx) {
+            ids[count].wIdx = wIdx;
             ids[count].out = i;
             ids[count].in = j;
             ids[count].w = w;
             ids[count].out = i;
             ids[count].in = j;
             ids[count].w = w;
index 00809c507b3d3cf1eaac5c0f22d4c67054fcf652..f1ac2dfe64455a807d4ee1083f157a2dc9b33b6c 100644 (file)
@@ -59,6 +59,11 @@ namespace Stockfish::Eval::NNUE::Layers {
       return previousLayer.read_parameters(stream);
     }
 
       return previousLayer.read_parameters(stream);
     }
 
+    // Write network parameters
+    bool write_parameters(std::ostream& stream) const {
+      return previousLayer.write_parameters(stream);
+    }
+
     // Forward propagation
     const OutputType* propagate(
         const TransformedFeatureType* transformedFeatures, char* buffer) const {
     // Forward propagation
     const OutputType* propagate(
         const TransformedFeatureType* transformedFeatures, char* buffer) const {
index f113b911239d789fc487a6ae9dc32462fa8c0aa5..bd4d74478de6a44cfe2474845705e5f30aa42f71 100644 (file)
@@ -53,6 +53,11 @@ class InputSlice {
     return true;
   }
 
     return true;
   }
 
+  // Read network parameters
+  bool write_parameters(std::ostream& /*stream*/) const {
+    return true;
+  }
+
   // Forward propagation
   const OutputType* propagate(
       const TransformedFeatureType* transformedFeatures,
   // Forward propagation
   const OutputType* propagate(
       const TransformedFeatureType* transformedFeatures,
index 8c54f9baeebb46eeac67138f3d56665612054124..d41e02377ac34b35378a969034682a9b3029190b 100644 (file)
@@ -99,6 +99,24 @@ namespace Stockfish::Eval::NNUE {
       return result;
   }
 
       return result;
   }
 
+  template <typename IntType>
+  inline void write_little_endian(std::ostream& stream, IntType value) {
+
+      std::uint8_t u[sizeof(IntType)];
+      typename std::make_unsigned<IntType>::type v = value;
+
+      std::size_t i = 0;
+      // if constexpr to silence the warning about shift by 8
+      if constexpr (sizeof(IntType) > 1) {
+        for (; i + 1 < sizeof(IntType); ++i) {
+            u[i] = v;
+            v >>= 8;
+        }
+      }
+      u[i] = v;
+
+      stream.write(reinterpret_cast<char*>(u), sizeof(IntType));
+  }
 }  // namespace Stockfish::Eval::NNUE
 
 #endif // #ifndef NNUE_COMMON_H_INCLUDED
 }  // namespace Stockfish::Eval::NNUE
 
 #endif // #ifndef NNUE_COMMON_H_INCLUDED
index f441274915329f3605831414841920a65093f0db..a4a8e98f9c5e8f579cea140b77126f9763184421 100644 (file)
@@ -118,6 +118,15 @@ namespace Stockfish::Eval::NNUE {
       return !stream.fail();
     }
 
       return !stream.fail();
     }
 
+    // Write network parameters
+    bool write_parameters(std::ostream& stream) const {
+      for (std::size_t i = 0; i < HalfDimensions; ++i)
+        write_little_endian<BiasType>(stream, biases[i]);
+      for (std::size_t i = 0; i < HalfDimensions * InputDimensions; ++i)
+        write_little_endian<WeightType>(stream, weights[i]);
+      return !stream.fail();
+    }
+
     // Convert input features
     void transform(const Position& pos, OutputType* output) const {
       update_accumulator(pos, WHITE);
     // Convert input features
     void transform(const Position& pos, OutputType* output) const {
       update_accumulator(pos, WHITE);
index 64bb7a7cb59b41f2093c1ef66de6d5b7f8d7d6ae..bb17b8d79b21ab8cd3e13855bcd404f7b074d545 100644 (file)
@@ -277,7 +277,14 @@ void UCI::loop(int argc, char* argv[]) {
       else if (token == "d")        sync_cout << pos << sync_endl;
       else if (token == "eval")     trace_eval(pos);
       else if (token == "compiler") sync_cout << compiler_info() << sync_endl;
       else if (token == "d")        sync_cout << pos << sync_endl;
       else if (token == "eval")     trace_eval(pos);
       else if (token == "compiler") sync_cout << compiler_info() << sync_endl;
-      else if (token == "export_net") Eval::NNUE::export_net();
+      else if (token == "export_net") {
+          std::optional<std::string> filename;
+          std::string f;
+          if (is >> skipws >> f) {
+            filename = f;
+          }
+          Eval::NNUE::export_net(filename);
+      }
       else if (!token.empty() && token[0] != '#')
           sync_cout << "Unknown command: " << cmd << sync_endl;
 
       else if (!token.empty() && token[0] != '#')
           sync_cout << "Unknown command: " << cmd << sync_endl;