From c7f0a768cb9d5972861baae0f215d69f9e86a626 Mon Sep 17 00:00:00 2001 From: Fanael Linithien Date: Mon, 7 Dec 2020 14:46:29 +0100 Subject: [PATCH] Use arithmetic right shift for sign extension in MMX and SSE2 paths This appears to be slightly faster than using a comparison against zero to compute the high bits, on both old (like Pentium III) and new (like Zen 2) hardware. closes https://github.com/official-stockfish/Stockfish/pull/3254 No functional change. --- src/nnue/layers/affine_transform.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index caf315b2..0e0515f9 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -680,9 +680,8 @@ namespace Eval::NNUE::Layers { for (IndexType j = 0; j < kNumChunks; ++j) { __m128i row_j = _mm_load_si128(&row[j]); __m128i input_j = _mm_load_si128(&input_vector[j]); - __m128i row_signs = _mm_cmpgt_epi8(kZeros, row_j); - __m128i extended_row_lo = _mm_unpacklo_epi8(row_j, row_signs); - __m128i extended_row_hi = _mm_unpackhi_epi8(row_j, row_signs); + __m128i extended_row_lo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8); + __m128i extended_row_hi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8); __m128i extended_input_lo = _mm_unpacklo_epi8(input_j, kZeros); __m128i extended_input_hi = _mm_unpackhi_epi8(input_j, kZeros); __m128i product_lo = _mm_madd_epi16(extended_row_lo, extended_input_lo); @@ -704,9 +703,8 @@ namespace Eval::NNUE::Layers { for (IndexType j = 0; j < kNumChunks; ++j) { __m64 row_j = row[j]; __m64 input_j = input_vector[j]; - __m64 row_signs = _mm_cmpgt_pi8(kZeros, row_j); - __m64 extended_row_lo = _mm_unpacklo_pi8(row_j, row_signs); - __m64 extended_row_hi = _mm_unpackhi_pi8(row_j, row_signs); + __m64 extended_row_lo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8); + __m64 extended_row_hi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8); __m64 extended_input_lo = _mm_unpacklo_pi8(input_j, kZeros); __m64 extended_input_hi = _mm_unpackhi_pi8(input_j, kZeros); __m64 product_lo = _mm_madd_pi16(extended_row_lo, extended_input_lo); -- 2.39.2