]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_feature_transformer.h
Reduce SIMD register count from 32 to 16
[stockfish] / src / nnue / nnue_feature_transformer.h
index 8087ea55dd01c59baaabbebdad2acc5b1557dd66..902918b2c6ff433656a110c9714481a7b72eac77 100644 (file)
 #ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED
 #define NNUE_FEATURE_TRANSFORMER_H_INCLUDED
 
-#include "nnue_common.h"
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <iosfwd>
+#include <utility>
+
+#include "../position.h"
+#include "../types.h"
+#include "nnue_accumulator.h"
 #include "nnue_architecture.h"
-
-#include <cstring> // std::memset()
-#include <utility> // std::pair
+#include "nnue_common.h"
 
 namespace Stockfish::Eval::NNUE {
 
@@ -62,7 +69,7 @@ namespace Stockfish::Eval::NNUE {
   #define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
   #define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
   #define vec_zero_psqt() _mm256_setzero_si256()
-  #define NumRegistersSIMD 32
+  #define NumRegistersSIMD 16
   #define MaxChunkSize 64
 
   #elif USE_AVX2
@@ -253,9 +260,9 @@ namespace Stockfish::Eval::NNUE {
     // Read network parameters
     bool read_parameters(std::istream& stream) {
 
-      read_little_endian<BiasType      >(stream, biases     , HalfDimensions                  );
-      read_little_endian<WeightType    >(stream, weights    , HalfDimensions * InputDimensions);
-      read_little_endian<PSQTWeightType>(stream, psqtWeights, PSQTBuckets    * InputDimensions);
+      read_leb_128<BiasType      >(stream, biases     , HalfDimensions                  );
+      read_leb_128<WeightType    >(stream, weights    , HalfDimensions * InputDimensions);
+      read_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets    * InputDimensions);
 
       return !stream.fail();
     }
@@ -263,9 +270,9 @@ namespace Stockfish::Eval::NNUE {
     // Write network parameters
     bool write_parameters(std::ostream& stream) const {
 
-      write_little_endian<BiasType      >(stream, biases     , HalfDimensions                  );
-      write_little_endian<WeightType    >(stream, weights    , HalfDimensions * InputDimensions);
-      write_little_endian<PSQTWeightType>(stream, psqtWeights, PSQTBuckets    * InputDimensions);
+      write_leb_128<BiasType      >(stream, biases     , HalfDimensions                  );
+      write_leb_128<WeightType    >(stream, weights    , HalfDimensions * InputDimensions);
+      write_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets    * InputDimensions);
 
       return !stream.fail();
     }
@@ -320,9 +327,9 @@ namespace Stockfish::Eval::NNUE {
           for (IndexType j = 0; j < HalfDimensions / 2; ++j) {
               BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];
               BiasType sum1 = accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];
-              sum0 = std::max<int>(0, std::min<int>(127, sum0));
-              sum1 = std::max<int>(0, std::min<int>(127, sum1));
-              output[offset + j] = static_cast<OutputType>(sum0 * sum1 / 128);
+              sum0 = std::clamp<BiasType>(sum0, 0, 127);
+              sum1 = std::clamp<BiasType>(sum1, 0, 127);
+              output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 128);
           }
 
 #endif
@@ -363,7 +370,7 @@ namespace Stockfish::Eval::NNUE {
     // NOTE: The parameter states_to_update is an array of position states, ending with nullptr.
     //       All states must be sequential, that is states_to_update[i] must either be reachable
     //       by repeatedly applying ->previous from states_to_update[i+1] or states_to_update[i] == nullptr.
-    //       computed_st must be reachable by repeatadly applying ->previous on states_to_update[0], if not nullptr.
+    //       computed_st must be reachable by repeatedly applying ->previous on states_to_update[0], if not nullptr.
     template<Color Perspective, size_t N>
     void update_accumulator_incremental(const Position& pos, StateInfo* computed_st, StateInfo* states_to_update[N]) const {
       static_assert(N > 0);