X-Git-Url: https://git.sesse.net/?a=blobdiff_plain;f=src%2Fnnue%2Flayers%2Faffine_transform.h;h=9a5f62c0a098691322f911fee4d212d12a199685;hb=26edf9534ad571a6d26bf9db47d21776cbf45d54;hp=9a3b778e6bbbedec7cb8b6d409c5d226f7569206;hpb=e973eee919a8d450f095d102a0d52c196a8e7793;p=stockfish diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index 9a3b778e..9a5f62c0 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -251,9 +251,6 @@ namespace Stockfish::Eval::NNUE::Layers { #endif #if defined (USE_SSSE3) - // Different layout, we process 4 inputs at a time, always. - static_assert(InputDimensions % 4 == 0); - const auto output = reinterpret_cast(buffer); const auto inputVector = reinterpret_cast(input); @@ -263,13 +260,18 @@ namespace Stockfish::Eval::NNUE::Layers { // because then it is also an input dimension. if constexpr (OutputDimensions % OutputSimdWidth == 0) { + static_assert(InputDimensions % 16 == 0); + constexpr IndexType NumChunks = InputDimensions / 4; + constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth; const auto input32 = reinterpret_cast(input); - vec_t* outptr = reinterpret_cast(output); - std::memcpy(output, biases, OutputDimensions * sizeof(OutputType)); + const vec_t* biasvec = reinterpret_cast(biases); + vec_t outs[NumRegs]; + for (IndexType k = 0; k < NumRegs; ++k) + outs[k] = biasvec[k]; - for (int i = 0; i < (int)NumChunks - 3; i += 4) + for (IndexType i = 0; i < NumChunks; i += 4) { const vec_t in0 = vec_set_32(input32[i + 0]); const vec_t in1 = vec_set_32(input32[i + 1]); @@ -279,12 +281,18 @@ namespace Stockfish::Eval::NNUE::Layers { const auto col1 = reinterpret_cast(&weights[(i + 1) * OutputDimensions * 4]); const auto col2 = reinterpret_cast(&weights[(i + 2) * OutputDimensions * 4]); const auto col3 = reinterpret_cast(&weights[(i + 3) * OutputDimensions * 4]); - for (int j = 0; j * OutputSimdWidth < OutputDimensions; ++j) - vec_add_dpbusd_32x4(outptr[j], in0, col0[j], in1, col1[j], in2, col2[j], in3, col3[j]); + for (IndexType k = 0; k < NumRegs; ++k) + vec_add_dpbusd_32x4(outs[k], in0, col0[k], in1, col1[k], in2, col2[k], in3, col3[k]); } + + vec_t* outptr = reinterpret_cast(output); + for (IndexType k = 0; k < NumRegs; ++k) + outptr[k] = outs[k]; } else if constexpr (OutputDimensions == 1) { + static_assert(InputDimensions % 4 == 0); + #if defined (USE_AVX512) if constexpr (PaddedInputDimensions % (SimdWidth * 2) != 0) {