]> git.sesse.net Git - stockfish/commitdiff
Optimize and tidy up affine transform code.
authorTomasz Sobczyk <tomasz.sobczyk1997@gmail.com>
Mon, 16 Aug 2021 10:19:26 +0000 (12:19 +0200)
committerJoost VandeVondele <Joost.VandeVondele@gmail.com>
Fri, 20 Aug 2021 06:50:25 +0000 (08:50 +0200)
The new network caused some issues initially due to the very narrow neuron set between the first two FC layers. Necessary changes were hacked together to make it work. This patch is a mature approach to make the affine transform code faster, more readable, and easier to maintain should the layer sizes change again.

The following changes were made:

* ClippedReLU always produces a multiple of 32 outputs. This is about as good of a solution for AffineTransform's SIMD requirements as it can get without a bigger rewrite.

* All self-contained simd helpers are moved to a separate file (simd.h). Inline asm is utilized to work around GCC's issues with code generation and register assignment. See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=101693, https://godbolt.org/z/da76fY1n7

* AffineTransform has 2 specializations. While it's more lines of code due to the boilerplate, the logic in both is significantly reduced, as these two are impossible to nicely combine into one.
 1) The first specialization is for cases when there's >=128 inputs. It uses a different approach to perform the affine transform and can make full use of AVX512 without any edge cases. Furthermore, it has higher theoretical throughput because less loads are needed in the hot path, requiring only a fixed amount of instructions for horizontal additions at the end, which are amortized by the large number of inputs.
 2) The second specialization is made to handle smaller layers where performance is still necessary but edge cases need to be handled. AVX512 implementation for this was ommited by mistake, a remnant from the temporary implementation for the new... This could be easily reintroduced if needed. A slightly more detailed description of both implementations is in the code.

Overall it should be a minor speedup, as shown on fishtest:

passed STC:
LLR: 2.96 (-2.94,2.94) <-0.50,2.50>
Total: 51520 W: 4074 L: 3888 D: 43558
Ptnml(0-2): 111, 3136, 19097, 3288, 128

and various tests shown in the pull request

closes https://github.com/official-stockfish/Stockfish/pull/3663

No functional change

src/nnue/layers/affine_transform.h
src/nnue/layers/clipped_relu.h
src/simd.h [new file with mode: 0644]

index d131836865f73800920fcb29ed635bd199c29b4f..b28712780b2684868bc2c2937628e4112b72b69c 100644 (file)
 #define NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED
 
 #include <iostream>
+#include <algorithm>
+#include <type_traits>
 #include "../nnue_common.h"
+#include "../../simd.h"
+
+/*
+  This file contains the definition for a fully connected layer (aka affine transform).
+  Two approaches are employed, depending on the sizes of the transform.
+
+  Approach 1:
+    - used when the PaddedInputDimensions >= 128
+    - uses AVX512 if possible
+    - processes inputs in batches of 2*InputSimdWidth
+      - so in batches of 128 for AVX512
+    - the weight blocks of size InputSimdWidth are transposed such that
+      access is sequential
+    - N columns of the weight matrix are processed a time, where N
+      depends on the architecture (the amount of registers)
+    - accumulate + hadd is used
+
+  Approach 2:
+    - used when the PaddedInputDimensions < 128
+    - does not use AVX512
+    - expected use-case is for when PaddedInputDimensions == 32 and InputDimensions <= 32.
+      - that's why AVX512 is hard to implement
+    - expected use-case is small layers
+      - not optimized as well as the approach 1
+    - inputs are processed in chunks of 4, weights are respectively transposed
+    - accumulation happens directly to int32s
+*/
 
 namespace Stockfish::Eval::NNUE::Layers {
 
-  // Affine transformation layer
+// Fallback implementation for older/other architectures.
+// Identical for both approaches. Requires the input to be padded to at least 16 values.
+#if !defined(USE_SSSE3)
+  template <IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>
+  static void affine_transform_non_ssse3(std::int32_t* output, const std::int8_t* weights, const std::int32_t* biases, const std::uint8_t* input)
+  {
+# if defined(USE_SSE2)
+    // At least a multiple of 16, with SSE2.
+    static_assert(PaddedInputDimensions % 16 == 0);
+    constexpr IndexType NumChunks = PaddedInputDimensions / 16;
+    const __m128i Zeros = _mm_setzero_si128();
+    const auto inputVector = reinterpret_cast<const __m128i*>(input);
+
+# elif defined(USE_MMX)
+    static_assert(InputDimensions % 8 == 0);
+    constexpr IndexType NumChunks = InputDimensions / 8;
+    const __m64 Zeros = _mm_setzero_si64();
+    const auto inputVector = reinterpret_cast<const __m64*>(input);
+
+# elif defined(USE_NEON)
+    static_assert(PaddedInputDimensions % 16 == 0);
+    constexpr IndexType NumChunks = PaddedInputDimensions / 16;
+    const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
+# endif
+
+    for (IndexType i = 0; i < OutputDimensions; ++i) {
+      const IndexType offset = i * PaddedInputDimensions;
+
+# if defined(USE_SSE2)
+      __m128i sumLo = _mm_cvtsi32_si128(biases[i]);
+      __m128i sumHi = Zeros;
+      const auto row = reinterpret_cast<const __m128i*>(&weights[offset]);
+      for (IndexType j = 0; j < NumChunks; ++j) {
+        __m128i row_j = _mm_load_si128(&row[j]);
+        __m128i input_j = _mm_load_si128(&inputVector[j]);
+        __m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
+        __m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);
+        __m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros);
+        __m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros);
+        __m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo);
+        __m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi);
+        sumLo = _mm_add_epi32(sumLo, productLo);
+        sumHi = _mm_add_epi32(sumHi, productHi);
+      }
+      __m128i sum = _mm_add_epi32(sumLo, sumHi);
+      __m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
+      sum = _mm_add_epi32(sum, sumHigh_64);
+      __m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2));
+      sum = _mm_add_epi32(sum, sum_second_32);
+      output[i] = _mm_cvtsi128_si32(sum);
+
+# elif defined(USE_MMX)
+      __m64 sumLo = _mm_cvtsi32_si64(biases[i]);
+      __m64 sumHi = Zeros;
+      const auto row = reinterpret_cast<const __m64*>(&weights[offset]);
+      for (IndexType j = 0; j < NumChunks; ++j) {
+        __m64 row_j = row[j];
+        __m64 input_j = inputVector[j];
+        __m64 extendedRowLo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8);
+        __m64 extendedRowHi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8);
+        __m64 extendedInputLo = _mm_unpacklo_pi8(input_j, Zeros);
+        __m64 extendedInputHi = _mm_unpackhi_pi8(input_j, Zeros);
+        __m64 productLo = _mm_madd_pi16(extendedRowLo, extendedInputLo);
+        __m64 productHi = _mm_madd_pi16(extendedRowHi, extendedInputHi);
+        sumLo = _mm_add_pi32(sumLo, productLo);
+        sumHi = _mm_add_pi32(sumHi, productHi);
+      }
+      __m64 sum = _mm_add_pi32(sumLo, sumHi);
+      sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum));
+      output[i] = _mm_cvtsi64_si32(sum);
+
+# elif defined(USE_NEON)
+      int32x4_t sum = {biases[i]};
+      const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
+      for (IndexType j = 0; j < NumChunks; ++j) {
+        int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]);
+        product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]);
+        sum = vpadalq_s16(sum, product);
+      }
+      output[i] = sum[0] + sum[1] + sum[2] + sum[3];
+
+# else
+      std::int32_t sum = biases[i];
+      for (IndexType j = 0; j < InputDimensions; ++j) {
+        sum += weights[offset + j] * input[j];
+      }
+      output[i] = sum;
+# endif
+    }
+
+# if defined(USE_MMX)
+    _mm_empty();
+# endif
+  }
+#endif
+
+  template <typename PreviousLayer, IndexType OutDims, typename Enabled = void>
+  class AffineTransform;
+
+  // A specialization for large inputs.
   template <typename PreviousLayer, IndexType OutDims>
