]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/clipped_relu.h
Optimize and tidy up affine transform code.
[stockfish] / src / nnue / layers / clipped_relu.h
index 00809c507b3d3cf1eaac5c0f22d4c67054fcf652..c6f3ccade7db51917dfa3e0bcf88540ffbff25e1 100644 (file)
@@ -35,9 +35,10 @@ namespace Stockfish::Eval::NNUE::Layers {
     static_assert(std::is_same<InputType, std::int32_t>::value, "");
 
     // Number of input/output dimensions
-    static constexpr IndexType InputDimensions =
-        PreviousLayer::OutputDimensions;
+    static constexpr IndexType InputDimensions = PreviousLayer::OutputDimensions;
     static constexpr IndexType OutputDimensions = InputDimensions;
+    static constexpr IndexType PaddedOutputDimensions =
+        ceil_to_multiple<IndexType>(OutputDimensions, 32);
 
     // Size of forward propagation buffer used in this layer
     static constexpr std::size_t SelfBufferSize =
@@ -59,6 +60,11 @@ namespace Stockfish::Eval::NNUE::Layers {
       return previousLayer.read_parameters(stream);
     }
 
+    // Write network parameters
+    bool write_parameters(std::ostream& stream) const {
+      return previousLayer.write_parameters(stream);
+    }
+
     // Forward propagation
     const OutputType* propagate(
         const TransformedFeatureType* transformedFeatures, char* buffer) const {
@@ -67,22 +73,42 @@ namespace Stockfish::Eval::NNUE::Layers {
       const auto output = reinterpret_cast<OutputType*>(buffer);
 
   #if defined(USE_AVX2)
-      constexpr IndexType NumChunks = InputDimensions / SimdWidth;
-      const __m256i Zero = _mm256_setzero_si256();
-      const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
-      const auto in = reinterpret_cast<const __m256i*>(input);
-      const auto out = reinterpret_cast<__m256i*>(output);
-      for (IndexType i = 0; i < NumChunks; ++i) {
-        const __m256i words0 = _mm256_srai_epi16(_mm256_packs_epi32(
-            _mm256_load_si256(&in[i * 4 + 0]),
-            _mm256_load_si256(&in[i * 4 + 1])), WeightScaleBits);
-        const __m256i words1 = _mm256_srai_epi16(_mm256_packs_epi32(
-            _mm256_load_si256(&in[i * 4 + 2]),
-            _mm256_load_si256(&in[i * 4 + 3])), WeightScaleBits);
-        _mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(_mm256_max_epi8(
-            _mm256_packs_epi16(words0, words1), Zero), Offsets));
+      if constexpr (InputDimensions % SimdWidth == 0) {
+        constexpr IndexType NumChunks = InputDimensions / SimdWidth;
+        const __m256i Zero = _mm256_setzero_si256();
+        const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
+        const auto in = reinterpret_cast<const __m256i*>(input);
+        const auto out = reinterpret_cast<__m256i*>(output);
+        for (IndexType i = 0; i < NumChunks; ++i) {
+          const __m256i words0 = _mm256_srai_epi16(_mm256_packs_epi32(
+              _mm256_load_si256(&in[i * 4 + 0]),
+              _mm256_load_si256(&in[i * 4 + 1])), WeightScaleBits);
+          const __m256i words1 = _mm256_srai_epi16(_mm256_packs_epi32(
+              _mm256_load_si256(&in[i * 4 + 2]),
+              _mm256_load_si256(&in[i * 4 + 3])), WeightScaleBits);
+          _mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(_mm256_max_epi8(
+              _mm256_packs_epi16(words0, words1), Zero), Offsets));
+        }
+      } else {
+        constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
+        const __m128i Zero = _mm_setzero_si128();
+        const auto in = reinterpret_cast<const __m128i*>(input);
+        const auto out = reinterpret_cast<__m128i*>(output);
+        for (IndexType i = 0; i < NumChunks; ++i) {
+          const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32(
+              _mm_load_si128(&in[i * 4 + 0]),
+              _mm_load_si128(&in[i * 4 + 1])), WeightScaleBits);
+          const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32(
+              _mm_load_si128(&in[i * 4 + 2]),
+              _mm_load_si128(&in[i * 4 + 3])), WeightScaleBits);
+          const __m128i packedbytes = _mm_packs_epi16(words0, words1);
+          _mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero));
+        }
       }
-      constexpr IndexType Start = NumChunks * SimdWidth;
+      constexpr IndexType Start =
+        InputDimensions % SimdWidth == 0
+        ? InputDimensions / SimdWidth * SimdWidth
+        : InputDimensions / (SimdWidth / 2) * (SimdWidth / 2);
 
   #elif defined(USE_SSE2)
       constexpr IndexType NumChunks = InputDimensions / SimdWidth;
@@ -154,6 +180,15 @@ namespace Stockfish::Eval::NNUE::Layers {
         output[i] = static_cast<OutputType>(
             std::max(0, std::min(127, input[i] >> WeightScaleBits)));
       }
+
+      // Affine transform layers expect that there is at least
+      // ceil_to_multiple(OutputDimensions, 32) initialized values.
+      // We cannot do this in the affine transform because it requires
+      // preallocating space here.
+      for (IndexType i = OutputDimensions; i < PaddedOutputDimensions; ++i) {
+        output[i] = 0;
+      }
+
       return output;
     }