]> git.sesse.net Git - stockfish/blob - src/nnue/layers/simd.h
remove large input specialization
[stockfish] / src / nnue / layers / simd.h
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #ifndef STOCKFISH_SIMD_H_INCLUDED
20 #define STOCKFISH_SIMD_H_INCLUDED
21
22 #if defined(USE_AVX2)
23 # include <immintrin.h>
24
25 #elif defined(USE_SSE41)
26 # include <smmintrin.h>
27
28 #elif defined(USE_SSSE3)
29 # include <tmmintrin.h>
30
31 #elif defined(USE_SSE2)
32 # include <emmintrin.h>
33
34 #elif defined(USE_MMX)
35 # include <mmintrin.h>
36
37 #elif defined(USE_NEON)
38 # include <arm_neon.h>
39 #endif
40
41 // The inline asm is only safe for GCC, where it is necessary to get good codegen.
42 // See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=101693
43 // Clang does fine without it.
44 // Play around here: https://godbolt.org/z/7EWqrYq51
45 #if (defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER))
46 #define USE_INLINE_ASM
47 #endif
48
49 // Use either the AVX512 or AVX-VNNI version of the VNNI instructions.
50 #if defined(USE_AVXVNNI)
51 #define VNNI_PREFIX "%{vex%} "
52 #else
53 #define VNNI_PREFIX ""
54 #endif
55
56 namespace Stockfish::Simd {
57
58 #if defined (USE_AVX512)
59
60     [[maybe_unused]] static int m512_hadd(__m512i sum, int bias) {
61       return _mm512_reduce_add_epi32(sum) + bias;
62     }
63
64     /*
65       Parameters:
66         sum0 = [zmm0.i128[0], zmm0.i128[1], zmm0.i128[2], zmm0.i128[3]]
67         sum1 = [zmm1.i128[0], zmm1.i128[1], zmm1.i128[2], zmm1.i128[3]]
68         sum2 = [zmm2.i128[0], zmm2.i128[1], zmm2.i128[2], zmm2.i128[3]]
69         sum3 = [zmm3.i128[0], zmm3.i128[1], zmm3.i128[2], zmm3.i128[3]]
70
71       Returns:
72         ret = [
73           reduce_add_epi32(zmm0.i128[0]), reduce_add_epi32(zmm1.i128[0]), reduce_add_epi32(zmm2.i128[0]), reduce_add_epi32(zmm3.i128[0]),
74           reduce_add_epi32(zmm0.i128[1]), reduce_add_epi32(zmm1.i128[1]), reduce_add_epi32(zmm2.i128[1]), reduce_add_epi32(zmm3.i128[1]),
75           reduce_add_epi32(zmm0.i128[2]), reduce_add_epi32(zmm1.i128[2]), reduce_add_epi32(zmm2.i128[2]), reduce_add_epi32(zmm3.i128[2]),
76           reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3])
77         ]
78     */
79     [[maybe_unused]] static __m512i m512_hadd128x16_interleave(
80         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) {
81
82       __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
83       __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
84
85       __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
86       __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
87
88       __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
89       __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
90
91       __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
92       __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
93
94       return _mm512_add_epi32(sum0123a, sum0123b);
95     }
96
97     [[maybe_unused]] static __m128i m512_haddx4(
98         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
99         __m128i bias) {
100
101       __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
102
103       __m256i sum256lo = _mm512_castsi512_si256(sum);
104       __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
105
106       sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
107
108       __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
109       __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
110
111       return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
112     }
113
114     [[maybe_unused]] static void m512_add_dpbusd_epi32(
115         __m512i& acc,
116         __m512i a,
117         __m512i b) {
118
119 # if defined (USE_VNNI)
120 #   if defined (USE_INLINE_ASM)
121       asm(
122         "vpdpbusd %[b], %[a], %[acc]\n\t"
123         : [acc]"+v"(acc)
124         : [a]"v"(a), [b]"vm"(b)
125       );
126 #   else
127       acc = _mm512_dpbusd_epi32(acc, a, b);
128 #   endif
129 # else
130 #   if defined (USE_INLINE_ASM)
131       __m512i tmp = _mm512_maddubs_epi16(a, b);
132       asm(
133           "vpmaddwd    %[tmp], %[ones], %[tmp]\n\t"
134           "vpaddd      %[acc], %[tmp], %[acc]\n\t"
135           : [acc]"+v"(acc), [tmp]"+&v"(tmp)
136           : [ones]"v"(_mm512_set1_epi16(1))
137       );
138 #   else
139       __m512i product0 = _mm512_maddubs_epi16(a, b);
140       product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
141       acc = _mm512_add_epi32(acc, product0);
142 #   endif
143 # endif
144     }
145
146     [[maybe_unused]] static void m512_add_dpbusd_epi32x2(
147         __m512i& acc,
148         __m512i a0, __m512i b0,
149         __m512i a1, __m512i b1) {
150
151 # if defined (USE_VNNI)
152 #   if defined (USE_INLINE_ASM)
153       asm(
154         "vpdpbusd %[b0], %[a0], %[acc]\n\t"
155         "vpdpbusd %[b1], %[a1], %[acc]\n\t"
156         : [acc]"+&v"(acc)
157         : [a0]"v"(a0), [b0]"vm"(b0), [a1]"v"(a1), [b1]"vm"(b1)
158       );
159 #   else
160       acc = _mm512_dpbusd_epi32(acc, a0, b0);
161       acc = _mm512_dpbusd_epi32(acc, a1, b1);
162 #   endif
163 # else
164 #   if defined (USE_INLINE_ASM)
165       __m512i tmp0 = _mm512_maddubs_epi16(a0, b0);
166       __m512i tmp1 = _mm512_maddubs_epi16(a1, b1);
167       asm(
168           "vpmaddwd    %[tmp0], %[ones], %[tmp0]\n\t"
169           "vpmaddwd    %[tmp1], %[ones], %[tmp1]\n\t"
170           "vpaddd      %[tmp0], %[tmp1], %[tmp0]\n\t"
171           "vpaddd      %[acc], %[tmp0], %[acc]\n\t"
172           : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
173           : [ones]"v"(_mm512_set1_epi16(1))
174       );
175 #   else
176       __m512i product0 = _mm512_maddubs_epi16(a0, b0);
177       __m512i product1 = _mm512_maddubs_epi16(a1, b1);
178       product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
179       product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1));
180       acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1));
181 #   endif
182 # endif
183     }
184
185 #endif
186
187 #if defined (USE_AVX2)
188
189     [[maybe_unused]] static int m256_hadd(__m256i sum, int bias) {
190       __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
191       sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
192       sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
193       return _mm_cvtsi128_si32(sum128) + bias;
194     }
195
196     [[maybe_unused]] static __m128i m256_haddx4(
197         __m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3,
198         __m128i bias) {
199
200       sum0 = _mm256_hadd_epi32(sum0, sum1);
201       sum2 = _mm256_hadd_epi32(sum2, sum3);
202
203       sum0 = _mm256_hadd_epi32(sum0, sum2);
204
205       __m128i sum128lo = _mm256_castsi256_si128(sum0);
206       __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
207
208       return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
209     }
210
211     [[maybe_unused]] static void m256_add_dpbusd_epi32(
212         __m256i& acc,
213         __m256i a,
214         __m256i b) {
215
216 # if defined (USE_VNNI)
217 #   if defined (USE_INLINE_ASM)
218       asm(
219         VNNI_PREFIX "vpdpbusd %[b], %[a], %[acc]\n\t"
220         : [acc]"+v"(acc)
221         : [a]"v"(a), [b]"vm"(b)
222       );
223 #   else
224       acc = _mm256_dpbusd_epi32(acc, a, b);
225 #   endif
226 # else
227 #   if defined (USE_INLINE_ASM)
228       __m256i tmp = _mm256_maddubs_epi16(a, b);
229       asm(
230           "vpmaddwd    %[tmp], %[ones], %[tmp]\n\t"
231           "vpaddd      %[acc], %[tmp], %[acc]\n\t"
232           : [acc]"+v"(acc), [tmp]"+&v"(tmp)
233           : [ones]"v"(_mm256_set1_epi16(1))
234       );
235 #   else
236       __m256i product0 = _mm256_maddubs_epi16(a, b);
237       product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
238       acc = _mm256_add_epi32(acc, product0);
239 #   endif
240 # endif
241     }
242
243     [[maybe_unused]] static void m256_add_dpbusd_epi32x2(
244         __m256i& acc,
245         __m256i a0, __m256i b0,
246         __m256i a1, __m256i b1) {
247
248 # if defined (USE_VNNI)
249 #   if defined (USE_INLINE_ASM)
250       asm(
251         VNNI_PREFIX "vpdpbusd %[b0], %[a0], %[acc]\n\t"
252         VNNI_PREFIX "vpdpbusd %[b1], %[a1], %[acc]\n\t"
253         : [acc]"+&v"(acc)
254         : [a0]"v"(a0), [b0]"vm"(b0), [a1]"v"(a1), [b1]"vm"(b1)
255       );
256 #   else
257       acc = _mm256_dpbusd_epi32(acc, a0, b0);
258       acc = _mm256_dpbusd_epi32(acc, a1, b1);
259 #   endif
260 # else
261 #   if defined (USE_INLINE_ASM)
262       __m256i tmp0 = _mm256_maddubs_epi16(a0, b0);
263       __m256i tmp1 = _mm256_maddubs_epi16(a1, b1);
264       asm(
265           "vpmaddwd    %[tmp0], %[ones], %[tmp0]\n\t"
266           "vpmaddwd    %[tmp1], %[ones], %[tmp1]\n\t"
267           "vpaddd      %[tmp0], %[tmp1], %[tmp0]\n\t"
268           "vpaddd      %[acc], %[tmp0], %[acc]\n\t"
269           : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
270           : [ones]"v"(_mm256_set1_epi16(1))
271       );
272 #   else
273       __m256i product0 = _mm256_maddubs_epi16(a0, b0);
274       __m256i product1 = _mm256_maddubs_epi16(a1, b1);
275       product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
276       product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1));
277       acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1));
278 #   endif
279 # endif
280     }
281
282 #endif
283
284 #if defined (USE_SSSE3)
285
286     [[maybe_unused]] static int m128_hadd(__m128i sum, int bias) {
287       sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
288       sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
289       return _mm_cvtsi128_si32(sum) + bias;
290     }
291
292     [[maybe_unused]] static __m128i m128_haddx4(
293         __m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3,
294         __m128i bias) {
295
296       sum0 = _mm_hadd_epi32(sum0, sum1);
297       sum2 = _mm_hadd_epi32(sum2, sum3);
298       sum0 = _mm_hadd_epi32(sum0, sum2);
299       return _mm_add_epi32(sum0, bias);
300     }
301
302     [[maybe_unused]] static void m128_add_dpbusd_epi32(
303         __m128i& acc,
304         __m128i a,
305         __m128i b) {
306
307 #   if defined (USE_INLINE_ASM)
308       __m128i tmp = _mm_maddubs_epi16(a, b);
309       asm(
310           "pmaddwd    %[ones], %[tmp]\n\t"
311           "paddd      %[tmp], %[acc]\n\t"
312           : [acc]"+v"(acc), [tmp]"+&v"(tmp)
313           : [ones]"v"(_mm_set1_epi16(1))
314       );
315 #   else
316       __m128i product0 = _mm_maddubs_epi16(a, b);
317       product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
318       acc = _mm_add_epi32(acc, product0);
319 #   endif
320     }
321
322     [[maybe_unused]] static void m128_add_dpbusd_epi32x2(
323         __m128i& acc,
324         __m128i a0, __m128i b0,
325         __m128i a1, __m128i b1) {
326
327 #   if defined (USE_INLINE_ASM)
328       __m128i tmp0 = _mm_maddubs_epi16(a0, b0);
329       __m128i tmp1 = _mm_maddubs_epi16(a1, b1);
330       asm(
331           "pmaddwd    %[ones], %[tmp0]\n\t"
332           "pmaddwd    %[ones], %[tmp1]\n\t"
333           "paddd      %[tmp1], %[tmp0]\n\t"
334           "paddd      %[tmp0], %[acc]\n\t"
335           : [acc]"+v"(acc), [tmp0]"+&v"(tmp0), [tmp1]"+&v"(tmp1)
336           : [ones]"v"(_mm_set1_epi16(1))
337       );
338 #   else
339       __m128i product0 = _mm_maddubs_epi16(a0, b0);
340       __m128i product1 = _mm_maddubs_epi16(a1, b1);
341       product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
342       product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1));
343       acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1));
344 #   endif
345     }
346
347 #endif
348
349 #if defined (USE_NEON_DOTPROD)
350
351     [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2(
352         int32x4_t& acc,
353         int8x16_t a0, int8x16_t b0,
354         int8x16_t a1, int8x16_t b1) {
355
356         acc = vdotq_s32(acc, a0, b0);
357         acc = vdotq_s32(acc, a1, b1);
358     }
359
360 #endif
361
362 #if defined (USE_NEON)
363
364     [[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {
365 #   if USE_NEON >= 8
366       return vaddvq_s32(s);
367 #   else
368       return s[0] + s[1] + s[2] + s[3];
369 #   endif
370     }
371
372     [[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) {
373       return neon_m128_reduce_add_epi32(sum) + bias;
374     }
375
376     [[maybe_unused]] static int32x4_t neon_m128_haddx4(
377         int32x4_t sum0, int32x4_t sum1, int32x4_t sum2, int32x4_t sum3,
378         int32x4_t bias) {
379
380       int32x4_t hsums {
381         neon_m128_reduce_add_epi32(sum0),
382         neon_m128_reduce_add_epi32(sum1),
383         neon_m128_reduce_add_epi32(sum2),
384         neon_m128_reduce_add_epi32(sum3)
385       };
386       return vaddq_s32(hsums, bias);
387     }
388
389     [[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2(
390         int32x4_t& acc,
391         int8x8_t a0, int8x8_t b0,
392         int8x8_t a1, int8x8_t b1) {
393
394       int16x8_t product = vmull_s8(a0, b0);
395       product = vmlal_s8(product, a1, b1);
396       acc = vpadalq_s16(acc, product);
397     }
398
399 #endif
400
401 }
402
403 #endif // STOCKFISH_SIMD_H_INCLUDED