-  class AffineTransform {
+  class AffineTransform<PreviousLayer, OutDims, std::enable_if_t<(PreviousLayer::OutputDimensions >= 2*64-1)>> {
    public:
     // Input/output type
     using InputType = typename PreviousLayer::OutputType;
@@ -36,29 +164,49 @@ namespace Stockfish::Eval::NNUE::Layers {
     static_assert(std::is_same<InputType, std::uint8_t>::value, "");
 
     // Number of input/output dimensions
-    static constexpr IndexType InputDimensions =
-        PreviousLayer::OutputDimensions;
+    static constexpr IndexType InputDimensions = PreviousLayer::OutputDimensions;
     static constexpr IndexType OutputDimensions = OutDims;
+
     static constexpr IndexType PaddedInputDimensions =
-        ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth);
-#if defined (USE_AVX512)
-    static constexpr const IndexType OutputSimdWidth = SimdWidth / 2;
-#elif defined (USE_SSSE3)
-    static constexpr const IndexType OutputSimdWidth = SimdWidth / 4;
-#endif
+      ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth);
+
+    static_assert(PaddedInputDimensions >= 128, "Something went wrong. This specialization should not have been chosen.");
+
 #if defined (USE_AVX512)
-    static constexpr const IndexType InputSimdWidth = SimdWidth * 2;
+    static constexpr const IndexType InputSimdWidth = 64;
+    static constexpr const IndexType MaxNumOutputRegs = 16;
+#elif defined (USE_AVX2)
+    static constexpr const IndexType InputSimdWidth = 32;
+    static constexpr const IndexType MaxNumOutputRegs = 8;
 #elif defined (USE_SSSE3)
-    static constexpr const IndexType InputSimdWidth = SimdWidth;
+    static constexpr const IndexType InputSimdWidth = 16;
+    static constexpr const IndexType MaxNumOutputRegs = 8;
+#else
+    // The fallback implementation will not have permuted weights.
+    // We define these to avoid a lot of ifdefs later.
+    static constexpr const IndexType InputSimdWidth = 1;
+    static constexpr const IndexType MaxNumOutputRegs = 1;
 #endif
 
+    // A big block is a region in the weight matrix of the size [PaddedInputDimensions, NumOutputRegs].
+    // A small block is a region of size [InputSimdWidth, 1]
+
+    static constexpr const IndexType NumOutputRegs = std::min(MaxNumOutputRegs, OutputDimensions);
+    static constexpr const IndexType SmallBlockSize = InputSimdWidth;
+    static constexpr const IndexType BigBlockSize = NumOutputRegs * PaddedInputDimensions;
+    static constexpr const IndexType NumSmallBlocksInBigBlock = BigBlockSize / SmallBlockSize;
+    static constexpr const IndexType NumSmallBlocksPerOutput = PaddedInputDimensions / SmallBlockSize;
+    static constexpr const IndexType NumBigBlocks = OutputDimensions / NumOutputRegs;
+
+    static_assert(OutputDimensions % NumOutputRegs == 0);
+
     // Size of forward propagation buffer used in this layer
     static constexpr std::size_t SelfBufferSize =
-        ceil_to_multiple(OutputDimensions * sizeof(OutputType), CacheLineSize);
+      ceil_to_multiple(OutputDimensions * sizeof(OutputType), CacheLineSize);
 
     // Size of the forward propagation buffer used from the input layer to this layer
     static constexpr std::size_t BufferSize =
-        PreviousLayer::BufferSize + SelfBufferSize;
+      PreviousLayer::BufferSize + SelfBufferSize;
 
     // Hash value embedded in the evaluation file
     static constexpr std::uint32_t get_hash_value() {
@@ -69,30 +217,35 @@ namespace Stockfish::Eval::NNUE::Layers {
       return hashValue;
     }
 
+    /*
+      Transposes the small blocks within a block.
+      Effectively means that weights can be traversed sequentially during inference.
+    */
+    static IndexType get_weight_index(IndexType i)
+    {
+      const IndexType smallBlock = (i / SmallBlockSize) % NumSmallBlocksInBigBlock;
+      const IndexType smallBlockCol = smallBlock / NumSmallBlocksPerOutput;
+      const IndexType smallBlockRow = smallBlock % NumSmallBlocksPerOutput;
+      const IndexType bigBlock   = i / BigBlockSize;
+      const IndexType rest       = i % SmallBlockSize;
+
+      const IndexType idx =
+          bigBlock * BigBlockSize
+        + smallBlockRow * SmallBlockSize * NumOutputRegs
+        + smallBlockCol * SmallBlockSize
+        + rest;
+
+      return idx;
+    }
+
     // Read network parameters
     bool read_parameters(std::istream& stream) {
       if (!previousLayer.read_parameters(stream)) return false;
       for (std::size_t i = 0; i < OutputDimensions; ++i)
         biases[i] = read_little_endian<BiasType>(stream);
+
       for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
-#if !defined (USE_SSSE3)
-        weights[i] = read_little_endian<WeightType>(stream);
-#elif defined (USE_VNNI) || defined (USE_AVX512)
-        if constexpr (OutputDimensions <= 8 && OutputDimensions != 1)
-            weights[i] = read_little_endian<WeightType>(stream);
-        else
-            weights[
-              (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
-              i / PaddedInputDimensions * 4 +
-              i % 4
-            ] = read_little_endian<WeightType>(stream);
-#else
-        weights[
-          (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
-          i / PaddedInputDimensions * 4 +
-          i % 4
-        ] = read_little_endian<WeightType>(stream);
-#endif
+        weights[get_weight_index(i)] = read_little_endian<WeightType>(stream);
 
       return !stream.fail();
     }
@@ -102,517 +255,285 @@ namespace Stockfish::Eval::NNUE::Layers {
       if (!previousLayer.write_parameters(stream)) return false;
       for (std::size_t i = 0; i < OutputDimensions; ++i)
           write_little_endian<BiasType>(stream, biases[i]);
-#if !defined (USE_SSSE3)
-      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
-          write_little_endian<WeightType>(stream, weights[i]);
-#else
-      std::unique_ptr<WeightType[]> unscrambledWeights = std::make_unique<WeightType[]>(OutputDimensions * PaddedInputDimensions);
-      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i) {
-          unscrambledWeights[i] =
-              weights[
-                (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
-                i / PaddedInputDimensions * 4 +
-                i % 4
-              ];
-      }
 
       for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
-          write_little_endian<WeightType>(stream, unscrambledWeights[i]);
-#endif
+        write_little_endian<WeightType>(stream, weights[get_weight_index(i)]);
 
       return !stream.fail();
     }
+
     // Forward propagation
     const OutputType* propagate(
         const TransformedFeatureType* transformedFeatures, char* buffer) const {
       const auto input = previousLayer.propagate(
-          transformedFeatures, buffer + SelfBufferSize);
+        transformedFeatures, buffer + SelfBufferSize);
+      OutputType* output = reinterpret_cast<OutputType*>(buffer);
 
 #if defined (USE_AVX512)
+      using vec_t = __m512i;
+      #define vec_setzero _mm512_setzero_si512
+      #define vec_set_32 _mm512_set1_epi32
+      #define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
+      #define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2
+      #define vec_hadd Simd::m512_hadd
+      #define vec_haddx4 Simd::m512_haddx4
+#elif defined (USE_AVX2)
+      using vec_t = __m256i;
+      #define vec_setzero _mm256_setzero_si256
+      #define vec_set_32 _mm256_set1_epi32
+      #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
+      #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
+      #define vec_hadd Simd::m256_hadd
+      #define vec_haddx4 Simd::m256_haddx4
+#elif defined (USE_SSSE3)
+      using vec_t = __m128i;
+      #define vec_setzero _mm_setzero_si128
+      #define vec_set_32 _mm_set1_epi32
+      #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
+      #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
+      #define vec_hadd Simd::m128_hadd
+      #define vec_haddx4 Simd::m128_haddx4
+#endif
 
-      [[maybe_unused]] const __m512i Ones512 = _mm512_set1_epi16(1);
-
-      [[maybe_unused]] auto m512_hadd = [](__m512i sum, int bias) -> int {
-        return _mm512_reduce_add_epi32(sum) + bias;
-      };
-
-      [[maybe_unused]] auto m512_hadd128x16_interleave = [](
-        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
+#if defined (USE_SSSE3)
+      const vec_t* invec = reinterpret_cast<const vec_t*>(input);
 
-        __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
-        __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
 
-        __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
-        __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
+      // Perform accumulation to registers for each big block
+      for (IndexType bigBlock = 0; bigBlock < NumBigBlocks; ++bigBlock)
+      {
+        vec_t acc[NumOutputRegs] = { vec_setzero() };
+
+        // Each big block has NumOutputRegs small blocks in each "row", one per register.
+        // We process two small blocks at a time to save on one addition without VNNI.
+        for (IndexType smallBlock = 0; smallBlock < NumSmallBlocksPerOutput; smallBlock += 2)
+        {
+          const vec_t* weightvec =
+            reinterpret_cast<const vec_t*>(
+                weights
+              + bigBlock * BigBlockSize
+              + smallBlock * SmallBlockSize * NumOutputRegs);
+
+          const vec_t in0 = invec[smallBlock + 0];
+          const vec_t in1 = invec[smallBlock + 1];
+
+          for (IndexType k = 0; k < NumOutputRegs; ++k)
+            vec_add_dpbusd_32x2(acc[k], in0, weightvec[k], in1, weightvec[k + NumOutputRegs]);
+        }
 
-        __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
-        __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
+        // Horizontally add all accumulators.
+        if constexpr (NumOutputRegs % 4 == 0)
+        {
+          __m128i* outputvec = reinterpret_cast<__m128i*>(output);
+          const __m128i* biasvec = reinterpret_cast<const __m128i*>(biases);
 
-        __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
-        __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
+          for (IndexType k = 0; k < NumOutputRegs; k += 4)
+          {
+            const IndexType idx = (bigBlock * NumOutputRegs + k) / 4;
+            outputvec[idx] = vec_haddx4(acc[k+0], acc[k+1], acc[k+2], acc[k+3], biasvec[idx]);
+          }
+        }
+        else
+        {
+          for (IndexType k = 0; k < NumOutputRegs; ++k)
+          {
+            const IndexType idx = (bigBlock * NumOutputRegs + k);
+            output[idx] = vec_hadd(acc[k], biases[idx]);
+          }
+        }
+      }
 
-        return _mm512_add_epi32(sum0123a, sum0123b);
-      };
+# undef vec_setzero
+# undef vec_set_32
+# undef vec_add_dpbusd_32
+# undef vec_add_dpbusd_32x2
+# undef vec_hadd
+# undef vec_haddx4
+#else
+      // Use old implementation for the other architectures.
+      affine_transform_non_ssse3<
+        InputDimensions,
+        PaddedInputDimensions,
+        OutputDimensions>(output, weights, biases, input);
 
-      [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
-        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
+#endif
 
-        __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
+      return output;
+    }
 
-        __m256i sum256lo = _mm512_castsi512_si256(sum);
-        __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
+   private:
+    using BiasType = OutputType;
+    using WeightType = std::int8_t;
 
-        sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
+    PreviousLayer previousLayer;
 
-        __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
-        __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
+    alignas(CacheLineSize) BiasType biases[OutputDimensions];
+    alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions];
+  };
 
-        return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
-      };
+  template <typename PreviousLayer, IndexType OutDims>
+  class AffineTransform<PreviousLayer, OutDims, std::enable_if_t<(PreviousLayer::OutputDimensions < 2*64-1)>> {
+   public:
+    // Input/output type
+    using InputType = typename PreviousLayer::OutputType;
+    using OutputType = std::int32_t;
+    static_assert(std::is_same<InputType, std::uint8_t>::value, "");
 
-      [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
-#if defined (USE_VNNI)
-        acc = _mm512_dpbusd_epi32(acc, a, b);
-#else
-        __m512i product0 = _mm512_maddubs_epi16(a, b);
-        product0 = _mm512_madd_epi16(product0, Ones512);
-        acc = _mm512_add_epi32(acc, product0);
-#endif
-      };
+    // Number of input/output dimensions
+    static constexpr IndexType InputDimensions =
+        PreviousLayer::OutputDimensions;
+    static constexpr IndexType OutputDimensions = OutDims;
+    static constexpr IndexType PaddedInputDimensions =
+        ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth);
 
-      [[maybe_unused]] auto m512_add_dpbusd_epi32x2 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
-#if defined (USE_VNNI)
-        acc = _mm512_dpbusd_epi32(acc, a0, b0);
-        acc = _mm512_dpbusd_epi32(acc, a1, b1);
-#else
-        __m512i product0 = _mm512_maddubs_epi16(a0, b0);
-        __m512i product1 = _mm512_maddubs_epi16(a1, b1);
-        product0 = _mm512_adds_epi16(product0, product1);
-        product0 = _mm512_madd_epi16(product0, Ones512);
-        acc = _mm512_add_epi32(acc, product0);
-#endif
-      };
-
-      [[maybe_unused]] auto m512_add_dpbusd_epi32x4 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1,
-                                                                        __m512i a2, __m512i b2, __m512i a3, __m512i b3) {
-#if defined (USE_VNNI)
-        acc = _mm512_dpbusd_epi32(acc, a0, b0);
-        acc = _mm512_dpbusd_epi32(acc, a1, b1);
-        acc = _mm512_dpbusd_epi32(acc, a2, b2);
-        acc = _mm512_dpbusd_epi32(acc, a3, b3);
-#else
-        __m512i product0 = _mm512_maddubs_epi16(a0, b0);
-        __m512i product1 = _mm512_maddubs_epi16(a1, b1);
-        __m512i product2 = _mm512_maddubs_epi16(a2, b2);
-        __m512i product3 = _mm512_maddubs_epi16(a3, b3);
-        product0 = _mm512_adds_epi16(product0, product1);
-        product0 = _mm512_madd_epi16(product0, Ones512);
-        product2 = _mm512_adds_epi16(product2, product3);
-        product2 = _mm512_madd_epi16(product2, Ones512);
-        acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product2));
-#endif
-      };
+    static_assert(PaddedInputDimensions < 128, "Something went wrong. This specialization should not have been chosen.");
 
