2 Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3 Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file)
5 Stockfish is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, either version 3 of the License, or
8 (at your option) any later version.
10 Stockfish is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
15 You should have received a copy of the GNU General Public License
16 along with this program. If not, see <http://www.gnu.org/licenses/>.
19 // Definition of layer AffineTransform of NNUE evaluation function
21 #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED
22 #define NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED
25 #include "../nnue_common.h"
27 namespace Eval::NNUE::Layers {
29 // Affine transformation layer
30 template <typename PreviousLayer, IndexType OutputDimensions>
31 class AffineTransform {
34 using InputType = typename PreviousLayer::OutputType;
35 using OutputType = std::int32_t;
36 static_assert(std::is_same<InputType, std::uint8_t>::value, "");
38 // Number of input/output dimensions
39 static constexpr IndexType kInputDimensions =
40 PreviousLayer::kOutputDimensions;
41 static constexpr IndexType kOutputDimensions = OutputDimensions;
42 static constexpr IndexType kPaddedInputDimensions =
43 CeilToMultiple<IndexType>(kInputDimensions, kMaxSimdWidth);
45 // Size of forward propagation buffer used in this layer
46 static constexpr std::size_t kSelfBufferSize =
47 CeilToMultiple(kOutputDimensions * sizeof(OutputType), kCacheLineSize);
49 // Size of the forward propagation buffer used from the input layer to this layer
50 static constexpr std::size_t kBufferSize =
51 PreviousLayer::kBufferSize + kSelfBufferSize;
53 // Hash value embedded in the evaluation file
54 static constexpr std::uint32_t GetHashValue() {
55 std::uint32_t hash_value = 0xCC03DAE4u;
56 hash_value += kOutputDimensions;
57 hash_value ^= PreviousLayer::GetHashValue() >> 1;
58 hash_value ^= PreviousLayer::GetHashValue() << 31;
62 // Read network parameters
63 bool ReadParameters(std::istream& stream) {
64 if (!previous_layer_.ReadParameters(stream)) return false;
65 for (std::size_t i = 0; i < kOutputDimensions; ++i)
66 biases_[i] = read_little_endian<BiasType>(stream);
67 for (std::size_t i = 0; i < kOutputDimensions * kPaddedInputDimensions; ++i)
68 weights_[i] = read_little_endian<WeightType>(stream);
70 #if defined (USE_SSSE3)
71 // Determine if quadruplets of weight and input products can be summed using 16bits
72 // without saturation. We assume worst case combinations of 0 and 127 for all inputs.
75 auto can_saturate = [](const WeightType* w, int idx[4]) {
76 int pSum = 0, nSum = 0;
77 for (int p = 0; p < 4; ++p)
83 return pSum > 258 || nSum < -258;
86 for (IndexType i = 0; i < kOutputDimensions; ++i)
88 canSaturate16[i] = false;
89 const WeightType* w = &weights_[i * kPaddedInputDimensions];
90 #if defined (USE_AVX512)
91 for (IndexType j = 0; j < (kPaddedInputDimensions & ~127) && !canSaturate16[i]; j += 128)
92 for (int k = 0; k < 64 && !canSaturate16[i]; k += 2)
94 int spacing[4] = { 0, 1, 64, 65 };
95 canSaturate16[i] = can_saturate(&w[j + k], spacing);
97 #elif defined (USE_AVX2)
98 for (IndexType j = 0; j < (kPaddedInputDimensions & ~63) && !canSaturate16[i]; j += 64)
99 for (int k = 0; k < 32 && !canSaturate16[i]; k += 2)
101 int spacing[4] = { 0, 1, 32, 33 };
102 canSaturate16[i] = can_saturate(&w[j + k], spacing);
104 #elif defined (USE_SSSE3)
105 for (IndexType j = 0; j < (kPaddedInputDimensions & ~31) && !canSaturate16[i]; j += 32)
106 for (int k = 0; k < 16 && !canSaturate16[i]; k += 2)
108 int spacing[4] = { 0, 1, 16, 17 };
109 canSaturate16[i] = can_saturate(&w[j + k], spacing);
116 return !stream.fail();
119 // Forward propagation
120 const OutputType* Propagate(
121 const TransformedFeatureType* transformed_features, char* buffer) const {
122 const auto input = previous_layer_.Propagate(
123 transformed_features, buffer + kSelfBufferSize);
125 #if defined (USE_AVX512)
127 [[maybe_unused]] const __m512i kOnes512 = _mm512_set1_epi16(1);
129 [[maybe_unused]] auto m512_hadd = [](__m512i sum, int bias) -> int {
130 return _mm512_reduce_add_epi32(sum) + bias;
133 // This function takes
134 // sum0 = [xmm0a, xmm0b, xmm0c, xmm0d]
135 // sum1 = [xmm1a, xmm1b, xmm1c, xmm1d]
136 // sum2 = [xmm2a, xmm2b, xmm2c, xmm2d]
137 // sum3 = [xmm3a, xmm3b, xmm3c, xmm3d]
140 // reduce_add_epi32(xmm0a), reduce_add_epi32(xmm1a), reduce_add_epi32(xmm2a), reduce_add_epi32(xmm3a),
141 // reduce_add_epi32(xmm0b), reduce_add_epi32(xmm1b), reduce_add_epi32(xmm2b), reduce_add_epi32(xmm3b),
142 // reduce_add_epi32(xmm0c), reduce_add_epi32(xmm1c), reduce_add_epi32(xmm2c), reduce_add_epi32(xmm3c),
143 // reduce_add_epi32(xmm0d), reduce_add_epi32(xmm1d), reduce_add_epi32(xmm2d), reduce_add_epi32(xmm3d)
145 [[maybe_unused]] auto m512_hadd128x16_interleave = [](
146 __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
148 __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
149 __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
151 __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
152 __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
154 __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
155 __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
157 __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
158 __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
160 return _mm512_add_epi32(sum0123a, sum0123b);
163 [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
164 __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
166 __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
168 __m256i sum256lo = _mm512_castsi512_si256(sum);
169 __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
171 sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
173 __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
174 __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
176 return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
179 [[maybe_unused]] auto m512_haddx8 = [m512_hadd128x16_interleave](
180 __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
181 __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m256i bias) -> __m256i {
183 __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
184 __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
186 __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
187 __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
188 __m512i x = _mm512_add_epi32(
189 _mm512_permutex2var_epi64(suma, indices0, sumb),
190 _mm512_permutex2var_epi64(suma, indices1, sumb));
192 __m256i sum256lo = _mm512_castsi512_si256(x);
193 __m256i sum256hi = _mm512_extracti64x4_epi64(x, 1);
195 return _mm256_add_epi32(_mm256_add_epi32(sum256lo, sum256hi), bias);
198 [[maybe_unused]] auto m512_hadd256x8 =[m512_hadd128x16_interleave](
199 __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m256i bias) -> __m256i {
201 __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
203 __m512i indices = _mm512_setr_epi32(
204 0, 4, 8, 12, 2, 6, 10, 14,
205 1, 5, 9, 13, 3, 7, 11, 15);
206 sum = _mm512_permutexvar_epi32(indices, sum);
208 __m256i sum256lo = _mm512_castsi512_si256(sum);
209 __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
211 return _mm256_add_epi32(_mm256_hadd_epi32(sum256lo, sum256hi), bias);
214 [[maybe_unused]] auto m512_hadd256x16 = [m512_hadd128x16_interleave](
215 __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
216 __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m512i bias) -> __m512i {
218 __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
219 __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
221 __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
222 __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
223 __m512i x = _mm512_add_epi32(
224 _mm512_permutex2var_epi64(suma, indices0, sumb),
225 _mm512_permutex2var_epi64(suma, indices1, sumb));
227 __m512i indices = _mm512_setr_epi32(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
228 return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
231 [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
232 #if defined (USE_VNNI)
233 acc = _mm512_dpbusd_epi32(acc, a, b);
235 __m512i product0 = _mm512_maddubs_epi16(a, b);
236 product0 = _mm512_madd_epi16(product0, kOnes512);
237 acc = _mm512_add_epi32(acc, product0);
241 [[maybe_unused]] auto m512_add_dpbusd_epi32x2 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
242 #if defined (USE_VNNI)
243 acc = _mm512_dpbusd_epi32(acc, a0, b0);
244 acc = _mm512_dpbusd_epi32(acc, a1, b1);
246 __m512i product0 = _mm512_maddubs_epi16(a0, b0);
247 __m512i product1 = _mm512_maddubs_epi16(a1, b1);
248 product0 = _mm512_adds_epi16(product0, product1);
249 product0 = _mm512_madd_epi16(product0, kOnes512);
250 acc = _mm512_add_epi32(acc, product0);
255 #if defined (USE_AVX2)
257 [[maybe_unused]] const __m256i kOnes256 = _mm256_set1_epi16(1);
259 [[maybe_unused]] auto m256_hadd = [](__m256i sum, int bias) -> int {
260 __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
261 sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
262 sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
263 return _mm_cvtsi128_si32(sum128) + bias;
266 [[maybe_unused]] auto m256_haddx4 = [](__m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3, __m128i bias) -> __m128i {
267 sum0 = _mm256_hadd_epi32(sum0, sum1);
268 sum2 = _mm256_hadd_epi32(sum2, sum3);
270 sum0 = _mm256_hadd_epi32(sum0, sum2);
272 __m128i sum128lo = _mm256_castsi256_si128(sum0);
273 __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
275 return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
278 [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
279 #if defined (USE_VNNI)
280 acc = _mm256_dpbusd_epi32(acc, a, b);
282 __m256i product0 = _mm256_maddubs_epi16(a, b);
283 product0 = _mm256_madd_epi16(product0, kOnes256);
284 acc = _mm256_add_epi32(acc, product0);
288 [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
289 #if defined (USE_VNNI)
290 acc = _mm256_dpbusd_epi32(acc, a0, b0);
291 acc = _mm256_dpbusd_epi32(acc, a1, b1);
293 __m256i product0 = _mm256_maddubs_epi16(a0, b0);
294 __m256i product1 = _mm256_maddubs_epi16(a1, b1);
295 product0 = _mm256_adds_epi16(product0, product1);
296 product0 = _mm256_madd_epi16(product0, kOnes256);
297 acc = _mm256_add_epi32(acc, product0);
303 #if defined (USE_SSSE3)
305 [[maybe_unused]] const __m128i kOnes128 = _mm_set1_epi16(1);
307 [[maybe_unused]] auto m128_hadd = [](__m128i sum, int bias) -> int {
308 sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
309 sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
310 return _mm_cvtsi128_si32(sum) + bias;
313 [[maybe_unused]] auto m128_haddx4 = [](__m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3, __m128i bias) -> __m128i {
314 sum0 = _mm_hadd_epi32(sum0, sum1);
315 sum2 = _mm_hadd_epi32(sum2, sum3);
317 sum0 = _mm_hadd_epi32(sum0, sum2);
319 return _mm_add_epi32(sum0, bias);
322 [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
323 __m128i product0 = _mm_maddubs_epi16(a, b);
324 product0 = _mm_madd_epi16(product0, kOnes128);
325 acc = _mm_add_epi32(acc, product0);
328 [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
329 __m128i product0 = _mm_maddubs_epi16(a0, b0);
330 __m128i product1 = _mm_maddubs_epi16(a1, b1);
331 product0 = _mm_adds_epi16(product0, product1);
332 product0 = _mm_madd_epi16(product0, kOnes128);
333 acc = _mm_add_epi32(acc, product0);
338 #if defined (USE_AVX512)
340 constexpr IndexType kNumChunks512 = kPaddedInputDimensions / (kSimdWidth * 2);
341 constexpr IndexType kNumChunks256 = kPaddedInputDimensions / kSimdWidth;
343 const auto output = reinterpret_cast<OutputType*>(buffer);
345 // Since to saturate a zmm register it takes 64 bytes we
346 // cannot use AVX512 for the smaller affine transforms.
347 // Instead we fallback to a AVX2 implementation if the
348 // kInputDimensions isn't a multiple of 64.
349 // Note that this means that for example for
350 // kInputDimensions of 96 we fallback to AVX2 even though
351 // the first 64 elements could be processed with AVX512.
352 // This is caused by mixing the __m256 and __m512 variables
353 // required to better handle that case and it would
354 // require handling more cases statically not to lose performance.
355 // This should be revisited if such input dimensions are to be considered.
356 [[maybe_unused]] const auto input_vector512 = reinterpret_cast<const __m512i*>(input);
357 [[maybe_unused]] const auto input_vector256 = reinterpret_cast<const __m256i*>(input);
359 // kOutputDimensions is either 1 or a multiple of kSimdWidth
360 // because then it is also an input dimension.
361 if constexpr (kOutputDimensions % 16 == 0 && kNumChunks256 == 1)
363 for (IndexType i = 0; i < kOutputDimensions; i += 16)
365 const IndexType offset01a = (i + 0) * kPaddedInputDimensions;
366 const IndexType offset23a = (i + 2) * kPaddedInputDimensions;
367 const IndexType offset45a = (i + 4) * kPaddedInputDimensions;
368 const IndexType offset67a = (i + 6) * kPaddedInputDimensions;
369 const IndexType offset01b = (i + 8) * kPaddedInputDimensions;
370 const IndexType offset23b = (i + 10) * kPaddedInputDimensions;
371 const IndexType offset45b = (i + 12) * kPaddedInputDimensions;
372 const IndexType offset67b = (i + 14) * kPaddedInputDimensions;
374 const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
375 __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
377 __m512i sum01a = _mm512_setzero_si512();
378 __m512i sum23a = _mm512_setzero_si512();
379 __m512i sum45a = _mm512_setzero_si512();
380 __m512i sum67a = _mm512_setzero_si512();
381 __m512i sum01b = _mm512_setzero_si512();
382 __m512i sum23b = _mm512_setzero_si512();
383 __m512i sum45b = _mm512_setzero_si512();
384 __m512i sum67b = _mm512_setzero_si512();
386 const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
387 const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
388 const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
389 const auto row67a = *reinterpret_cast<const __m512i*>(&weights_[offset67a]);
390 const auto row01b = *reinterpret_cast<const __m512i*>(&weights_[offset01b]);
391 const auto row23b = *reinterpret_cast<const __m512i*>(&weights_[offset23b]);
392 const auto row45b = *reinterpret_cast<const __m512i*>(&weights_[offset45b]);
393 const auto row67b = *reinterpret_cast<const __m512i*>(&weights_[offset67b]);
395 const __m256i in256 = input_vector256[0];
396 const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
398 m512_add_dpbusd_epi32(sum01a, in, row01a);
399 m512_add_dpbusd_epi32(sum23a, in, row23a);
400 m512_add_dpbusd_epi32(sum45a, in, row45a);
401 m512_add_dpbusd_epi32(sum67a, in, row67a);
402 m512_add_dpbusd_epi32(sum01b, in, row01b);
403 m512_add_dpbusd_epi32(sum23b, in, row23b);
404 m512_add_dpbusd_epi32(sum45b, in, row45b);
405 m512_add_dpbusd_epi32(sum67b, in, row67b);
407 *outptr = m512_hadd256x16(
408 sum01a, sum23a, sum45a, sum67a,
409 sum01b, sum23b, sum45b, sum67b, bias);
412 else if constexpr (kOutputDimensions % 4 == 0)
414 for (IndexType i = 0; i < kOutputDimensions; i += 4)
416 const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
417 const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
418 const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
419 const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
421 const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
422 __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
424 if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
426 __m512i sum0 = _mm512_setzero_si512();
427 __m512i sum1 = _mm512_setzero_si512();
428 __m512i sum2 = _mm512_setzero_si512();
429 __m512i sum3 = _mm512_setzero_si512();
431 const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
432 const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
433 const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
434 const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
437 if (!canSaturate16x4[i / 4])
439 for (; j < (int)kNumChunks512 - 1; j += 2)
441 const __m512i in0 = input_vector512[j];
442 const __m512i in1 = input_vector512[j + 1];
444 m512_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
445 m512_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
446 m512_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
447 m512_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
450 for (; j < (int)kNumChunks512; ++j)
452 const __m512i in = input_vector512[j];
454 m512_add_dpbusd_epi32(sum0, in, row0[j]);
455 m512_add_dpbusd_epi32(sum1, in, row1[j]);
456 m512_add_dpbusd_epi32(sum2, in, row2[j]);
457 m512_add_dpbusd_epi32(sum3, in, row3[j]);
460 *outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias);
464 __m256i sum0 = _mm256_setzero_si256();
465 __m256i sum1 = _mm256_setzero_si256();
466 __m256i sum2 = _mm256_setzero_si256();
467 __m256i sum3 = _mm256_setzero_si256();
469 const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
470 const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
471 const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
472 const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
474 for (IndexType j = 0; j < kNumChunks256; ++j)
476 const __m256i in = input_vector256[j];
478 m256_add_dpbusd_epi32(sum0, in, row0[j]);
479 m256_add_dpbusd_epi32(sum1, in, row1[j]);
480 m256_add_dpbusd_epi32(sum2, in, row2[j]);
481 m256_add_dpbusd_epi32(sum3, in, row3[j]);
484 *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
488 else if constexpr (kOutputDimensions == 1)
490 if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
492 __m512i sum0 = _mm512_setzero_si512();
494 const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
496 for (IndexType j = 0; j < kNumChunks512; ++j)
498 const __m512i in = input_vector512[j];
500 m512_add_dpbusd_epi32(sum0, in, row0[j]);
503 output[0] = m512_hadd(sum0, biases_[0]);
507 __m256i sum0 = _mm256_setzero_si256();
509 const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
511 for (IndexType j = 0; j < kNumChunks256; ++j)
513 const __m256i in = input_vector256[j];
515 m256_add_dpbusd_epi32(sum0, in, row0[j]);
518 output[0] = m256_hadd(sum0, biases_[0]);
523 // This case can never happen because kOutputDimensions
524 // is always 1 or a multiple of kSimdWidth.
528 #elif defined (USE_AVX2)
530 constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
532 const auto output = reinterpret_cast<OutputType*>(buffer);
533 const auto input_vector = reinterpret_cast<const __m256i*>(input);
535 // kOutputDimensions is either 1 or a multiple of kSimdWidth
536 // because then it is also an input dimension.
537 if constexpr (kOutputDimensions % 4 == 0)
539 for (IndexType i = 0; i < kOutputDimensions; i += 4)
541 const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
542 const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
543 const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
544 const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
546 const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
547 __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
549 __m256i sum0 = _mm256_setzero_si256();
550 __m256i sum1 = _mm256_setzero_si256();
551 __m256i sum2 = _mm256_setzero_si256();
552 __m256i sum3 = _mm256_setzero_si256();
554 const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
555 const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
556 const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
557 const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
560 if (!canSaturate16x4[i / 4])
562 for (; j < (int)kNumChunks - 1; j += 2)
564 const __m256i in0 = input_vector[j];
565 const __m256i in1 = input_vector[j + 1];
567 m256_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
568 m256_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
569 m256_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
570 m256_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
573 for (; j < (int)kNumChunks; ++j)
575 const __m256i in = input_vector[j];
577 m256_add_dpbusd_epi32(sum0, in, row0[j]);
578 m256_add_dpbusd_epi32(sum1, in, row1[j]);
579 m256_add_dpbusd_epi32(sum2, in, row2[j]);
580 m256_add_dpbusd_epi32(sum3, in, row3[j]);
583 *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
586 else if constexpr (kOutputDimensions == 1)
588 __m256i sum0 = _mm256_setzero_si256();
590 const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
592 for (IndexType j = 0; j < kNumChunks; ++j)
594 const __m256i in = input_vector[j];
596 m256_add_dpbusd_epi32(sum0, in, row0[j]);
599 output[0] = m256_hadd(sum0, biases_[0]);
603 // This case can never happen because kOutputDimensions
604 // is always 1 or a multiple of kSimdWidth.
608 #elif defined (USE_SSSE3)
610 constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
612 auto output = reinterpret_cast<OutputType*>(buffer);
613 const auto input_vector = reinterpret_cast<const __m128i*>(input);
615 // kOutputDimensions is either 1 or a multiple of kSimdWidth
616 // because then it is also an input dimension.
617 if constexpr (kOutputDimensions % 4 == 0)
619 for (IndexType i = 0; i < kOutputDimensions; i += 4)
621 const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
622 const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
623 const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
624 const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
626 const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
627 __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
629 __m128i sum0 = _mm_setzero_si128();
630 __m128i sum1 = _mm_setzero_si128();
631 __m128i sum2 = _mm_setzero_si128();
632 __m128i sum3 = _mm_setzero_si128();
634 const auto row0 = reinterpret_cast<const __m128i*>(&weights_[offset0]);
635 const auto row1 = reinterpret_cast<const __m128i*>(&weights_[offset1]);
636 const auto row2 = reinterpret_cast<const __m128i*>(&weights_[offset2]);
637 const auto row3 = reinterpret_cast<const __m128i*>(&weights_[offset3]);
640 if (!canSaturate16x4[i / 4])
642 for (; j < (int)kNumChunks - 1; j += 2)
644 const __m128i in0 = input_vector[j];
645 const __m128i in1 = input_vector[j + 1];
647 m128_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
648 m128_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
649 m128_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
650 m128_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
653 for (; j < (int)kNumChunks; ++j)
655 const __m128i in = input_vector[j];
657 m128_add_dpbusd_epi32(sum0, in, row0[j]);
658 m128_add_dpbusd_epi32(sum1, in, row1[j]);
659 m128_add_dpbusd_epi32(sum2, in, row2[j]);
660 m128_add_dpbusd_epi32(sum3, in, row3[j]);
663 *outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias);
666 else if constexpr (kOutputDimensions == 1)
668 __m128i sum0 = _mm_setzero_si128();
670 const auto row0 = reinterpret_cast<const __m128i*>(&weights_[0]);
672 for (int j = 0; j < (int)kNumChunks; ++j)
674 const __m128i in = input_vector[j];
676 m128_add_dpbusd_epi32(sum0, in, row0[j]);
679 output[0] = m128_hadd(sum0, biases_[0]);
683 // This case can never happen because kOutputDimensions
684 // is always 1 or a multiple of kSimdWidth.
690 // Use old implementation for the other architectures.
692 auto output = reinterpret_cast<OutputType*>(buffer);
694 #if defined(USE_SSE2)
695 constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
697 const __m128i kZeros = _mm_setzero_si128();
699 const __m128i kOnes = _mm_set1_epi16(1);
701 const auto input_vector = reinterpret_cast<const __m128i*>(input);
703 #elif defined(USE_MMX)
704 constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
705 const __m64 kZeros = _mm_setzero_si64();
706 const auto input_vector = reinterpret_cast<const __m64*>(input);
708 #elif defined(USE_NEON)
709 constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
710 const auto input_vector = reinterpret_cast<const int8x8_t*>(input);
713 for (IndexType i = 0; i < kOutputDimensions; ++i) {
714 const IndexType offset = i * kPaddedInputDimensions;
716 #if defined(USE_SSE2)
717 __m128i sum_lo = _mm_cvtsi32_si128(biases_[i]);
718 __m128i sum_hi = kZeros;
719 const auto row = reinterpret_cast<const __m128i*>(&weights_[offset]);
720 for (IndexType j = 0; j < kNumChunks; ++j) {
721 __m128i row_j = _mm_load_si128(&row[j]);
722 __m128i input_j = _mm_load_si128(&input_vector[j]);
723 __m128i extended_row_lo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
724 __m128i extended_row_hi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);
725 __m128i extended_input_lo = _mm_unpacklo_epi8(input_j, kZeros);
726 __m128i extended_input_hi = _mm_unpackhi_epi8(input_j, kZeros);
727 __m128i product_lo = _mm_madd_epi16(extended_row_lo, extended_input_lo);
728 __m128i product_hi = _mm_madd_epi16(extended_row_hi, extended_input_hi);
729 sum_lo = _mm_add_epi32(sum_lo, product_lo);
730 sum_hi = _mm_add_epi32(sum_hi, product_hi);
732 __m128i sum = _mm_add_epi32(sum_lo, sum_hi);
733 __m128i sum_high_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
734 sum = _mm_add_epi32(sum, sum_high_64);
735 __m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2));
736 sum = _mm_add_epi32(sum, sum_second_32);
737 output[i] = _mm_cvtsi128_si32(sum);
739 #elif defined(USE_MMX)
740 __m64 sum_lo = _mm_cvtsi32_si64(biases_[i]);
741 __m64 sum_hi = kZeros;
742 const auto row = reinterpret_cast<const __m64*>(&weights_[offset]);
743 for (IndexType j = 0; j < kNumChunks; ++j) {
744 __m64 row_j = row[j];
745 __m64 input_j = input_vector[j];
746 __m64 extended_row_lo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8);
747 __m64 extended_row_hi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8);
748 __m64 extended_input_lo = _mm_unpacklo_pi8(input_j, kZeros);
749 __m64 extended_input_hi = _mm_unpackhi_pi8(input_j, kZeros);
750 __m64 product_lo = _mm_madd_pi16(extended_row_lo, extended_input_lo);
751 __m64 product_hi = _mm_madd_pi16(extended_row_hi, extended_input_hi);
752 sum_lo = _mm_add_pi32(sum_lo, product_lo);
753 sum_hi = _mm_add_pi32(sum_hi, product_hi);
755 __m64 sum = _mm_add_pi32(sum_lo, sum_hi);
756 sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum));
757 output[i] = _mm_cvtsi64_si32(sum);
759 #elif defined(USE_NEON)
760 int32x4_t sum = {biases_[i]};
761 const auto row = reinterpret_cast<const int8x8_t*>(&weights_[offset]);
762 for (IndexType j = 0; j < kNumChunks; ++j) {
763 int16x8_t product = vmull_s8(input_vector[j * 2], row[j * 2]);
764 product = vmlal_s8(product, input_vector[j * 2 + 1], row[j * 2 + 1]);
765 sum = vpadalq_s16(sum, product);
767 output[i] = sum[0] + sum[1] + sum[2] + sum[3];
770 OutputType sum = biases_[i];
771 for (IndexType j = 0; j < kInputDimensions; ++j) {
772 sum += weights_[offset + j] * input[j];
788 using BiasType = OutputType;
789 using WeightType = std::int8_t;
791 PreviousLayer previous_layer_;
793 alignas(kCacheLineSize) BiasType biases_[kOutputDimensions];
794 alignas(kCacheLineSize) WeightType weights_[kOutputDimensions * kPaddedInputDimensions];
796 uint32_t canSaturate16x4[(kOutputDimensions + 3) / 4];
797 bool canSaturate16[kOutputDimensions];
801 } // namespace Eval::NNUE::Layers
803 #endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED