]> git.sesse.net Git - stockfish/blobdiff - src/nnue/layers/affine_transform.h
Avoid unnecessary stores in the affine transform
[stockfish] / src / nnue / layers / affine_transform.h
index 9a3b778e6bbbedec7cb8b6d409c5d226f7569206..9a5f62c0a098691322f911fee4d212d12a199685 100644 (file)
@@ -251,9 +251,6 @@ namespace Stockfish::Eval::NNUE::Layers {
 #endif
 
 #if defined (USE_SSSE3)
-      // Different layout, we process 4 inputs at a time, always.
-      static_assert(InputDimensions % 4 == 0);
-
       const auto output = reinterpret_cast<OutputType*>(buffer);
       const auto inputVector = reinterpret_cast<const vec_t*>(input);
 
@@ -263,13 +260,18 @@ namespace Stockfish::Eval::NNUE::Layers {
       // because then it is also an input dimension.
       if constexpr (OutputDimensions % OutputSimdWidth == 0)
       {
+          static_assert(InputDimensions % 16 == 0);
+
           constexpr IndexType NumChunks = InputDimensions / 4;
+          constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
 
           const auto input32 = reinterpret_cast<const std::int32_t*>(input);
-          vec_t* outptr = reinterpret_cast<vec_t*>(output);
-          std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
+          const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
+          vec_t outs[NumRegs];
+          for (IndexType k = 0; k < NumRegs; ++k)
+              outs[k] = biasvec[k];
 
-          for (int i = 0; i < (int)NumChunks - 3; i += 4)
+          for (IndexType i = 0; i < NumChunks; i += 4)
           {
               const vec_t in0 = vec_set_32(input32[i + 0]);
               const vec_t in1 = vec_set_32(input32[i + 1]);
@@ -279,12 +281,18 @@ namespace Stockfish::Eval::NNUE::Layers {
               const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
               const auto col2 = reinterpret_cast<const vec_t*>(&weights[(i + 2) * OutputDimensions * 4]);
               const auto col3 = reinterpret_cast<const vec_t*>(&weights[(i + 3) * OutputDimensions * 4]);
-              for (int j = 0; j * OutputSimdWidth < OutputDimensions; ++j)
-                  vec_add_dpbusd_32x4(outptr[j], in0, col0[j], in1, col1[j], in2, col2[j], in3, col3[j]);
+              for (IndexType k = 0; k < NumRegs; ++k)
+                  vec_add_dpbusd_32x4(outs[k], in0, col0[k], in1, col1[k], in2, col2[k], in3, col3[k]);
           }
+
+          vec_t* outptr = reinterpret_cast<vec_t*>(output);
+          for (IndexType k = 0; k < NumRegs; ++k)
+              outptr[k] = outs[k];
       }
       else if constexpr (OutputDimensions == 1)
       {
+          static_assert(InputDimensions % 4 == 0);
+
 #if defined (USE_AVX512)
           if constexpr (PaddedInputDimensions % (SimdWidth * 2) != 0)
           {