+#if defined (USE_SSSE3)
+    static constexpr const IndexType OutputSimdWidth = SimdWidth / 4;
+    static constexpr const IndexType InputSimdWidth = SimdWidth;
 #endif
-#if defined (USE_AVX2)
 
-      [[maybe_unused]] const __m256i Ones256 = _mm256_set1_epi16(1);
-
-      [[maybe_unused]] auto m256_hadd = [](__m256i sum, int bias) -> int {
-        __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
-        sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
-        sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
-        return _mm_cvtsi128_si32(sum128) + bias;
-      };
-
-      [[maybe_unused]] auto m256_haddx4 = [](__m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3, __m128i bias) -> __m128i {
-        sum0 = _mm256_hadd_epi32(sum0, sum1);
-        sum2 = _mm256_hadd_epi32(sum2, sum3);
+    // Size of forward propagation buffer used in this layer
+    static constexpr std::size_t SelfBufferSize =
+      ceil_to_multiple(OutputDimensions * sizeof(OutputType), CacheLineSize);
 
-        sum0 = _mm256_hadd_epi32(sum0, sum2);
+    // Size of the forward propagation buffer used from the input layer to this layer
+    static constexpr std::size_t BufferSize =
+      PreviousLayer::BufferSize + SelfBufferSize;
 
-        __m128i sum128lo = _mm256_castsi256_si128(sum0);
-        __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
+    // Hash value embedded in the evaluation file
+    static constexpr std::uint32_t get_hash_value() {
+      std::uint32_t hashValue = 0xCC03DAE4u;
+      hashValue += OutputDimensions;
+      hashValue ^= PreviousLayer::get_hash_value() >> 1;
+      hashValue ^= PreviousLayer::get_hash_value() << 31;
+      return hashValue;
+    }
 
-        return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
-      };
+    static IndexType get_weight_index_scrambled(IndexType i)
+    {
+      return
+        (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
+        i / PaddedInputDimensions * 4 +
+        i % 4;
+    }
 
-      [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
-#if defined (USE_VNNI)
-        acc = _mm256_dpbusd_epi32(acc, a, b);
+    static IndexType get_weight_index(IndexType i)
+    {
+#if defined (USE_SSSE3)
+      return get_weight_index_scrambled(i);
 #else
-        __m256i product0 = _mm256_maddubs_epi16(a, b);
-        product0 = _mm256_madd_epi16(product0, Ones256);
-        acc = _mm256_add_epi32(acc, product0);
+      return i;
 #endif
-      };
+    }
 
-      [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
-#if defined (USE_VNNI)
-        acc = _mm256_dpbusd_epi32(acc, a0, b0);
-        acc = _mm256_dpbusd_epi32(acc, a1, b1);
-#else
-        __m256i product0 = _mm256_maddubs_epi16(a0, b0);
-        __m256i product1 = _mm256_maddubs_epi16(a1, b1);
-        product0 = _mm256_adds_epi16(product0, product1);
-        product0 = _mm256_madd_epi16(product0, Ones256);
-        acc = _mm256_add_epi32(acc, product0);
-#endif
-      };
-
-      [[maybe_unused]] auto m256_add_dpbusd_epi32x4 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1,
-                                                                        __m256i a2, __m256i b2, __m256i a3, __m256i b3) {
-#if defined (USE_VNNI)
-        acc = _mm256_dpbusd_epi32(acc, a0, b0);
-        acc = _mm256_dpbusd_epi32(acc, a1, b1);
-        acc = _mm256_dpbusd_epi32(acc, a2, b2);
-        acc = _mm256_dpbusd_epi32(acc, a3, b3);
-#else
-        __m256i product0 = _mm256_maddubs_epi16(a0, b0);
-        __m256i product1 = _mm256_maddubs_epi16(a1, b1);
-        __m256i product2 = _mm256_maddubs_epi16(a2, b2);
-        __m256i product3 = _mm256_maddubs_epi16(a3, b3);
-        product0 = _mm256_adds_epi16(product0, product1);
-        product0 = _mm256_madd_epi16(product0, Ones256);
-        product2 = _mm256_adds_epi16(product2, product3);
-        product2 = _mm256_madd_epi16(product2, Ones256);
-        acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product2));
-#endif
-      };
+    // Read network parameters
+    bool read_parameters(std::istream& stream) {
+      if (!previousLayer.read_parameters(stream)) return false;
+      for (std::size_t i = 0; i < OutputDimensions; ++i)
+        biases[i] = read_little_endian<BiasType>(stream);
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
+        weights[get_weight_index(i)] = read_little_endian<WeightType>(stream);
 
-#endif
-#if defined (USE_SSSE3)
+      return !stream.fail();
+    }
 
