]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/clipped_relu.h
Merge remote-tracking branch 'upstream/master' into HEAD
[stockfish] / src / nnue / layers / clipped_relu.h
index a10e3e482b722e919f2d0b9740090792a3c1e30b..65455df4944324a12870ca29d68c9ff5e0b379b1 100644 (file)
@@ -35,130 +35,155 @@ namespace Stockfish::Eval::NNUE::Layers {
     static_assert(std::is_same<InputType, std::int32_t>::value, "");
 
     // Number of input/output dimensions
-    static constexpr IndexType kInputDimensions =
-        PreviousLayer::kOutputDimensions;
-    static constexpr IndexType kOutputDimensions = kInputDimensions;
+    static constexpr IndexType InputDimensions =
+        PreviousLayer::OutputDimensions;
+    static constexpr IndexType OutputDimensions = InputDimensions;
 
     // Size of forward propagation buffer used in this layer
-    static constexpr std::size_t kSelfBufferSize =
-        CeilToMultiple(kOutputDimensions * sizeof(OutputType), kCacheLineSize);
+    static constexpr std::size_t SelfBufferSize =
+        ceil_to_multiple(OutputDimensions * sizeof(OutputType), CacheLineSize);
 
     // Size of the forward propagation buffer used from the input layer to this layer
-    static constexpr std::size_t kBufferSize =
-        PreviousLayer::kBufferSize + kSelfBufferSize;
+    static constexpr std::size_t BufferSize =
+        PreviousLayer::BufferSize + SelfBufferSize;
 
     // Hash value embedded in the evaluation file
-    static constexpr std::uint32_t GetHashValue() {
-      std::uint32_t hash_value = 0x538D24C7u;
-      hash_value += PreviousLayer::GetHashValue();
-      return hash_value;
+    static constexpr std::uint32_t get_hash_value() {
+      std::uint32_t hashValue = 0x538D24C7u;
+      hashValue += PreviousLayer::get_hash_value();
+      return hashValue;
     }
 
     // Read network parameters
-    bool ReadParameters(std::istream& stream) {
-      return previous_layer_.ReadParameters(stream);
+    bool read_parameters(std::istream& stream) {
+      return previousLayer.read_parameters(stream);
+    }
+
+    // Write network parameters
+    bool write_parameters(std::ostream& stream) const {
+      return previousLayer.write_parameters(stream);
     }
 
     // Forward propagation
-    const OutputType* Propagate(
-        const TransformedFeatureType* transformed_features, char* buffer) const {
-      const auto input = previous_layer_.Propagate(
-          transformed_features, buffer + kSelfBufferSize);
+    const OutputType* propagate(
+        const TransformedFeatureType* transformedFeatures, char* buffer) const {
+      const auto input = previousLayer.propagate(
+          transformedFeatures, buffer + SelfBufferSize);
       const auto output = reinterpret_cast<OutputType*>(buffer);
 
   #if defined(USE_AVX2)
-      constexpr IndexType kNumChunks = kInputDimensions / kSimdWidth;
-      const __m256i kZero = _mm256_setzero_si256();
-      const __m256i kOffsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
-      const auto in = reinterpret_cast<const __m256i*>(input);
-      const auto out = reinterpret_cast<__m256i*>(output);
-      for (IndexType i = 0; i < kNumChunks; ++i) {
-        const __m256i words0 = _mm256_srai_epi16(_mm256_packs_epi32(
-            _mm256_load_si256(&in[i * 4 + 0]),
-            _mm256_load_si256(&in[i * 4 + 1])), kWeightScaleBits);
-        const __m256i words1 = _mm256_srai_epi16(_mm256_packs_epi32(
-            _mm256_load_si256(&in[i * 4 + 2]),
-            _mm256_load_si256(&in[i * 4 + 3])), kWeightScaleBits);
-        _mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(_mm256_max_epi8(
-            _mm256_packs_epi16(words0, words1), kZero), kOffsets));
+      if constexpr (InputDimensions % SimdWidth == 0) {
+        constexpr IndexType NumChunks = InputDimensions / SimdWidth;
+        const __m256i Zero = _mm256_setzero_si256();
+        const __m256i Offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
+        const auto in = reinterpret_cast<const __m256i*>(input);
+        const auto out = reinterpret_cast<__m256i*>(output);
+        for (IndexType i = 0; i < NumChunks; ++i) {
+          const __m256i words0 = _mm256_srai_epi16(_mm256_packs_epi32(
+              _mm256_load_si256(&in[i * 4 + 0]),
+              _mm256_load_si256(&in[i * 4 + 1])), WeightScaleBits);
+          const __m256i words1 = _mm256_srai_epi16(_mm256_packs_epi32(
+              _mm256_load_si256(&in[i * 4 + 2]),
+              _mm256_load_si256(&in[i * 4 + 3])), WeightScaleBits);
+          _mm256_store_si256(&out[i], _mm256_permutevar8x32_epi32(_mm256_max_epi8(
+              _mm256_packs_epi16(words0, words1), Zero), Offsets));
+        }
+      } else {
+        constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
+        const __m128i Zero = _mm_setzero_si128();
+        const auto in = reinterpret_cast<const __m128i*>(input);
+        const auto out = reinterpret_cast<__m128i*>(output);
+        for (IndexType i = 0; i < NumChunks; ++i) {
+          const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32(
+              _mm_load_si128(&in[i * 4 + 0]),
+              _mm_load_si128(&in[i * 4 + 1])), WeightScaleBits);
+          const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32(
+              _mm_load_si128(&in[i * 4 + 2]),
+              _mm_load_si128(&in[i * 4 + 3])), WeightScaleBits);
+          const __m128i packedbytes = _mm_packs_epi16(words0, words1);
+          _mm_store_si128(&out[i], _mm_max_epi8(packedbytes, Zero));
+        }
       }
-      constexpr IndexType kStart = kNumChunks * kSimdWidth;
+      constexpr IndexType Start =
+        InputDimensions % SimdWidth == 0
+        ? InputDimensions / SimdWidth * SimdWidth
+        : InputDimensions / (SimdWidth / 2) * (SimdWidth / 2);
 
   #elif defined(USE_SSE2)
-      constexpr IndexType kNumChunks = kInputDimensions / kSimdWidth;
+      constexpr IndexType NumChunks = InputDimensions / SimdWidth;
 
   #ifdef USE_SSE41
-      const __m128i kZero = _mm_setzero_si128();
+      const __m128i Zero = _mm_setzero_si128();
   #else
       const __m128i k0x80s = _mm_set1_epi8(-128);
   #endif
 
       const auto in = reinterpret_cast<const __m128i*>(input);
       const auto out = reinterpret_cast<__m128i*>(output);
-      for (IndexType i = 0; i < kNumChunks; ++i) {
+      for (IndexType i = 0; i < NumChunks; ++i) {
         const __m128i words0 = _mm_srai_epi16(_mm_packs_epi32(
             _mm_load_si128(&in[i * 4 + 0]),
-            _mm_load_si128(&in[i * 4 + 1])), kWeightScaleBits);
+            _mm_load_si128(&in[i * 4 + 1])), WeightScaleBits);
         const __m128i words1 = _mm_srai_epi16(_mm_packs_epi32(
             _mm_load_si128(&in[i * 4 + 2]),
-            _mm_load_si128(&in[i * 4 + 3])), kWeightScaleBits);
+            _mm_load_si128(&in[i * 4 + 3])), WeightScaleBits);
         const __m128i packedbytes = _mm_packs_epi16(words0, words1);
         _mm_store_si128(&out[i],
 
   #ifdef USE_SSE41
-          _mm_max_epi8(packedbytes, kZero)
+          _mm_max_epi8(packedbytes, Zero)
   #else
           _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
   #endif
 
         );
       }
-      constexpr IndexType kStart = kNumChunks * kSimdWidth;
+      constexpr IndexType Start = NumChunks * SimdWidth;
 
   #elif defined(USE_MMX)
-      constexpr IndexType kNumChunks = kInputDimensions / kSimdWidth;
+      constexpr IndexType NumChunks = InputDimensions / SimdWidth;
       const __m64 k0x80s = _mm_set1_pi8(-128);
       const auto in = reinterpret_cast<const __m64*>(input);
       const auto out = reinterpret_cast<__m64*>(output);
-      for (IndexType i = 0; i < kNumChunks; ++i) {
+      for (IndexType i = 0; i < NumChunks; ++i) {
         const __m64 words0 = _mm_srai_pi16(
             _mm_packs_pi32(in[i * 4 + 0], in[i * 4 + 1]),
-            kWeightScaleBits);
+            WeightScaleBits);
         const __m64 words1 = _mm_srai_pi16(
             _mm_packs_pi32(in[i * 4 + 2], in[i * 4 + 3]),
-            kWeightScaleBits);
+            WeightScaleBits);
         const __m64 packedbytes = _mm_packs_pi16(words0, words1);
         out[i] = _mm_subs_pi8(_mm_adds_pi8(packedbytes, k0x80s), k0x80s);
       }
       _mm_empty();
-      constexpr IndexType kStart = kNumChunks * kSimdWidth;
+      constexpr IndexType Start = NumChunks * SimdWidth;
 
   #elif defined(USE_NEON)
-      constexpr IndexType kNumChunks = kInputDimensions / (kSimdWidth / 2);
-      const int8x8_t kZero = {0};
+      constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
+      const int8x8_t Zero = {0};
       const auto in = reinterpret_cast<const int32x4_t*>(input);
       const auto out = reinterpret_cast<int8x8_t*>(output);
-      for (IndexType i = 0; i < kNumChunks; ++i) {
+      for (IndexType i = 0; i < NumChunks; ++i) {
         int16x8_t shifted;
         const auto pack = reinterpret_cast<int16x4_t*>(&shifted);
-        pack[0] = vqshrn_n_s32(in[i * 2 + 0], kWeightScaleBits);
-        pack[1] = vqshrn_n_s32(in[i * 2 + 1], kWeightScaleBits);
-        out[i] = vmax_s8(vqmovn_s16(shifted), kZero);
+        pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits);
+        pack[1] = vqshrn_n_s32(in[i * 2 + 1], WeightScaleBits);
+        out[i] = vmax_s8(vqmovn_s16(shifted), Zero);
       }
-      constexpr IndexType kStart = kNumChunks * (kSimdWidth / 2);
+      constexpr IndexType Start = NumChunks * (SimdWidth / 2);
   #else
-      constexpr IndexType kStart = 0;
+      constexpr IndexType Start = 0;
   #endif
 
-      for (IndexType i = kStart; i < kInputDimensions; ++i) {
+      for (IndexType i = Start; i < InputDimensions; ++i) {
         output[i] = static_cast<OutputType>(
-            std::max(0, std::min(127, input[i] >> kWeightScaleBits)));
+            std::max(0, std::min(127, input[i] >> WeightScaleBits)));
       }
       return output;
     }
 
    private:
-    PreviousLayer previous_layer_;
+    PreviousLayer previousLayer;
   };
 
 }  // namespace Stockfish::Eval::NNUE::Layers