Use arithmetic right shift for sign extension in MMX and SSE2 paths
authorFanael Linithien <fanael4@gmail.com>
Mon, 7 Dec 2020 13:46:29 +0000 (14:46 +0100)
committerJoost VandeVondele <Joost.VandeVondele@gmail.com>
Sat, 12 Dec 2020 08:20:15 +0000 (09:20 +0100)
This appears to be slightly faster than using a comparison against zero
to compute the high bits, on both old (like Pentium III) and new (like
Zen 2) hardware.

closes https://github.com/official-stockfish/Stockfish/pull/3254

No functional change.

src/nnue/layers/affine_transform.h

index caf315b2792897df8b206c57aa718cb8331ec496..0e0515f932a0773cc82f72c1620bc0a1afe5e5eb 100644 (file)
@@ -680,9 +680,8 @@ namespace Eval::NNUE::Layers {
         for (IndexType j = 0; j < kNumChunks; ++j) {
           __m128i row_j = _mm_load_si128(&row[j]);
           __m128i input_j = _mm_load_si128(&input_vector[j]);
-          __m128i row_signs = _mm_cmpgt_epi8(kZeros, row_j);
-          __m128i extended_row_lo = _mm_unpacklo_epi8(row_j, row_signs);
-          __m128i extended_row_hi = _mm_unpackhi_epi8(row_j, row_signs);
+          __m128i extended_row_lo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
+          __m128i extended_row_hi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);
           __m128i extended_input_lo = _mm_unpacklo_epi8(input_j, kZeros);
           __m128i extended_input_hi = _mm_unpackhi_epi8(input_j, kZeros);
           __m128i product_lo = _mm_madd_epi16(extended_row_lo, extended_input_lo);
@@ -704,9 +703,8 @@ namespace Eval::NNUE::Layers {
         for (IndexType j = 0; j < kNumChunks; ++j) {
           __m64 row_j = row[j];
           __m64 input_j = input_vector[j];
-          __m64 row_signs = _mm_cmpgt_pi8(kZeros, row_j);
-          __m64 extended_row_lo = _mm_unpacklo_pi8(row_j, row_signs);
-          __m64 extended_row_hi = _mm_unpackhi_pi8(row_j, row_signs);
+          __m64 extended_row_lo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8);
+          __m64 extended_row_hi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8);
           __m64 extended_input_lo = _mm_unpacklo_pi8(input_j, kZeros);
           __m64 extended_input_hi = _mm_unpackhi_pi8(input_j, kZeros);
           __m64 product_lo = _mm_madd_pi16(extended_row_lo, extended_input_lo);