-      [[maybe_unused]] const __m128i Ones128 = _mm_set1_epi16(1);
-
-      [[maybe_unused]] auto m128_hadd = [](__m128i sum, int bias) -> int {
-        sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
-        sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
-        return _mm_cvtsi128_si32(sum) + bias;
-      };
-
-      [[maybe_unused]] auto m128_haddx4 = [](__m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3, __m128i bias) -> __m128i {
-        sum0 = _mm_hadd_epi32(sum0, sum1);
-        sum2 = _mm_hadd_epi32(sum2, sum3);
-        sum0 = _mm_hadd_epi32(sum0, sum2);
-        return _mm_add_epi32(sum0, bias);
-      };
-
-      [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
-        __m128i product0 = _mm_maddubs_epi16(a, b);
-        product0 = _mm_madd_epi16(product0, Ones128);
-        acc = _mm_add_epi32(acc, product0);
-      };
-
-      [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
-        __m128i product0 = _mm_maddubs_epi16(a0, b0);
-        __m128i product1 = _mm_maddubs_epi16(a1, b1);
-        product0 = _mm_adds_epi16(product0, product1);
-        product0 = _mm_madd_epi16(product0, Ones128);
-        acc = _mm_add_epi32(acc, product0);
-      };
-
-      [[maybe_unused]] auto m128_add_dpbusd_epi32x4 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1,
-                                                                        __m128i a2, __m128i b2, __m128i a3, __m128i b3) {
-        __m128i product0 = _mm_maddubs_epi16(a0, b0);
-        __m128i product1 = _mm_maddubs_epi16(a1, b1);
-        __m128i product2 = _mm_maddubs_epi16(a2, b2);
-        __m128i product3 = _mm_maddubs_epi16(a3, b3);
-        product0 = _mm_adds_epi16(product0, product1);
-        product0 = _mm_madd_epi16(product0, Ones128);
-        product2 = _mm_adds_epi16(product2, product3);
-        product2 = _mm_madd_epi16(product2, Ones128);
-        acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product2));
-      };
+    // Write network parameters
+    bool write_parameters(std::ostream& stream) const {
+      if (!previousLayer.write_parameters(stream)) return false;
+      for (std::size_t i = 0; i < OutputDimensions; ++i)
+        write_little_endian<BiasType>(stream, biases[i]);
 
-#endif
+      for (std::size_t i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
+        write_little_endian<WeightType>(stream, weights[get_weight_index(i)]);
 
-#if defined (USE_AVX512)
-      using vec_t = __m512i;
-      #define vec_setzero _mm512_setzero_si512
-      #define vec_set_32 _mm512_set1_epi32
-      [[maybe_unused]] auto& vec_add_dpbusd_32 = m512_add_dpbusd_epi32;
-      [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m512_add_dpbusd_epi32x2;
-      [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m512_add_dpbusd_epi32x4;
-      [[maybe_unused]] auto& vec_hadd = m512_hadd;
-      [[maybe_unused]] auto& vec_haddx4 = m512_haddx4;
-#elif defined (USE_AVX2)
+      return !stream.fail();
+    }
+    // Forward propagation
+    const OutputType* propagate(
+        const TransformedFeatureType* transformedFeatures, char* buffer) const {
+      const auto input = previousLayer.propagate(
+        transformedFeatures, buffer + SelfBufferSize);
+      const auto output = reinterpret_cast<OutputType*>(buffer);
+
+#if defined (USE_AVX2)
       using vec_t = __m256i;
       #define vec_setzero _mm256_setzero_si256
       #define vec_set_32 _mm256_set1_epi32
-      [[maybe_unused]] auto& vec_add_dpbusd_32 = m256_add_dpbusd_epi32;
-      [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m256_add_dpbusd_epi32x2;
-      [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m256_add_dpbusd_epi32x4;
-      [[maybe_unused]] auto& vec_hadd = m256_hadd;
-      [[maybe_unused]] auto& vec_haddx4 = m256_haddx4;
+      #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
+      #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
+      #define vec_add_dpbusd_32x4 Simd::m256_add_dpbusd_epi32x4
+      #define vec_hadd Simd::m256_hadd
+      #define vec_haddx4 Simd::m256_haddx4
 #elif defined (USE_SSSE3)
       using vec_t = __m128i;
       #define vec_setzero _mm_setzero_si128
       #define vec_set_32 _mm_set1_epi32
-      [[maybe_unused]] auto& vec_add_dpbusd_32 = m128_add_dpbusd_epi32;
-      [[maybe_unused]] auto& vec_add_dpbusd_32x2 = m128_add_dpbusd_epi32x2;
-      [[maybe_unused]] auto& vec_add_dpbusd_32x4 = m128_add_dpbusd_epi32x4;
-      [[maybe_unused]] auto& vec_hadd = m128_hadd;
-      [[maybe_unused]] auto& vec_haddx4 = m128_haddx4;
+      #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
+      #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
+      #define vec_add_dpbusd_32x4 Simd::m128_add_dpbusd_epi32x4
+      #define vec_hadd Simd::m128_hadd
+      #define vec_haddx4 Simd::m128_haddx4
 #endif
 
 #if defined (USE_SSSE3)
-      const auto output = reinterpret_cast<OutputType*>(buffer);
       const auto inputVector = reinterpret_cast<const vec_t*>(input);
-#endif
-
-#if defined (USE_VNNI) || defined (USE_AVX512)
-
-      static_assert(OutputDimensions == 1 || OutputDimensions % 4 == 0);
-
-      // OutputDimensions is either 1 or a multiple of SimdWidth
-      // because then it is also an input dimension.
-      if constexpr (OutputDimensions <= 8 && OutputDimensions != 1)
-      {
-          constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth;
-
-          static_assert(NumChunks % 2 == 0);
-
-          const auto input_vec = reinterpret_cast<const vec_t*>(input);
-          const auto bias_vec = reinterpret_cast<const __m128i*>(biases);
-          auto out_vec = reinterpret_cast<__m128i*>(output);
-
-          vec_t regs[OutputDimensions];
-          for (IndexType k = 0; k < OutputDimensions; ++k)
-            regs[k] = vec_setzero();
-
-          for (IndexType i = 0; i < NumChunks / 2; ++i)
-          {
-              const vec_t in0 = input_vec[i * 2 + 0];
-              const vec_t in1 = input_vec[i * 2 + 1];
-              for (IndexType k = 0; k < OutputDimensions; ++k)
-              {
-                  const vec_t w0 = reinterpret_cast<const vec_t*>(&weights[k * PaddedInputDimensions])[i * 2 + 0];
-                  const vec_t w1 = reinterpret_cast<const vec_t*>(&weights[k * PaddedInputDimensions])[i * 2 + 1];
-                  vec_add_dpbusd_32(regs[k], in0, w0);
-                  vec_add_dpbusd_32(regs[k], in1, w1);
-              }
-          }
-
-          for (IndexType k = 0; k < OutputDimensions / 4; ++k)
-          {
-            out_vec[k] = vec_haddx4(
-              regs[k * 4 + 0],
-              regs[k * 4 + 1],
-              regs[k * 4 + 2],
-              regs[k * 4 + 3],
-              bias_vec[k]
-            );
-          }
-      }
-      else if constexpr (InputDimensions == 8)
-      {
-          const auto input32 = reinterpret_cast<const std::int32_t*>(input);
-          __m256i* outptr = reinterpret_cast<__m256i*>(output);
-          std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
-
-          const __m256i in0 = _mm256_set1_epi32(input32[0]);
-          const __m256i in1 = _mm256_set1_epi32(input32[1]);
-          const auto col0 = reinterpret_cast<const __m256i*>(&weights[0]);
-          const auto col1 = reinterpret_cast<const __m256i*>(&weights[OutputDimensions * 4]);
-          for (IndexType j = 0; j * 8 < OutputDimensions; ++j)
-              m256_add_dpbusd_epi32x2(outptr[j], in0, col0[j], in1, col1[j]);
-      }
-      else
 
-#elif defined (USE_SSSE3)
-
-      if constexpr (OutputDimensions % OutputSimdWidth == 0 && InputDimensions == 8)
-      {
-          const auto input32 = reinterpret_cast<const std::int32_t*>(input);
-          vec_t* outptr = reinterpret_cast<vec_t*>(output);
-          std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
-
-          const vec_t in0 = vec_set_32(input32[0]);
-          const vec_t in1 = vec_set_32(input32[1]);
-          const auto col0 = reinterpret_cast<const vec_t*>(&weights[0]);
-          const auto col1 = reinterpret_cast<const vec_t*>(&weights[OutputDimensions * 4]);
-          for (IndexType j = 0; j * OutputSimdWidth < OutputDimensions; ++j)
-              vec_add_dpbusd_32x2(outptr[j], in0, col0[j], in1, col1[j]);
-      }
-      else
-
-#endif
-
-#if defined (USE_SSSE3)
+      static_assert(InputDimensions % 8 == 0);
+      static_assert(OutputDimensions % OutputSimdWidth == 0 || OutputDimensions == 1);
 
       if constexpr (OutputDimensions % OutputSimdWidth == 0)
       {
-          static_assert(InputDimensions % 16 == 0);
-
-          constexpr IndexType NumChunks = InputDimensions / 4;
-          constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
-
-          const auto input32 = reinterpret_cast<const std::int32_t*>(input);
-          const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
-          vec_t outs[NumRegs];
+        constexpr IndexType NumChunks = InputDimensions / 4;
+        constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
+
+        const auto input32 = reinterpret_cast<const std::int32_t*>(input);
+        const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
+        vec_t acc[NumRegs];
+        for (IndexType k = 0; k < NumRegs; ++k)
+          acc[k] = biasvec[k];
+
+        for (IndexType i = 0; i < NumChunks; i += 2)
+        {
+          const vec_t in0 = vec_set_32(input32[i + 0]);
+          const vec_t in1 = vec_set_32(input32[i + 1]);
+          const auto col0 = reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]);
+          const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
           for (IndexType k = 0; k < NumRegs; ++k)
-              outs[k] = biasvec[k];
-
-          for (IndexType i = 0; i < NumChunks; i += 4)
-          {
-              const vec_t in0 = vec_set_32(input32[i + 0]);
-              const vec_t in1 = vec_set_32(input32[i + 1]);
-              const vec_t in2 = vec_set_32(input32[i + 2]);
-              const vec_t in3 = vec_set_32(input32[i + 3]);
-              const auto col0 = reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]);
-              const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
-              const auto col2 = reinterpret_cast<const vec_t*>(&weights[(i + 2) * OutputDimensions * 4]);
-              const auto col3 = reinterpret_cast<const vec_t*>(&weights[(i + 3) * OutputDimensions * 4]);
-              for (IndexType k = 0; k < NumRegs; ++k)
-                  vec_add_dpbusd_32x4(outs[k], in0, col0[k], in1, col1[k], in2, col2[k], in3, col3[k]);
-          }
+            vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]);
+        }
 
-          vec_t* outptr = reinterpret_cast<vec_t*>(output);
-          for (IndexType k = 0; k < NumRegs; ++k)
-              outptr[k] = outs[k];
+        vec_t* outptr = reinterpret_cast<vec_t*>(output);
+        for (IndexType k = 0; k < NumRegs; ++k)
+          outptr[k] = acc[k];
       }
       else if constexpr (OutputDimensions == 1)
       {
-          static_assert(InputDimensions % 4 == 0);
-
-#if defined (USE_AVX512)
-          if constexpr (PaddedInputDimensions % (SimdWidth * 2) != 0)
-          {
-              constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
-              const auto inputVector256 = reinterpret_cast<const __m256i*>(input);
-
-              __m256i sum0 = _mm256_setzero_si256();
-              const auto row0 = reinterpret_cast<const __m256i*>(&weights[0]);
-
-              for (int j = 0; j < (int)NumChunks; ++j)
-              {
-                  const __m256i in = inputVector256[j];
-                  m256_add_dpbusd_epi32(sum0, in, row0[j]);
-              }
-              output[0] = m256_hadd(sum0, biases[0]);
-          }
-          else
-#endif
-          {
-#if defined (USE_AVX512)
-              constexpr IndexType NumChunks = PaddedInputDimensions / (SimdWidth * 2);
-#else
-              constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
-#endif
-              vec_t sum0 = vec_setzero();
-              const auto row0 = reinterpret_cast<const vec_t*>(&weights[0]);
-
-              for (int j = 0; j < (int)NumChunks; ++j)
-              {
-                  const vec_t in = inputVector[j];
-                  vec_add_dpbusd_32(sum0, in, row0[j]);
-              }
-              output[0] = vec_hadd(sum0, biases[0]);
-          }
-      }
-
-#else
-
-// Use old implementation for the other architectures.
-
-      auto output = reinterpret_cast<OutputType*>(buffer);
-
-#if defined(USE_SSE2)
-      // At least a multiple of 16, with SSE2.
-      static_assert(PaddedInputDimensions % SimdWidth == 0);
-      constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
-      const __m128i Zeros = _mm_setzero_si128();
-      const auto inputVector = reinterpret_cast<const __m128i*>(input);
-
-#elif defined(USE_MMX)
-      static_assert(InputDimensions % SimdWidth == 0);
-      constexpr IndexType NumChunks = InputDimensions / SimdWidth;
-      const __m64 Zeros = _mm_setzero_si64();
-      const auto inputVector = reinterpret_cast<const __m64*>(input);
-
-#elif defined(USE_NEON)
-      static_assert(PaddedInputDimensions % SimdWidth == 0);
-      constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
-      const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
-#endif
-
-      for (IndexType i = 0; i < OutputDimensions; ++i) {
-        const IndexType offset = i * PaddedInputDimensions;
-
-#if defined(USE_SSE2)
-        __m128i sumLo = _mm_cvtsi32_si128(biases[i]);
-        __m128i sumHi = Zeros;
-        const auto row = reinterpret_cast<const __m128i*>(&weights[offset]);
-        for (IndexType j = 0; j < NumChunks; ++j) {
-          __m128i row_j = _mm_load_si128(&row[j]);
-          __m128i input_j = _mm_load_si128(&inputVector[j]);
-          __m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
-          __m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);
-          __m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros);
-          __m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros);
-          __m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo);
-          __m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi);
-          sumLo = _mm_add_epi32(sumLo, productLo);
-          sumHi = _mm_add_epi32(sumHi, productHi);
-        }
-        __m128i sum = _mm_add_epi32(sumLo, sumHi);
-        __m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
-        sum = _mm_add_epi32(sum, sumHigh_64);
-        __m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2));
-        sum = _mm_add_epi32(sum, sum_second_32);
-        output[i] = _mm_cvtsi128_si32(sum);
-
-#elif defined(USE_MMX)
-        __m64 sumLo = _mm_cvtsi32_si64(biases[i]);
-        __m64 sumHi = Zeros;
-        const auto row = reinterpret_cast<const __m64*>(&weights[offset]);
-        for (IndexType j = 0; j < NumChunks; ++j) {
-          __m64 row_j = row[j];
-          __m64 input_j = inputVector[j];
-          __m64 extendedRowLo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8);
-          __m64 extendedRowHi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8);
-          __m64 extendedInputLo = _mm_unpacklo_pi8(input_j, Zeros);
-          __m64 extendedInputHi = _mm_unpackhi_pi8(input_j, Zeros);
-          __m64 productLo = _mm_madd_pi16(extendedRowLo, extendedInputLo);
-          __m64 productHi = _mm_madd_pi16(extendedRowHi, extendedInputHi);
-          sumLo = _mm_add_pi32(sumLo, productLo);
-          sumHi = _mm_add_pi32(sumHi, productHi);
-        }
-        __m64 sum = _mm_add_pi32(sumLo, sumHi);
-        sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum));
-        output[i] = _mm_cvtsi64_si32(sum);
-
-#elif defined(USE_NEON)
-        int32x4_t sum = {biases[i]};
-        const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
-        for (IndexType j = 0; j < NumChunks; ++j) {
-          int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]);
-          product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]);
-          sum = vpadalq_s16(sum, product);
-        }
-        output[i] = sum[0] + sum[1] + sum[2] + sum[3];
-
-#else
-        OutputType sum = biases[i];
-        for (IndexType j = 0; j < InputDimensions; ++j) {
-          sum += weights[offset + j] * input[j];
+        constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
+        vec_t sum0 = vec_setzero();
+        const auto row0 = reinterpret_cast<const vec_t*>(&weights[0]);
+
+        for (int j = 0; j < (int)NumChunks; ++j)
+        {
+          const vec_t in = inputVector[j];
+          vec_add_dpbusd_32(sum0, in, row0[j]);
         }
-        output[i] = sum;
-#endif
-
+        output[0] = vec_hadd(sum0, biases[0]);
       }
