]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/affine_transform.h
Exporting the currently loaded network file
[stockfish] / src / nnue / layers / affine_transform.h
index 424fad5650f164b5c013ad775b3867b3109eb579..fc1926912df1f9a15a731db50f47dd86e4f8d33e 100644 (file)
@@ -69,15 +69,19 @@ namespace Stockfish::Eval::NNUE::Layers {
       if (!previousLayer.read_parameters(stream)) return false;
       for (std::size_t i = 0; i < OutputDimensions; ++i)
         biases[i] = read_little_endian<BiasType>(stream);
-      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
 #if !defined (USE_SSSE3)
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
         weights[i] = read_little_endian<WeightType>(stream);
 #else
-        weights[
+      std::unique_ptr<uint32_t[]> indexMap = std::make_unique<uint32_t[]>(OutputDimensions * PaddedInputDimensions);
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) {
+        const uint32_t scrambledIdx =
           (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
           i / PaddedInputDimensions * 4 +
-          i % 4
-        ] = read_little_endian<WeightType>(stream);
+          i % 4;
+        weights[scrambledIdx] = read_little_endian<WeightType>(stream);
+        indexMap[scrambledIdx] = i;
+      }
 
       // Determine if eights of weight and input products can be summed using 16bits
       // without saturation. We assume worst case combinations of 0 and 127 for all inputs.
@@ -109,7 +113,8 @@ namespace Stockfish::Eval::NNUE::Layers {
 
                               IndexType idx = maxK / 2 * OutputDimensions * 4 + maxK % 2;
                               sum[sign == -1] -= w[idx];
-                              canSaturate16.add(j, i + maxK / 2 * 4 + maxK % 2 + x * 2, w[idx]);
+                              const uint32_t scrambledIdx = idx + i * OutputDimensions + j * 4 + x * 2;
+                              canSaturate16.add(j, i + maxK / 2 * 4 + maxK % 2 + x * 2, w[idx], indexMap[scrambledIdx]);
                               w[idx] = 0;
                           }
                   }
@@ -125,6 +130,34 @@ namespace Stockfish::Eval::NNUE::Layers {
       return !stream.fail();
     }
 
+    // Write network parameters
+    bool write_parameters(std::ostream& stream) const {
+      if (!previousLayer.write_parameters(stream)) return false;
+      for (std::size_t i = 0; i < OutputDimensions; ++i)
+          write_little_endian<BiasType>(stream, biases[i]);
+#if !defined (USE_SSSE3)
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
+          write_little_endian<WeightType>(stream, weights[i]);
+#else
+      std::unique_ptr<WeightType[]> unscrambledWeights = std::make_unique<WeightType[]>(OutputDimensions * PaddedInputDimensions);
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) {
+          unscrambledWeights[i] =
+              weights[
+                (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
+                i / PaddedInputDimensions * 4 +
+                i % 4
+              ];
+      }
+      for (int i = 0; i < canSaturate16.count; ++i)
+          unscrambledWeights[canSaturate16.ids[i].wIdx] = canSaturate16.ids[i].w;
+
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
+          write_little_endian<WeightType>(stream, unscrambledWeights[i]);
+#endif
+
+      return !stream.fail();
+    }
+
     // Forward propagation
     const OutputType* propagate(
         const TransformedFeatureType* transformedFeatures, char* buffer) const {
@@ -444,12 +477,14 @@ namespace Stockfish::Eval::NNUE::Layers {
     struct CanSaturate {
         int count;
         struct Entry {
+            uint32_t wIdx;
             uint16_t out;
             uint16_t in;
             int8_t w;
         } ids[PaddedInputDimensions * OutputDimensions * 3 / 4];
 
-        void add(int i, int j, int8_t w) {
+        void add(int i, int j, int8_t w, uint32_t wIdx) {
+            ids[count].wIdx = wIdx;
             ids[count].out = i;
             ids[count].in = j;
             ids[count].w = w;