2 Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3 Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file)
5 Stockfish is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, either version 3 of the License, or
8 (at your option) any later version.
10 Stockfish is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with this program. If not, see <http://www.gnu.org/licenses/>.
19 #ifndef STOCKFISH_SIMD_H_INCLUDED
20 #define STOCKFISH_SIMD_H_INCLUDED
23 # include <immintrin.h>
25 #elif defined(USE_SSE41)
26 # include <smmintrin.h>
28 #elif defined(USE_SSSE3)
29 # include <tmmintrin.h>
31 #elif defined(USE_SSE2)
32 # include <emmintrin.h>
34 #elif defined(USE_MMX)
35 # include <mmintrin.h>
37 #elif defined(USE_NEON)
38 # include <arm_neon.h>
41 // The inline asm is only safe for GCC, where it is necessary to get good codegen.
42 // See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=101693
43 // Clang does fine without it.
44 // Play around here: https://godbolt.org/z/7EWqrYq51
45 #if (defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER))
46 #define USE_INLINE_ASM
49 // Use either the AVX512 or AVX-VNNI version of the VNNI instructions.
50 #if defined(USE_AVXVNNI)
51 #define VNNI_PREFIX "%{vex%} "
53 #define VNNI_PREFIX ""
56 namespace Stockfish::Simd {
58 #if defined (USE_AVX512)
60 [[maybe_unused]] static int m512_hadd(__m512i sum, int bias) {
61 return _mm512_reduce_add_epi32(sum) + bias;
66 sum0 = [zmm0.i128[0], zmm0.i128[1], zmm0.i128[2], zmm0.i128[3]]
67 sum1 = [zmm1.i128[0], zmm1.i128[1], zmm1.i128[2], zmm1.i128[3]]
68 sum2 = [zmm2.i128[0], zmm2.i128[1], zmm2.i128[2], zmm2.i128[3]]
69 sum3 = [zmm3.i128[0], zmm3.i128[1], zmm3.i128[2], zmm3.i128[3]]
73 reduce_add_epi32(zmm0.i128[0]), reduce_add_epi32(zmm1.i128[0]), reduce_add_epi32(zmm2.i128[0]), reduce_add_epi32(zmm3.i128[0]),
74 reduce_add_epi32(zmm0.i128[1]), reduce_add_epi32(zmm1.i128[1]), reduce_add_epi32(zmm2.i128[1]), reduce_add_epi32(zmm3.i128[1]),
75 reduce_add_epi32(zmm0.i128[2]), reduce_add_epi32(zmm1.i128[2]), reduce_add_epi32(zmm2.i128[2]), reduce_add_epi32(zmm3.i128[2]),
76 reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3])
79 [[maybe_unused]] static __m512i m512_hadd128x16_interleave(
80 __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) {
82 __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
83 __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
85 __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
86 __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
88 __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
89 __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
91 __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
92 __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
94 return _mm512_add_epi32(sum0123a, sum0123b);
97 [[maybe_unused]] static __m128i m512_haddx4(
98 __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
101 __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
103 __m256i sum256lo = _mm512_castsi512_si256(sum);
104 __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
106 sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
108 __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
109 __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
111 return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
114 [[maybe_unused]] static void m512_add_dpbusd_epi32(
119 # if defined (USE_VNNI)
120 # if defined (USE_INLINE_ASM)
122 "vpdpbusd %[b], %[a], %[acc]\n\t"
124 : [a]"v"(a), [b]"vm"(b)
127 acc = _mm512_dpbusd_epi32(acc, a, b);
130 # if defined (USE_INLINE_ASM)
131 __m512i tmp = _mm512_maddubs_epi16(a, b);
133 "vpmaddwd %[tmp], %[ones], %[tmp]\n\t"
134 "vpaddd %[acc], %[tmp], %[acc]\n\t"
135 : [acc]"+v"(acc), [tmp]"+&v"(tmp)
136 : [ones]"v"(_mm512_set1_epi16(1))
139 __m512i product0 = _mm512_maddubs_epi16(a, b);
140 product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
141 acc = _mm512_add_epi32(acc, product0);
146 [[maybe_unused]] static void m512_add_dpbusd_epi32x2(
148 __m512i a0, __m512i b0,
149 __m512i a1, __m512i b1) {
151 # if defined (USE_VNNI)
152 # if defined (USE_INLINE_ASM)
154 "vpdpbusd %[b0], %[a0], %[acc]\n\t"
155 "vpdpbusd %[b1], %[a1], %[acc]\n\t"
157 : [a0]"v"(a0), [b0]"vm"(b0), [a1]"v"(a1), [b1]"vm"(b1)
160 acc = _mm512_dpbusd_epi32(acc, a0, b0);
161 acc = _mm512_dpbusd_epi32(acc, a1, b1);
164 # if defined (USE_INLINE_ASM)
165 __m512i tmp0 = _mm512_maddubs_epi16(a0, b0);
166 __m512i tmp1 = _mm512_maddubs_epi16(a1, b1);
168 "vpmaddwd %[tmp0], %[ones], %[tmp0]\n\t"
169 "vpmaddwd %[tmp1], %[ones], %[tmp1]\n\t"
170 "vpaddd %[tmp0], %[tmp1], %[tmp0]\n\t"
171 "vpaddd %[acc], %[tmp0], %[acc]\n\t"
172 : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
173 : [ones]"v"(_mm512_set1_epi16(1))
176 __m512i product0 = _mm512_maddubs_epi16(a0, b0);
177 __m512i product1 = _mm512_maddubs_epi16(a1, b1);
178 product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
179 product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1));
180 acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1));
187 #if defined (USE_AVX2)
189 [[maybe_unused]] static int m256_hadd(__m256i sum, int bias) {
190 __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
191 sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
192 sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
193 return _mm_cvtsi128_si32(sum128) + bias;
196 [[maybe_unused]] static __m128i m256_haddx4(
197 __m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3,
200 sum0 = _mm256_hadd_epi32(sum0, sum1);
201 sum2 = _mm256_hadd_epi32(sum2, sum3);
203 sum0 = _mm256_hadd_epi32(sum0, sum2);
205 __m128i sum128lo = _mm256_castsi256_si128(sum0);
206 __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
208 return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
211 [[maybe_unused]] static void m256_add_dpbusd_epi32(
216 # if defined (USE_VNNI)
217 # if defined (USE_INLINE_ASM)
219 VNNI_PREFIX "vpdpbusd %[b], %[a], %[acc]\n\t"
221 : [a]"v"(a), [b]"vm"(b)
224 acc = _mm256_dpbusd_epi32(acc, a, b);
227 # if defined (USE_INLINE_ASM)
228 __m256i tmp = _mm256_maddubs_epi16(a, b);
230 "vpmaddwd %[tmp], %[ones], %[tmp]\n\t"
231 "vpaddd %[acc], %[tmp], %[acc]\n\t"
232 : [acc]"+v"(acc), [tmp]"+&v"(tmp)
233 : [ones]"v"(_mm256_set1_epi16(1))
236 __m256i product0 = _mm256_maddubs_epi16(a, b);
237 product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
238 acc = _mm256_add_epi32(acc, product0);
243 [[maybe_unused]] static void m256_add_dpbusd_epi32x2(
245 __m256i a0, __m256i b0,
246 __m256i a1, __m256i b1) {
248 # if defined (USE_VNNI)
249 # if defined (USE_INLINE_ASM)
251 VNNI_PREFIX "vpdpbusd %[b0], %[a0], %[acc]\n\t"
252 VNNI_PREFIX "vpdpbusd %[b1], %[a1], %[acc]\n\t"
254 : [a0]"v"(a0), [b0]"vm"(b0), [a1]"v"(a1), [b1]"vm"(b1)
257 acc = _mm256_dpbusd_epi32(acc, a0, b0);
258 acc = _mm256_dpbusd_epi32(acc, a1, b1);
261 # if defined (USE_INLINE_ASM)
262 __m256i tmp0 = _mm256_maddubs_epi16(a0, b0);
263 __m256i tmp1 = _mm256_maddubs_epi16(a1, b1);
265 "vpmaddwd %[tmp0], %[ones], %[tmp0]\n\t"
266 "vpmaddwd %[tmp1], %[ones], %[tmp1]\n\t"
267 "vpaddd %[tmp0], %[tmp1], %[tmp0]\n\t"
268 "vpaddd %[acc], %[tmp0], %[acc]\n\t"
269 : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
270 : [ones]"v"(_mm256_set1_epi16(1))
273 __m256i product0 = _mm256_maddubs_epi16(a0, b0);
274 __m256i product1 = _mm256_maddubs_epi16(a1, b1);
275 product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
276 product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1));
277 acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1));
284 #if defined (USE_SSSE3)
286 [[maybe_unused]] static int m128_hadd(__m128i sum, int bias) {
287 sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
288 sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
289 return _mm_cvtsi128_si32(sum) + bias;
292 [[maybe_unused]] static __m128i m128_haddx4(
293 __m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3,
296 sum0 = _mm_hadd_epi32(sum0, sum1);
297 sum2 = _mm_hadd_epi32(sum2, sum3);
298 sum0 = _mm_hadd_epi32(sum0, sum2);
299 return _mm_add_epi32(sum0, bias);
302 [[maybe_unused]] static void m128_add_dpbusd_epi32(
307 # if defined (USE_INLINE_ASM)
308 __m128i tmp = _mm_maddubs_epi16(a, b);
310 "pmaddwd %[ones], %[tmp]\n\t"
311 "paddd %[tmp], %[acc]\n\t"
312 : [acc]"+v"(acc), [tmp]"+&v"(tmp)
313 : [ones]"v"(_mm_set1_epi16(1))
316 __m128i product0 = _mm_maddubs_epi16(a, b);
317 product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
318 acc = _mm_add_epi32(acc, product0);
322 [[maybe_unused]] static void m128_add_dpbusd_epi32x2(
324 __m128i a0, __m128i b0,
325 __m128i a1, __m128i b1) {
327 # if defined (USE_INLINE_ASM)
328 __m128i tmp0 = _mm_maddubs_epi16(a0, b0);
329 __m128i tmp1 = _mm_maddubs_epi16(a1, b1);
331 "pmaddwd %[ones], %[tmp0]\n\t"
332 "pmaddwd %[ones], %[tmp1]\n\t"
333 "paddd %[tmp1], %[tmp0]\n\t"
334 "paddd %[tmp0], %[acc]\n\t"
335 : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
336 : [ones]"v"(_mm_set1_epi16(1))
339 __m128i product0 = _mm_maddubs_epi16(a0, b0);
340 __m128i product1 = _mm_maddubs_epi16(a1, b1);
341 product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
342 product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1));
343 acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1));
349 #if defined (USE_NEON_DOTPROD)
351 [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2(
353 int8x16_t a0, int8x16_t b0,
354 int8x16_t a1, int8x16_t b1) {
356 acc = vdotq_s32(acc, a0, b0);
357 acc = vdotq_s32(acc, a1, b1);
362 #if defined (USE_NEON)
364 [[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {
366 return vaddvq_s32(s);
368 return s[0] + s[1] + s[2] + s[3];
372 [[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) {
373 return neon_m128_reduce_add_epi32(sum) + bias;
376 [[maybe_unused]] static int32x4_t neon_m128_haddx4(
377 int32x4_t sum0, int32x4_t sum1, int32x4_t sum2, int32x4_t sum3,
381 neon_m128_reduce_add_epi32(sum0),
382 neon_m128_reduce_add_epi32(sum1),
383 neon_m128_reduce_add_epi32(sum2),
384 neon_m128_reduce_add_epi32(sum3)
386 return vaddq_s32(hsums, bias);
389 [[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2(
391 int8x8_t a0, int8x8_t b0,
392 int8x8_t a1, int8x8_t b1) {
394 int16x8_t product = vmull_s8(a0, b0);
395 product = vmlal_s8(product, a1, b1);
396 acc = vpadalq_s16(acc, product);
403 #endif // STOCKFISH_SIMD_H_INCLUDED