-#if defined(USE_MMX)
-      _mm_empty();
-#endif
 
-#endif
-
-#if (!defined (USE_SSSE3) && defined (USE_SSE2)) || defined (USE_NEON)
-      static_assert(SimdWidth <= 16, "Otherwise we run outside of the padding for the output.");
-      if constexpr (SimdWidth > OutputDimensions && OutputDimensions != 1)
-          for (IndexType i = OutputDimensions; i < SimdWidth; ++i)
-            output[i] = 0;
+# undef vec_setzero
+# undef vec_set_32
+# undef vec_add_dpbusd_32
+# undef vec_add_dpbusd_32x2
+# undef vec_add_dpbusd_32x4
+# undef vec_hadd
+# undef vec_haddx4
+#else
+      // Use old implementation for the other architectures.
+      affine_transform_non_ssse3<
+        InputDimensions,
+        PaddedInputDimensions,
+        OutputDimensions>(output, weights, biases, input);
 #endif
 
       return output;
index 65455df4944324a12870ca29d68c9ff5e0b379b1..c6f3ccade7db51917dfa3e0bcf88540ffbff25e1 100644 (file)
@@ -35,9 +35,10 @@ namespace Stockfish::Eval::NNUE::Layers {
     static_assert(std::is_same<InputType, std::int32_t>::value, "");
 
     // Number of input/output dimensions
-    static constexpr IndexType InputDimensions =
-        PreviousLayer::OutputDimensions;
+    static constexpr IndexType InputDimensions = PreviousLayer::OutputDimensions;
     static constexpr IndexType OutputDimensions = InputDimensions;
+    static constexpr IndexType PaddedOutputDimensions =
+        ceil_to_multiple<IndexType>(OutputDimensions, 32);
 
     // Size of forward propagation buffer used in this layer
     static constexpr std::size_t SelfBufferSize =
@@ -179,6 +180,15 @@ namespace Stockfish::Eval::NNUE::Layers {
         output[i] = static_cast<OutputType>(
             std::max(0, std::min(127, input[i] >> WeightScaleBits)));
       }
+
+      // Affine transform layers expect that there is at least
+      // ceil_to_multiple(OutputDimensions, 32) initialized values.
+      // We cannot do this in the affine transform because it requires
+      // preallocating space here.
+      for (IndexType i = OutputDimensions; i < PaddedOutputDimensions; ++i) {
+        output[i] = 0;
+      }
+
       return output;
     }
 
