X-Git-Url: https://git.sesse.net/?p=stockfish;a=blobdiff_plain;f=src%2Fnnue%2Flayers%2Faffine_transform.h;fp=src%2Fnnue%2Flayers%2Faffine_transform.h;h=461a7b83eca829a9a45c5d215e11dd5991280af6;hp=9a992608cc1a06a9808c4c96228a696bde3cc280;hb=82bb21dc7a198609589ef0cc78d185f00f619a90;hpb=1591e5ac3b24f068f965471f17d7aae33ceaab9f diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index 9a992608..461a7b83 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -25,7 +25,7 @@ #include #include #include "../nnue_common.h" -#include "../../simd.h" +#include "simd.h" /* This file contains the definition for a fully connected layer (aka affine transform). @@ -151,9 +151,15 @@ namespace Stockfish::Eval::NNUE::Layers { template class AffineTransform; +#if defined (USE_AVX512) + constexpr IndexType LargeInputSize = 2 * 64; +#else + constexpr IndexType LargeInputSize = std::numeric_limits::max(); +#endif + // A specialization for large inputs. template - class AffineTransform(InDims, MaxSimdWidth) >= 2*64)>> { + class AffineTransform(InDims, MaxSimdWidth) >= LargeInputSize)>> { public: // Input/output type using InputType = std::uint8_t; @@ -170,7 +176,7 @@ namespace Stockfish::Eval::NNUE::Layers { using OutputBuffer = OutputType[PaddedOutputDimensions]; - static_assert(PaddedInputDimensions >= 128, "Something went wrong. This specialization should not have been chosen."); + static_assert(PaddedInputDimensions >= LargeInputSize, "Something went wrong. This specialization should not have been chosen."); #if defined (USE_AVX512) static constexpr const IndexType InputSimdWidth = 64; @@ -369,7 +375,7 @@ namespace Stockfish::Eval::NNUE::Layers { }; template - class AffineTransform(InDims, MaxSimdWidth) < 2*64)>> { + class AffineTransform(InDims, MaxSimdWidth) < LargeInputSize)>> { public: // Input/output type // Input/output type @@ -387,7 +393,7 @@ namespace Stockfish::Eval::NNUE::Layers { using OutputBuffer = OutputType[PaddedOutputDimensions]; - static_assert(PaddedInputDimensions < 128, "Something went wrong. This specialization should not have been chosen."); + static_assert(PaddedInputDimensions < LargeInputSize, "Something went wrong. This specialization should not have been chosen."); #if defined (USE_SSSE3) static constexpr const IndexType OutputSimdWidth = SimdWidth / 4;