]> git.sesse.net Git - stockfish/blobdiff - src/nnue/nnue_architecture.h
Cleanup and simplify NNUE code.
[stockfish] / src / nnue / nnue_architecture.h
index f59474df463f9961253c537c09123de6d2cb46af..55a01fbe15db42d56880424ba8ff5a07808edb6a 100644 (file)
 #ifndef NNUE_ARCHITECTURE_H_INCLUDED
 #define NNUE_ARCHITECTURE_H_INCLUDED
 
-// Defines the network structure
-#include "architectures/halfkp_256x2-32-32.h"
+#include "nnue_common.h"
+
+#include "features/half_kp.h"
+
+#include "layers/input_slice.h"
+#include "layers/affine_transform.h"
+#include "layers/clipped_relu.h"
 
 namespace Stockfish::Eval::NNUE {
 
+  // Input features used in evaluation function
+  using FeatureSet = Features::HalfKP;
+
+  // Number of input feature dimensions after conversion
+  constexpr IndexType TransformedFeatureDimensions = 256;
+
+  namespace Layers {
+
+    // Define network structure
+    using InputLayer = InputSlice<TransformedFeatureDimensions * 2>;
+    using HiddenLayer1 = ClippedReLU<AffineTransform<InputLayer, 32>>;
+    using HiddenLayer2 = ClippedReLU<AffineTransform<HiddenLayer1, 32>>;
+    using OutputLayer = AffineTransform<HiddenLayer2, 1>;
+
+  }  // namespace Layers
+
+  using Network = Layers::OutputLayer;
+
   static_assert(TransformedFeatureDimensions % MaxSimdWidth == 0, "");
   static_assert(Network::OutputDimensions == 1, "");
   static_assert(std::is_same<Network::OutputType, std::int32_t>::value, "");
 
-  // Trigger for full calculation instead of difference calculation
-  constexpr auto RefreshTriggers = RawFeatures::RefreshTriggers;
-
 }  // namespace Stockfish::Eval::NNUE
 
 #endif // #ifndef NNUE_ARCHITECTURE_H_INCLUDED