diff --git a/src/simd.h b/src/simd.h
new file mode 100644 (file)
index 0000000..584148f
--- /dev/null
@@ -0,0 +1,341 @@
+/*
+  Stockfish, a UCI chess playing engine derived from Glaurung 2.1
+  Copyright (C) 2004-2021 The Stockfish developers (see AUTHORS file)
+
+  Stockfish is free software: you can redistribute it and/or modify
+  it under the terms of the GNU General Public License as published by
+  the Free Software Foundation, either version 3 of the License, or
+  (at your option) any later version.
+
+  Stockfish is distributed in the hope that it will be useful,
+  but WITHOUT ANY WARRANTY; without even the implied warranty of
+  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+  GNU General Public License for more details.
+
+  You should have received a copy of the GNU General Public License
+  along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#ifndef STOCKFISH_SIMD_H_INCLUDED
+#define STOCKFISH_SIMD_H_INCLUDED
+
+#if defined(USE_AVX2)
+# include <immintrin.h>
+
+#elif defined(USE_SSE41)
+# include <smmintrin.h>
+
+#elif defined(USE_SSSE3)
+# include <tmmintrin.h>
+
+#elif defined(USE_SSE2)
+# include <emmintrin.h>
+
+#elif defined(USE_MMX)
+# include <mmintrin.h>
+
+#elif defined(USE_NEON)
+# include <arm_neon.h>
+#endif
+
+// The inline asm is only safe for GCC, where it is necessary to get good codegen.
+// See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=101693
+// Clang does fine without it.
+// Play around here: https://godbolt.org/z/7EWqrYq51
+#if (defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER))
+#define USE_INLINE_ASM
+#endif
+
+namespace Stockfish::Simd {
+
+#if defined (USE_AVX512)
+
+    [[maybe_unused]] static int m512_hadd(__m512i sum, int bias) {
+      return _mm512_reduce_add_epi32(sum) + bias;
+    }
+
+    /*
+      Parameters:
+        sum0 = [zmm0.i128[0], zmm0.i128[1], zmm0.i128[2], zmm0.i128[3]]
+        sum1 = [zmm1.i128[0], zmm1.i128[1], zmm1.i128[2], zmm1.i128[3]]
+        sum2 = [zmm2.i128[0], zmm2.i128[1], zmm2.i128[2], zmm2.i128[3]]
+        sum3 = [zmm3.i128[0], zmm3.i128[1], zmm3.i128[2], zmm3.i128[3]]
+
+      Returns:
+        ret = [
+          reduce_add_epi32(zmm0.i128[0]), reduce_add_epi32(zmm1.i128[0]), reduce_add_epi32(zmm2.i128[0]), reduce_add_epi32(zmm3.i128[0]),
+          reduce_add_epi32(zmm0.i128[1]), reduce_add_epi32(zmm1.i128[1]), reduce_add_epi32(zmm2.i128[1]), reduce_add_epi32(zmm3.i128[1]),
+          reduce_add_epi32(zmm0.i128[2]), reduce_add_epi32(zmm1.i128[2]), reduce_add_epi32(zmm2.i128[2]), reduce_add_epi32(zmm3.i128[2]),
+          reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3])
+        ]
+    */
+    [[maybe_unused]] static __m512i m512_hadd128x16_interleave(
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) {
+
+      __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
+      __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
+
+      __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
+      __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
+
+      __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
+      __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
+
+      __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
+      __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
+
+      return _mm512_add_epi32(sum0123a, sum0123b);
+    }
+
+    [[maybe_unused]] static __m128i m512_haddx4(
+        __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
+        __m128i bias) {
+
+      __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
+
+      __m256i sum256lo = _mm512_castsi512_si256(sum);
+      __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
+
+      sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
+
+      __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
+      __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
+
+      return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
+    }
+
+    [[maybe_unused]] static void m512_add_dpbusd_epi32(
+        __m512i& acc,
+        __m512i a,
+        __m512i b) {
+
+# if defined (USE_VNNI)
+#   if defined (USE_INLINE_ASM)
+      asm(
+        "vpdpbusd %[b], %[a], %[acc]\n\t"
+        : [acc]"+v"(acc)
+        : [a]"v"(a), [b]"vm"(b)
+      );
+#   else
+      acc = _mm512_dpbusd_epi32(acc, a, b);
+#   endif
+# else
+#   if defined (USE_INLINE_ASM)
+      __m512i tmp = _mm512_maddubs_epi16(a, b);
+      asm(
+          "vpmaddwd    %[tmp], %[ones], %[tmp]\n\t"
+          "vpaddd      %[acc], %[tmp], %[acc]\n\t"
+          : [acc]"+v"(acc), [tmp]"+&v"(tmp)
+          : [ones]"v"(_mm512_set1_epi16(1))
+      );
+#   else
+      __m512i product0 = _mm512_maddubs_epi16(a, b);
+      product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
+      acc = _mm512_add_epi32(acc, product0);
+#   endif
+# endif
+    }
+
+    [[maybe_unused]] static void m512_add_dpbusd_epi32x2(
+        __m512i& acc,
+        __m512i a0, __m512i b0,
+        __m512i a1, __m512i b1) {
+
+# if defined (USE_VNNI)
+#   if defined (USE_INLINE_ASM)
+      asm(
+        "vpdpbusd %[b0], %[a0], %[acc]\n\t"
+        "vpdpbusd %[b1], %[a1], %[acc]\n\t"
+        : [acc]"+v"(acc)
+        : [a0]"v"(a0), [b0]"vm"(b0), [a1]"v"(a1), [b1]"vm"(b1)
+      );
+#   else
+      acc = _mm512_dpbusd_epi32(acc, a0, b0);
+      acc = _mm512_dpbusd_epi32(acc, a1, b1);
+#   endif
+# else
+#   if defined (USE_INLINE_ASM)
+      __m512i tmp0 = _mm512_maddubs_epi16(a0, b0);
+      __m512i tmp1 = _mm512_maddubs_epi16(a1, b1);
+      asm(
+          "vpaddsw     %[tmp0], %[tmp1], %[tmp0]\n\t"
+          "vpmaddwd    %[tmp0], %[ones], %[tmp0]\n\t"
+          "vpaddd      %[acc], %[tmp0], %[acc]\n\t"
+          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0)
+          : [tmp1]"v"(tmp1), [ones]"v"(_mm512_set1_epi16(1))
+      );
+#   else
+      __m512i product0 = _mm512_maddubs_epi16(a0, b0);
+      __m512i product1 = _mm512_maddubs_epi16(a1, b1);
+      product0 = _mm512_adds_epi16(product0, product1);
+      product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
+      acc = _mm512_add_epi32(acc, product0);
+#   endif
+# endif
+    }
+
+#endif
+
+#if defined (USE_AVX2)
+
+    [[maybe_unused]] static int m256_hadd(__m256i sum, int bias) {
+      __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
+      sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
+      sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
+      return _mm_cvtsi128_si32(sum128) + bias;
+    }
+
+    [[maybe_unused]] static __m128i m256_haddx4(
+        __m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3,
+        __m128i bias) {
+
+      sum0 = _mm256_hadd_epi32(sum0, sum1);
+      sum2 = _mm256_hadd_epi32(sum2, sum3);
+
+      sum0 = _mm256_hadd_epi32(sum0, sum2);
+
+      __m128i sum128lo = _mm256_castsi256_si128(sum0);
+      __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
+
+      return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
+    }
+
+    [[maybe_unused]] static void m256_add_dpbusd_epi32(
+        __m256i& acc,
+        __m256i a,
+        __m256i b) {
+
+# if defined (USE_VNNI)
+#   if defined (USE_INLINE_ASM)
+      asm(
+        "vpdpbusd %[b], %[a], %[acc]\n\t"
+        : [acc]"+v"(acc)
+        : [a]"v"(a), [b]"vm"(b)
+      );
+#   else
+      acc = _mm256_dpbusd_epi32(acc, a, b);
+#   endif
+# else
+#   if defined (USE_INLINE_ASM)
+      __m256i tmp = _mm256_maddubs_epi16(a, b);
+      asm(
+          "vpmaddwd    %[tmp], %[ones], %[tmp]\n\t"
+          "vpaddd      %[acc], %[tmp], %[acc]\n\t"
+          : [acc]"+v"(acc), [tmp]"+&v"(tmp)
+          : [ones]"v"(_mm256_set1_epi16(1))
+      );
+#   else
+      __m256i product0 = _mm256_maddubs_epi16(a, b);
+      product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
+      acc = _mm256_add_epi32(acc, product0);
+#   endif
+# endif
+    }
+
+    [[maybe_unused]] static void m256_add_dpbusd_epi32x2(
+        __m256i& acc,
+        __m256i a0, __m256i b0,
+        __m256i a1, __m256i b1) {
+
+# if defined (USE_VNNI)
+#   if defined (USE_INLINE_ASM)
+      asm(
+        "vpdpbusd %[b0], %[a0], %[acc]\n\t"
+        "vpdpbusd %[b1], %[a1], %[acc]\n\t"
+        : [acc]"+v"(acc)
+        : [a0]"v"(a0), [b0]"vm"(b0), [a1]"v"(a1), [b1]"vm"(b1)
+      );
+#   else
+      acc = _mm256_dpbusd_epi32(acc, a0, b0);
+      acc = _mm256_dpbusd_epi32(acc, a1, b1);
+#   endif
+# else
+#   if defined (USE_INLINE_ASM)
+      __m256i tmp0 = _mm256_maddubs_epi16(a0, b0);
+      __m256i tmp1 = _mm256_maddubs_epi16(a1, b1);
+      asm(
+          "vpaddsw     %[tmp0], %[tmp1], %[tmp0]\n\t"
+          "vpmaddwd    %[tmp0], %[ones], %[tmp0]\n\t"
+          "vpaddd      %[acc], %[tmp0], %[acc]\n\t"
+          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0)
+          : [tmp1]"v"(tmp1), [ones]"v"(_mm256_set1_epi16(1))
+      );
+#   else
+      __m256i product0 = _mm256_maddubs_epi16(a0, b0);
+      __m256i product1 = _mm256_maddubs_epi16(a1, b1);
+      product0 = _mm256_adds_epi16(product0, product1);
+      product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
+      acc = _mm256_add_epi32(acc, product0);
+#   endif
+# endif
+    }
+
+#endif
+
+#if defined (USE_SSSE3)
+
+    [[maybe_unused]] static int m128_hadd(__m128i sum, int bias) {
+      sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
+      sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
+      return _mm_cvtsi128_si32(sum) + bias;
+    }
+
+    [[maybe_unused]] static __m128i m128_haddx4(
+        __m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3,
+        __m128i bias) {
+
+      sum0 = _mm_hadd_epi32(sum0, sum1);
+      sum2 = _mm_hadd_epi32(sum2, sum3);
+      sum0 = _mm_hadd_epi32(sum0, sum2);
+      return _mm_add_epi32(sum0, bias);
+    }
+
+    [[maybe_unused]] static void m128_add_dpbusd_epi32(
+        __m128i& acc,
+        __m128i a,
+        __m128i b) {
+
+#   if defined (USE_INLINE_ASM)
+      __m128i tmp = _mm_maddubs_epi16(a, b);
+      asm(
+          "pmaddwd    %[ones], %[tmp]\n\t"
+          "paddd      %[tmp], %[acc]\n\t"
+          : [acc]"+v"(acc), [tmp]"+&v"(tmp)
+          : [ones]"v"(_mm_set1_epi16(1))
+      );
+#   else
+      __m128i product0 = _mm_maddubs_epi16(a, b);
+      product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
+      acc = _mm_add_epi32(acc, product0);
+#   endif
+    }
+
+    [[maybe_unused]] static void m128_add_dpbusd_epi32x2(
+        __m128i& acc,
+        __m128i a0, __m128i b0,
+        __m128i a1, __m128i b1) {
+
+#   if defined (USE_INLINE_ASM)
+      __m128i tmp0 = _mm_maddubs_epi16(a0, b0);
+      __m128i tmp1 = _mm_maddubs_epi16(a1, b1);
+      asm(
+          "paddsw     %[tmp1], %[tmp0]\n\t"
+          "pmaddwd    %[ones], %[tmp0]\n\t"
+          "paddd      %[tmp0], %[acc]\n\t"
+          : [acc]"+v"(acc), [tmp0]"+&v"(tmp0)
+          : [tmp1]"v"(tmp1), [ones]"v"(_mm_set1_epi16(1))
+      );
+#   else
+      __m128i product0 = _mm_maddubs_epi16(a0, b0);
+      __m128i product1 = _mm_maddubs_epi16(a1, b1);
+      product0 = _mm_adds_epi16(product0, product1);
+      product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
+      acc = _mm_add_epi32(acc, product0);
+#   endif
+    }
+
+#endif
+
+}
+
+#endif // STOCKFISH_SIMD_H_INCLUDED