]> git.sesse.net Git - stockfish/blob - src/nnue/layers/simd.h
638e39941a856500bb09151ce64c8133bb04d36a
[stockfish] / src / nnue / layers / simd.h
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #ifndef STOCKFISH_SIMD_H_INCLUDED
20 #define STOCKFISH_SIMD_H_INCLUDED
21
22 #if defined(USE_AVX2)
23 # include <immintrin.h>
24
25 #elif defined(USE_SSE41)
26 # include <smmintrin.h>
27
28 #elif defined(USE_SSSE3)
29 # include <tmmintrin.h>
30
31 #elif defined(USE_SSE2)
32 # include <emmintrin.h>
33
34 #elif defined(USE_MMX)
35 # include <mmintrin.h>
36
37 #elif defined(USE_NEON)
38 # include <arm_neon.h>
39 #endif
40
41 namespace Stockfish::Simd {
42
43 #if defined (USE_AVX512)
44
45     [[maybe_unused]] static int m512_hadd(__m512i sum, int bias) {
46       return _mm512_reduce_add_epi32(sum) + bias;
47     }
48
49     /*
50       Parameters:
51         sum0 = [zmm0.i128[0], zmm0.i128[1], zmm0.i128[2], zmm0.i128[3]]
52         sum1 = [zmm1.i128[0], zmm1.i128[1], zmm1.i128[2], zmm1.i128[3]]
53         sum2 = [zmm2.i128[0], zmm2.i128[1], zmm2.i128[2], zmm2.i128[3]]
54         sum3 = [zmm3.i128[0], zmm3.i128[1], zmm3.i128[2], zmm3.i128[3]]
55
56       Returns:
57         ret = [
58           reduce_add_epi32(zmm0.i128[0]), reduce_add_epi32(zmm1.i128[0]), reduce_add_epi32(zmm2.i128[0]), reduce_add_epi32(zmm3.i128[0]),
59           reduce_add_epi32(zmm0.i128[1]), reduce_add_epi32(zmm1.i128[1]), reduce_add_epi32(zmm2.i128[1]), reduce_add_epi32(zmm3.i128[1]),
60           reduce_add_epi32(zmm0.i128[2]), reduce_add_epi32(zmm1.i128[2]), reduce_add_epi32(zmm2.i128[2]), reduce_add_epi32(zmm3.i128[2]),
61           reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3])
62         ]
63     */
64     [[maybe_unused]] static __m512i m512_hadd128x16_interleave(
65         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) {
66
67       __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
68       __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
69
70       __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
71       __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
72
73       __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
74       __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
75
76       __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
77       __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
78
79       return _mm512_add_epi32(sum0123a, sum0123b);
80     }
81
82     [[maybe_unused]] static __m128i m512_haddx4(
83         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
84         __m128i bias) {
85
86       __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
87
88       __m256i sum256lo = _mm512_castsi512_si256(sum);
89       __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
90
91       sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
92
93       __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
94       __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
95
96       return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
97     }
98
99     [[maybe_unused]] static void m512_add_dpbusd_epi32(
100         __m512i& acc,
101         __m512i a,
102         __m512i b) {
103
104 # if defined (USE_VNNI)
105       acc = _mm512_dpbusd_epi32(acc, a, b);
106 # else
107       __m512i product0 = _mm512_maddubs_epi16(a, b);
108       product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
109       acc = _mm512_add_epi32(acc, product0);
110 # endif
111     }
112
113     [[maybe_unused]] static void m512_add_dpbusd_epi32x2(
114         __m512i& acc,
115         __m512i a0, __m512i b0,
116         __m512i a1, __m512i b1) {
117
118 # if defined (USE_VNNI)
119       acc = _mm512_dpbusd_epi32(acc, a0, b0);
120       acc = _mm512_dpbusd_epi32(acc, a1, b1);
121 # else
122       __m512i product0 = _mm512_maddubs_epi16(a0, b0);
123       __m512i product1 = _mm512_maddubs_epi16(a1, b1);
124       product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
125       product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1));
126       acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1));
127 # endif
128     }
129
130 #endif
131
132 #if defined (USE_AVX2)
133
134     [[maybe_unused]] static int m256_hadd(__m256i sum, int bias) {
135       __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
136       sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
137       sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
138       return _mm_cvtsi128_si32(sum128) + bias;
139     }
140
141     [[maybe_unused]] static __m128i m256_haddx4(
142         __m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3,
143         __m128i bias) {
144
145       sum0 = _mm256_hadd_epi32(sum0, sum1);
146       sum2 = _mm256_hadd_epi32(sum2, sum3);
147
148       sum0 = _mm256_hadd_epi32(sum0, sum2);
149
150       __m128i sum128lo = _mm256_castsi256_si128(sum0);
151       __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
152
153       return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
154     }
155
156     [[maybe_unused]] static void m256_add_dpbusd_epi32(
157         __m256i& acc,
158         __m256i a,
159         __m256i b) {
160
161 # if defined (USE_VNNI)
162       acc = _mm256_dpbusd_epi32(acc, a, b);
163 # else
164       __m256i product0 = _mm256_maddubs_epi16(a, b);
165       product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
166       acc = _mm256_add_epi32(acc, product0);
167 # endif
168     }
169
170     [[maybe_unused]] static void m256_add_dpbusd_epi32x2(
171         __m256i& acc,
172         __m256i a0, __m256i b0,
173         __m256i a1, __m256i b1) {
174
175 # if defined (USE_VNNI)
176       acc = _mm256_dpbusd_epi32(acc, a0, b0);
177       acc = _mm256_dpbusd_epi32(acc, a1, b1);
178 # else
179       __m256i product0 = _mm256_maddubs_epi16(a0, b0);
180       __m256i product1 = _mm256_maddubs_epi16(a1, b1);
181       product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
182       product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1));
183       acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1));
184 # endif
185     }
186
187 #endif
188
189 #if defined (USE_SSSE3)
190
191     [[maybe_unused]] static int m128_hadd(__m128i sum, int bias) {
192       sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
193       sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
194       return _mm_cvtsi128_si32(sum) + bias;
195     }
196
197     [[maybe_unused]] static __m128i m128_haddx4(
198         __m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3,
199         __m128i bias) {
200
201       sum0 = _mm_hadd_epi32(sum0, sum1);
202       sum2 = _mm_hadd_epi32(sum2, sum3);
203       sum0 = _mm_hadd_epi32(sum0, sum2);
204       return _mm_add_epi32(sum0, bias);
205     }
206
207     [[maybe_unused]] static void m128_add_dpbusd_epi32(
208         __m128i& acc,
209         __m128i a,
210         __m128i b) {
211
212       __m128i product0 = _mm_maddubs_epi16(a, b);
213       product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
214       acc = _mm_add_epi32(acc, product0);
215     }
216
217     [[maybe_unused]] static void m128_add_dpbusd_epi32x2(
218         __m128i& acc,
219         __m128i a0, __m128i b0,
220         __m128i a1, __m128i b1) {
221
222       __m128i product0 = _mm_maddubs_epi16(a0, b0);
223       __m128i product1 = _mm_maddubs_epi16(a1, b1);
224       product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
225       product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1));
226       acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1));
227     }
228
229 #endif
230
231 #if defined (USE_NEON_DOTPROD)
232
233     [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2(
234         int32x4_t& acc,
235         int8x16_t a0, int8x16_t b0,
236         int8x16_t a1, int8x16_t b1) {
237
238         acc = vdotq_s32(acc, a0, b0);
239         acc = vdotq_s32(acc, a1, b1);
240     }
241
242     [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32(
243         int32x4_t& acc,
244         int8x16_t a, int8x16_t b) {
245
246         acc = vdotq_s32(acc, a, b);
247     }
248 #endif
249
250 #if defined (USE_NEON)
251
252     [[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {
253 #   if USE_NEON >= 8
254       return vaddvq_s32(s);
255 #   else
256       return s[0] + s[1] + s[2] + s[3];
257 #   endif
258     }
259
260     [[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) {
261       return neon_m128_reduce_add_epi32(sum) + bias;
262     }
263
264     [[maybe_unused]] static int32x4_t neon_m128_haddx4(
265         int32x4_t sum0, int32x4_t sum1, int32x4_t sum2, int32x4_t sum3,
266         int32x4_t bias) {
267
268       int32x4_t hsums {
269         neon_m128_reduce_add_epi32(sum0),
270         neon_m128_reduce_add_epi32(sum1),
271         neon_m128_reduce_add_epi32(sum2),
272         neon_m128_reduce_add_epi32(sum3)
273       };
274       return vaddq_s32(hsums, bias);
275     }
276
277     [[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2(
278         int32x4_t& acc,
279         int8x8_t a0, int8x8_t b0,
280         int8x8_t a1, int8x8_t b1) {
281
282       int16x8_t product = vmull_s8(a0, b0);
283       product = vmlal_s8(product, a1, b1);
284       acc = vpadalq_s16(acc, product);
285     }
286 #endif
287
288 #if USE_NEON >= 8
289     [[maybe_unused]] static void neon_m128_add_dpbusd_epi32(
290         int32x4_t& acc,
291         int8x16_t a, int8x16_t b) {
292
293       int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
294       int16x8_t product1 = vmull_high_s8(a, b);
295       int16x8_t sum = vpaddq_s16(product0, product1);
296       acc = vpadalq_s16(acc, sum);
297     }
298 #endif
299 }
300
301 #endif // STOCKFISH_SIMD_H_INCLUDED