]> git.sesse.net Git - stockfish/blob - src/nnue/layers/simd.h
349217edb7a607e38dff89eaafc640d04e0674b9
[stockfish] / src / nnue / layers / simd.h
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #ifndef STOCKFISH_SIMD_H_INCLUDED
20 #define STOCKFISH_SIMD_H_INCLUDED
21
22 #if defined(USE_AVX2)
23 # include <immintrin.h>
24
25 #elif defined(USE_SSE41)
26 # include <smmintrin.h>
27
28 #elif defined(USE_SSSE3)
29 # include <tmmintrin.h>
30
31 #elif defined(USE_SSE2)
32 # include <emmintrin.h>
33
34 #elif defined(USE_NEON)
35 # include <arm_neon.h>
36 #endif
37
38 namespace Stockfish::Simd {
39
40 #if defined (USE_AVX512)
41
42     [[maybe_unused]] static int m512_hadd(__m512i sum, int bias) {
43       return _mm512_reduce_add_epi32(sum) + bias;
44     }
45
46     /*
47       Parameters:
48         sum0 = [zmm0.i128[0], zmm0.i128[1], zmm0.i128[2], zmm0.i128[3]]
49         sum1 = [zmm1.i128[0], zmm1.i128[1], zmm1.i128[2], zmm1.i128[3]]
50         sum2 = [zmm2.i128[0], zmm2.i128[1], zmm2.i128[2], zmm2.i128[3]]
51         sum3 = [zmm3.i128[0], zmm3.i128[1], zmm3.i128[2], zmm3.i128[3]]
52
53       Returns:
54         ret = [
55           reduce_add_epi32(zmm0.i128[0]), reduce_add_epi32(zmm1.i128[0]), reduce_add_epi32(zmm2.i128[0]), reduce_add_epi32(zmm3.i128[0]),
56           reduce_add_epi32(zmm0.i128[1]), reduce_add_epi32(zmm1.i128[1]), reduce_add_epi32(zmm2.i128[1]), reduce_add_epi32(zmm3.i128[1]),
57           reduce_add_epi32(zmm0.i128[2]), reduce_add_epi32(zmm1.i128[2]), reduce_add_epi32(zmm2.i128[2]), reduce_add_epi32(zmm3.i128[2]),
58           reduce_add_epi32(zmm0.i128[3]), reduce_add_epi32(zmm1.i128[3]), reduce_add_epi32(zmm2.i128[3]), reduce_add_epi32(zmm3.i128[3])
59         ]
60     */
61     [[maybe_unused]] static __m512i m512_hadd128x16_interleave(
62         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) {
63
64       __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
65       __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
66
67       __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
68       __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
69
70       __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
71       __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
72
73       __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
74       __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
75
76       return _mm512_add_epi32(sum0123a, sum0123b);
77     }
78
79     [[maybe_unused]] static void m512_add_dpbusd_epi32(
80         __m512i& acc,
81         __m512i a,
82         __m512i b) {
83
84 # if defined (USE_VNNI)
85       acc = _mm512_dpbusd_epi32(acc, a, b);
86 # else
87       __m512i product0 = _mm512_maddubs_epi16(a, b);
88       product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
89       acc = _mm512_add_epi32(acc, product0);
90 # endif
91     }
92
93     [[maybe_unused]] static void m512_add_dpbusd_epi32x2(
94         __m512i& acc,
95         __m512i a0, __m512i b0,
96         __m512i a1, __m512i b1) {
97
98 # if defined (USE_VNNI)
99       acc = _mm512_dpbusd_epi32(acc, a0, b0);
100       acc = _mm512_dpbusd_epi32(acc, a1, b1);
101 # else
102       __m512i product0 = _mm512_maddubs_epi16(a0, b0);
103       __m512i product1 = _mm512_maddubs_epi16(a1, b1);
104       product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
105       product1 = _mm512_madd_epi16(product1, _mm512_set1_epi16(1));
106       acc = _mm512_add_epi32(acc, _mm512_add_epi32(product0, product1));
107 # endif
108     }
109
110 #endif
111
112 #if defined (USE_AVX2)
113
114     [[maybe_unused]] static int m256_hadd(__m256i sum, int bias) {
115       __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
116       sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
117       sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
118       return _mm_cvtsi128_si32(sum128) + bias;
119     }
120
121     [[maybe_unused]] static void m256_add_dpbusd_epi32(
122         __m256i& acc,
123         __m256i a,
124         __m256i b) {
125
126 # if defined (USE_VNNI)
127       acc = _mm256_dpbusd_epi32(acc, a, b);
128 # else
129       __m256i product0 = _mm256_maddubs_epi16(a, b);
130       product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
131       acc = _mm256_add_epi32(acc, product0);
132 # endif
133     }
134
135     [[maybe_unused]] static void m256_add_dpbusd_epi32x2(
136         __m256i& acc,
137         __m256i a0, __m256i b0,
138         __m256i a1, __m256i b1) {
139
140 # if defined (USE_VNNI)
141       acc = _mm256_dpbusd_epi32(acc, a0, b0);
142       acc = _mm256_dpbusd_epi32(acc, a1, b1);
143 # else
144       __m256i product0 = _mm256_maddubs_epi16(a0, b0);
145       __m256i product1 = _mm256_maddubs_epi16(a1, b1);
146       product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
147       product1 = _mm256_madd_epi16(product1, _mm256_set1_epi16(1));
148       acc = _mm256_add_epi32(acc, _mm256_add_epi32(product0, product1));
149 # endif
150     }
151
152 #endif
153
154 #if defined (USE_SSSE3)
155
156     [[maybe_unused]] static int m128_hadd(__m128i sum, int bias) {
157       sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
158       sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
159       return _mm_cvtsi128_si32(sum) + bias;
160     }
161
162     [[maybe_unused]] static void m128_add_dpbusd_epi32(
163         __m128i& acc,
164         __m128i a,
165         __m128i b) {
166
167       __m128i product0 = _mm_maddubs_epi16(a, b);
168       product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
169       acc = _mm_add_epi32(acc, product0);
170     }
171
172     [[maybe_unused]] static void m128_add_dpbusd_epi32x2(
173         __m128i& acc,
174         __m128i a0, __m128i b0,
175         __m128i a1, __m128i b1) {
176
177       __m128i product0 = _mm_maddubs_epi16(a0, b0);
178       __m128i product1 = _mm_maddubs_epi16(a1, b1);
179       product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
180       product1 = _mm_madd_epi16(product1, _mm_set1_epi16(1));
181       acc = _mm_add_epi32(acc, _mm_add_epi32(product0, product1));
182     }
183
184 #endif
185
186 #if defined (USE_NEON_DOTPROD)
187
188     [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32x2(
189         int32x4_t& acc,
190         int8x16_t a0, int8x16_t b0,
191         int8x16_t a1, int8x16_t b1) {
192
193         acc = vdotq_s32(acc, a0, b0);
194         acc = vdotq_s32(acc, a1, b1);
195     }
196
197     [[maybe_unused]] static void dotprod_m128_add_dpbusd_epi32(
198         int32x4_t& acc,
199         int8x16_t a, int8x16_t b) {
200
201         acc = vdotq_s32(acc, a, b);
202     }
203 #endif
204
205 #if defined (USE_NEON)
206
207     [[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {
208 #   if USE_NEON >= 8
209       return vaddvq_s32(s);
210 #   else
211       return s[0] + s[1] + s[2] + s[3];
212 #   endif
213     }
214
215     [[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) {
216       return neon_m128_reduce_add_epi32(sum) + bias;
217     }
218
219     [[maybe_unused]] static void neon_m128_add_dpbusd_epi32x2(
220         int32x4_t& acc,
221         int8x8_t a0, int8x8_t b0,
222         int8x8_t a1, int8x8_t b1) {
223
224       int16x8_t product = vmull_s8(a0, b0);
225       product = vmlal_s8(product, a1, b1);
226       acc = vpadalq_s16(acc, product);
227     }
228 #endif
229
230 #if USE_NEON >= 8
231     [[maybe_unused]] static void neon_m128_add_dpbusd_epi32(
232         int32x4_t& acc,
233         int8x16_t a, int8x16_t b) {
234
235       int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
236       int16x8_t product1 = vmull_high_s8(a, b);
237       int16x8_t sum = vpaddq_s16(product0, product1);
238       acc = vpadalq_s16(acc, sum);
239     }
240 #endif
241 }
242
243 #endif // STOCKFISH_SIMD_H_INCLUDED