]> git.sesse.net Git - stockfish/blob - src/nnue/layers/affine_transform.h
Fix compilation after recent merge.
[stockfish] / src / nnue / layers / affine_transform.h
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2020 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // Definition of layer AffineTransform of NNUE evaluation function
20
21 #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED
22 #define NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED
23
24 #include <iostream>
25 #include "../nnue_common.h"
26
27 namespace Eval::NNUE::Layers {
28
29   // Affine transformation layer
30   template <typename PreviousLayer, IndexType OutputDimensions>
31   class AffineTransform {
32    public:
33     // Input/output type
34     using InputType = typename PreviousLayer::OutputType;
35     using OutputType = std::int32_t;
36     static_assert(std::is_same<InputType, std::uint8_t>::value, "");
37
38     // Number of input/output dimensions
39     static constexpr IndexType kInputDimensions =
40         PreviousLayer::kOutputDimensions;
41     static constexpr IndexType kOutputDimensions = OutputDimensions;
42     static constexpr IndexType kPaddedInputDimensions =
43         CeilToMultiple<IndexType>(kInputDimensions, kMaxSimdWidth);
44
45     // Size of forward propagation buffer used in this layer
46     static constexpr std::size_t kSelfBufferSize =
47         CeilToMultiple(kOutputDimensions * sizeof(OutputType), kCacheLineSize);
48
49     // Size of the forward propagation buffer used from the input layer to this layer
50     static constexpr std::size_t kBufferSize =
51         PreviousLayer::kBufferSize + kSelfBufferSize;
52
53     // Hash value embedded in the evaluation file
54     static constexpr std::uint32_t GetHashValue() {
55       std::uint32_t hash_value = 0xCC03DAE4u;
56       hash_value += kOutputDimensions;
57       hash_value ^= PreviousLayer::GetHashValue() >> 1;
58       hash_value ^= PreviousLayer::GetHashValue() << 31;
59       return hash_value;
60     }
61
62    // Read network parameters
63     bool ReadParameters(std::istream& stream) {
64       if (!previous_layer_.ReadParameters(stream)) return false;
65       for (std::size_t i = 0; i < kOutputDimensions; ++i)
66         biases_[i] = read_little_endian<BiasType>(stream);
67       for (std::size_t i = 0; i < kOutputDimensions * kPaddedInputDimensions; ++i)
68         weights_[i] = read_little_endian<WeightType>(stream);
69
70 #if defined (USE_SSSE3)
71       // Determine if quadruplets of weight and input products can be summed using 16bits
72       // without saturation. We assume worst case combinations of 0 and 127 for all inputs.
73       if (!stream.fail())
74       {
75           auto can_saturate = [](const WeightType* w, int idx[4]) {
76               int pSum = 0, nSum = 0;
77               for (int p = 0; p < 4; ++p)
78                   if (w[idx[p]] > 0)
79                       pSum += w[idx[p]];
80                   else
81                       nSum += w[idx[p]];
82
83               return pSum > 258 || nSum < -258;
84           };
85
86           for (IndexType i = 0; i < kOutputDimensions; ++i)
87           {
88               canSaturate16[i] = false;
89               const WeightType* w = &weights_[i * kPaddedInputDimensions];
90 #if defined (USE_AVX512)
91               for (IndexType j = 0; j < (kPaddedInputDimensions & ~127) && !canSaturate16[i]; j += 128)
92                   for (int k = 0; k < 64 && !canSaturate16[i]; k += 2)
93                   {
94                       int spacing[4] = { 0, 1, 64, 65 };
95                       canSaturate16[i] = can_saturate(&w[j + k], spacing);
96                   }
97 #elif defined (USE_AVX2)
98               for (IndexType j = 0; j < (kPaddedInputDimensions & ~63) && !canSaturate16[i]; j += 64)
99                   for (int k = 0; k < 32 && !canSaturate16[i]; k += 2)
100                   {
101                       int spacing[4] = { 0, 1, 32, 33 };
102                       canSaturate16[i] = can_saturate(&w[j + k], spacing);
103                   }
104 #elif defined (USE_SSSE3)
105               for (IndexType j = 0; j < (kPaddedInputDimensions & ~31) && !canSaturate16[i]; j += 32)
106                   for (int k = 0; k < 16 && !canSaturate16[i]; k += 2)
107                   {
108                       int spacing[4] = { 0, 1, 16, 17 };
109                       canSaturate16[i] = can_saturate(&w[j + k], spacing);
110                   }
111 #endif
112           }
113       }
114 #endif
115
116       return !stream.fail();
117     }
118
119     // Forward propagation
120     const OutputType* Propagate(
121         const TransformedFeatureType* transformed_features, char* buffer) const {
122       const auto input = previous_layer_.Propagate(
123           transformed_features, buffer + kSelfBufferSize);
124
125 #if defined (USE_AVX512)
126
127       [[maybe_unused]] const __m512i kOnes512 = _mm512_set1_epi16(1);
128
129       [[maybe_unused]] auto m512_hadd = [](__m512i sum, int bias) -> int {
130         return _mm512_reduce_add_epi32(sum) + bias;
131       };
132
133       // This function takes
134       //   sum0 = [xmm0a, xmm0b, xmm0c, xmm0d]
135       //   sum1 = [xmm1a, xmm1b, xmm1c, xmm1d]
136       //   sum2 = [xmm2a, xmm2b, xmm2c, xmm2d]
137       //   sum3 = [xmm3a, xmm3b, xmm3c, xmm3d]
138       // and returns
139       //   ret = [
140       //     reduce_add_epi32(xmm0a), reduce_add_epi32(xmm1a), reduce_add_epi32(xmm2a), reduce_add_epi32(xmm3a),
141       //     reduce_add_epi32(xmm0b), reduce_add_epi32(xmm1b), reduce_add_epi32(xmm2b), reduce_add_epi32(xmm3b),
142       //     reduce_add_epi32(xmm0c), reduce_add_epi32(xmm1c), reduce_add_epi32(xmm2c), reduce_add_epi32(xmm3c),
143       //     reduce_add_epi32(xmm0d), reduce_add_epi32(xmm1d), reduce_add_epi32(xmm2d), reduce_add_epi32(xmm3d)
144       //   ]
145       [[maybe_unused]] auto m512_hadd128x16_interleave = [](
146         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3) -> __m512i {
147
148         __m512i sum01a = _mm512_unpacklo_epi32(sum0, sum1);
149         __m512i sum01b = _mm512_unpackhi_epi32(sum0, sum1);
150
151         __m512i sum23a = _mm512_unpacklo_epi32(sum2, sum3);
152         __m512i sum23b = _mm512_unpackhi_epi32(sum2, sum3);
153
154         __m512i sum01 = _mm512_add_epi32(sum01a, sum01b);
155         __m512i sum23 = _mm512_add_epi32(sum23a, sum23b);
156
157         __m512i sum0123a = _mm512_unpacklo_epi64(sum01, sum23);
158         __m512i sum0123b = _mm512_unpackhi_epi64(sum01, sum23);
159
160         return _mm512_add_epi32(sum0123a, sum0123b);
161       };
162
163       [[maybe_unused]] auto m512_haddx4 = [m512_hadd128x16_interleave](
164         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m128i bias) -> __m128i {
165
166         __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
167
168         __m256i sum256lo = _mm512_castsi512_si256(sum);
169         __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
170
171         sum256lo = _mm256_add_epi32(sum256lo, sum256hi);
172
173         __m128i sum128lo = _mm256_castsi256_si128(sum256lo);
174         __m128i sum128hi = _mm256_extracti128_si256(sum256lo, 1);
175
176         return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
177       };
178
179       [[maybe_unused]] auto m512_haddx8 = [m512_hadd128x16_interleave](
180         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
181         __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m256i bias) -> __m256i {
182
183         __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
184         __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
185
186         __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
187         __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
188         __m512i x = _mm512_add_epi32(
189           _mm512_permutex2var_epi64(suma, indices0, sumb),
190           _mm512_permutex2var_epi64(suma, indices1, sumb));
191
192         __m256i sum256lo = _mm512_castsi512_si256(x);
193         __m256i sum256hi = _mm512_extracti64x4_epi64(x, 1);
194
195         return _mm256_add_epi32(_mm256_add_epi32(sum256lo, sum256hi), bias);
196       };
197
198       [[maybe_unused]] auto m512_hadd256x8 =[m512_hadd128x16_interleave](
199         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3, __m256i bias) -> __m256i {
200
201         __m512i sum = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
202
203         __m512i indices = _mm512_setr_epi32(
204           0, 4, 8, 12, 2, 6, 10, 14,
205           1, 5, 9, 13, 3, 7, 11, 15);
206         sum = _mm512_permutexvar_epi32(indices, sum);
207
208         __m256i sum256lo = _mm512_castsi512_si256(sum);
209         __m256i sum256hi = _mm512_extracti64x4_epi64(sum, 1);
210
211         return _mm256_add_epi32(_mm256_hadd_epi32(sum256lo, sum256hi), bias);
212       };
213
214       [[maybe_unused]] auto m512_hadd256x16 = [m512_hadd128x16_interleave](
215         __m512i sum0, __m512i sum1, __m512i sum2, __m512i sum3,
216         __m512i sum4, __m512i sum5, __m512i sum6, __m512i sum7, __m512i bias) -> __m512i {
217
218         __m512i suma = m512_hadd128x16_interleave(sum0, sum1, sum2, sum3);
219         __m512i sumb = m512_hadd128x16_interleave(sum4, sum5, sum6, sum7);
220
221         __m512i indices0 = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
222         __m512i indices1 = _mm512_setr_epi64(2, 3, 10, 11, 6, 7, 14, 15);
223         __m512i x = _mm512_add_epi32(
224           _mm512_permutex2var_epi64(suma, indices0, sumb),
225           _mm512_permutex2var_epi64(suma, indices1, sumb));
226
227         __m512i indices = _mm512_setr_epi32(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
228         return _mm512_add_epi32(_mm512_permutexvar_epi32(indices, x), bias);
229       };
230
231       [[maybe_unused]] auto m512_add_dpbusd_epi32 = [=](__m512i& acc, __m512i a, __m512i b) {
232 #if defined (USE_VNNI)
233         acc = _mm512_dpbusd_epi32(acc, a, b);
234 #else
235         __m512i product0 = _mm512_maddubs_epi16(a, b);
236         product0 = _mm512_madd_epi16(product0, kOnes512);
237         acc = _mm512_add_epi32(acc, product0);
238 #endif
239       };
240
241       [[maybe_unused]] auto m512_add_dpbusd_epi32x2 = [=](__m512i& acc, __m512i a0, __m512i b0, __m512i a1, __m512i b1) {
242 #if defined (USE_VNNI)
243         acc = _mm512_dpbusd_epi32(acc, a0, b0);
244         acc = _mm512_dpbusd_epi32(acc, a1, b1);
245 #else
246         __m512i product0 = _mm512_maddubs_epi16(a0, b0);
247         __m512i product1 = _mm512_maddubs_epi16(a1, b1);
248         product0 = _mm512_adds_epi16(product0, product1);
249         product0 = _mm512_madd_epi16(product0, kOnes512);
250         acc = _mm512_add_epi32(acc, product0);
251 #endif
252       };
253
254 #endif
255 #if defined (USE_AVX2)
256
257       [[maybe_unused]] const __m256i kOnes256 = _mm256_set1_epi16(1);
258
259       [[maybe_unused]] auto m256_hadd = [](__m256i sum, int bias) -> int {
260         __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
261         sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
262         sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
263         return _mm_cvtsi128_si32(sum128) + bias;
264       };
265
266       [[maybe_unused]] auto m256_haddx4 = [](__m256i sum0, __m256i sum1, __m256i sum2, __m256i sum3, __m128i bias) -> __m128i {
267         sum0 = _mm256_hadd_epi32(sum0, sum1);
268         sum2 = _mm256_hadd_epi32(sum2, sum3);
269
270         sum0 = _mm256_hadd_epi32(sum0, sum2);
271
272         __m128i sum128lo = _mm256_castsi256_si128(sum0);
273         __m128i sum128hi = _mm256_extracti128_si256(sum0, 1);
274
275         return _mm_add_epi32(_mm_add_epi32(sum128lo, sum128hi), bias);
276       };
277
278       [[maybe_unused]] auto m256_add_dpbusd_epi32 = [=](__m256i& acc, __m256i a, __m256i b) {
279 #if defined (USE_VNNI)
280         acc = _mm256_dpbusd_epi32(acc, a, b);
281 #else
282         __m256i product0 = _mm256_maddubs_epi16(a, b);
283         product0 = _mm256_madd_epi16(product0, kOnes256);
284         acc = _mm256_add_epi32(acc, product0);
285 #endif
286       };
287
288       [[maybe_unused]] auto m256_add_dpbusd_epi32x2 = [=](__m256i& acc, __m256i a0, __m256i b0, __m256i a1, __m256i b1) {
289 #if defined (USE_VNNI)
290         acc = _mm256_dpbusd_epi32(acc, a0, b0);
291         acc = _mm256_dpbusd_epi32(acc, a1, b1);
292 #else
293         __m256i product0 = _mm256_maddubs_epi16(a0, b0);
294         __m256i product1 = _mm256_maddubs_epi16(a1, b1);
295         product0 = _mm256_adds_epi16(product0, product1);
296         product0 = _mm256_madd_epi16(product0, kOnes256);
297         acc = _mm256_add_epi32(acc, product0);
298 #endif
299       };
300
301 #endif
302
303 #if defined (USE_SSSE3)
304
305       [[maybe_unused]] const __m128i kOnes128 = _mm_set1_epi16(1);
306
307       [[maybe_unused]] auto m128_hadd = [](__m128i sum, int bias) -> int {
308         sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
309         sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
310         return _mm_cvtsi128_si32(sum) + bias;
311       };
312
313       [[maybe_unused]] auto m128_haddx4 = [](__m128i sum0, __m128i sum1, __m128i sum2, __m128i sum3, __m128i bias) -> __m128i {
314         sum0 = _mm_hadd_epi32(sum0, sum1);
315         sum2 = _mm_hadd_epi32(sum2, sum3);
316
317         sum0 = _mm_hadd_epi32(sum0, sum2);
318
319         return _mm_add_epi32(sum0, bias);
320       };
321
322       [[maybe_unused]] auto m128_add_dpbusd_epi32 = [=](__m128i& acc, __m128i a, __m128i b) {
323         __m128i product0 = _mm_maddubs_epi16(a, b);
324         product0 = _mm_madd_epi16(product0, kOnes128);
325         acc = _mm_add_epi32(acc, product0);
326       };
327
328       [[maybe_unused]] auto m128_add_dpbusd_epi32x2 = [=](__m128i& acc, __m128i a0, __m128i b0, __m128i a1, __m128i b1) {
329         __m128i product0 = _mm_maddubs_epi16(a0, b0);
330         __m128i product1 = _mm_maddubs_epi16(a1, b1);
331         product0 = _mm_adds_epi16(product0, product1);
332         product0 = _mm_madd_epi16(product0, kOnes128);
333         acc = _mm_add_epi32(acc, product0);
334       };
335
336 #endif
337
338 #if defined (USE_AVX512)
339
340       constexpr IndexType kNumChunks512 = kPaddedInputDimensions / (kSimdWidth * 2);
341       constexpr IndexType kNumChunks256 = kPaddedInputDimensions / kSimdWidth;
342
343       const auto output = reinterpret_cast<OutputType*>(buffer);
344
345       // Since to saturate a zmm register it takes 64 bytes we
346       // cannot use AVX512 for the smaller affine transforms.
347       // Instead we fallback to a AVX2 implementation if the
348       // kInputDimensions isn't a multiple of 64.
349       // Note that this means that for example for
350       // kInputDimensions of 96 we fallback to AVX2 even though
351       // the first 64 elements could be processed with AVX512.
352       // This is caused by mixing the __m256 and __m512 variables
353       // required to better handle that case and it would
354       // require handling more cases statically not to lose performance.
355       // This should be revisited if such input dimensions are to be considered.
356       [[maybe_unused]] const auto input_vector512 = reinterpret_cast<const __m512i*>(input);
357       [[maybe_unused]] const auto input_vector256 = reinterpret_cast<const __m256i*>(input);
358
359       // kOutputDimensions is either 1 or a multiple of kSimdWidth
360       // because then it is also an input dimension.
361       if constexpr (kOutputDimensions % 16 == 0 && kNumChunks256 == 1)
362       {
363         for (IndexType i = 0; i < kOutputDimensions; i += 16)
364         {
365           const IndexType offset01a = (i + 0) * kPaddedInputDimensions;
366           const IndexType offset23a = (i + 2) * kPaddedInputDimensions;
367           const IndexType offset45a = (i + 4) * kPaddedInputDimensions;
368           const IndexType offset67a = (i + 6) * kPaddedInputDimensions;
369           const IndexType offset01b = (i + 8) * kPaddedInputDimensions;
370           const IndexType offset23b = (i + 10) * kPaddedInputDimensions;
371           const IndexType offset45b = (i + 12) * kPaddedInputDimensions;
372           const IndexType offset67b = (i + 14) * kPaddedInputDimensions;
373
374           const __m512i bias = *reinterpret_cast<const __m512i*>(&biases_[i]);
375           __m512i* outptr = reinterpret_cast<__m512i*>(&output[i]);
376
377           __m512i sum01a = _mm512_setzero_si512();
378           __m512i sum23a = _mm512_setzero_si512();
379           __m512i sum45a = _mm512_setzero_si512();
380           __m512i sum67a = _mm512_setzero_si512();
381           __m512i sum01b = _mm512_setzero_si512();
382           __m512i sum23b = _mm512_setzero_si512();
383           __m512i sum45b = _mm512_setzero_si512();
384           __m512i sum67b = _mm512_setzero_si512();
385
386           const auto row01a = *reinterpret_cast<const __m512i*>(&weights_[offset01a]);
387           const auto row23a = *reinterpret_cast<const __m512i*>(&weights_[offset23a]);
388           const auto row45a = *reinterpret_cast<const __m512i*>(&weights_[offset45a]);
389           const auto row67a = *reinterpret_cast<const __m512i*>(&weights_[offset67a]);
390           const auto row01b = *reinterpret_cast<const __m512i*>(&weights_[offset01b]);
391           const auto row23b = *reinterpret_cast<const __m512i*>(&weights_[offset23b]);
392           const auto row45b = *reinterpret_cast<const __m512i*>(&weights_[offset45b]);
393           const auto row67b = *reinterpret_cast<const __m512i*>(&weights_[offset67b]);
394
395           const __m256i in256 = input_vector256[0];
396           const __m512i in = _mm512_inserti64x4(_mm512_castsi256_si512(in256), in256, 1);
397
398           m512_add_dpbusd_epi32(sum01a, in, row01a);
399           m512_add_dpbusd_epi32(sum23a, in, row23a);
400           m512_add_dpbusd_epi32(sum45a, in, row45a);
401           m512_add_dpbusd_epi32(sum67a, in, row67a);
402           m512_add_dpbusd_epi32(sum01b, in, row01b);
403           m512_add_dpbusd_epi32(sum23b, in, row23b);
404           m512_add_dpbusd_epi32(sum45b, in, row45b);
405           m512_add_dpbusd_epi32(sum67b, in, row67b);
406
407           *outptr = m512_hadd256x16(
408             sum01a, sum23a, sum45a, sum67a,
409             sum01b, sum23b, sum45b, sum67b, bias);
410         }
411       }
412       else if constexpr (kOutputDimensions % 4 == 0)
413       {
414         for (IndexType i = 0; i < kOutputDimensions; i += 4)
415         {
416           const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
417           const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
418           const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
419           const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
420
421           const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
422           __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
423
424           if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
425           {
426             __m512i sum0 = _mm512_setzero_si512();
427             __m512i sum1 = _mm512_setzero_si512();
428             __m512i sum2 = _mm512_setzero_si512();
429             __m512i sum3 = _mm512_setzero_si512();
430
431             const auto row0 = reinterpret_cast<const __m512i*>(&weights_[offset0]);
432             const auto row1 = reinterpret_cast<const __m512i*>(&weights_[offset1]);
433             const auto row2 = reinterpret_cast<const __m512i*>(&weights_[offset2]);
434             const auto row3 = reinterpret_cast<const __m512i*>(&weights_[offset3]);
435
436             int j = 0;
437             if (!canSaturate16x4[i / 4])
438             {
439                 for (; j < (int)kNumChunks512 - 1; j += 2)
440                 {
441                     const __m512i in0 = input_vector512[j];
442                     const __m512i in1 = input_vector512[j + 1];
443
444                     m512_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
445                     m512_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
446                     m512_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
447                     m512_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
448                 }
449             }
450             for (; j < (int)kNumChunks512; ++j)
451             {
452               const __m512i in = input_vector512[j];
453
454               m512_add_dpbusd_epi32(sum0, in, row0[j]);
455               m512_add_dpbusd_epi32(sum1, in, row1[j]);
456               m512_add_dpbusd_epi32(sum2, in, row2[j]);
457               m512_add_dpbusd_epi32(sum3, in, row3[j]);
458             }
459
460             *outptr = m512_haddx4(sum0, sum1, sum2, sum3, bias);
461           }
462           else
463           {
464             __m256i sum0 = _mm256_setzero_si256();
465             __m256i sum1 = _mm256_setzero_si256();
466             __m256i sum2 = _mm256_setzero_si256();
467             __m256i sum3 = _mm256_setzero_si256();
468
469             const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
470             const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
471             const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
472             const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
473
474             for (IndexType j = 0; j < kNumChunks256; ++j)
475             {
476               const __m256i in = input_vector256[j];
477
478               m256_add_dpbusd_epi32(sum0, in, row0[j]);
479               m256_add_dpbusd_epi32(sum1, in, row1[j]);
480               m256_add_dpbusd_epi32(sum2, in, row2[j]);
481               m256_add_dpbusd_epi32(sum3, in, row3[j]);
482             }
483
484             *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
485           }
486         }
487       }
488       else if constexpr (kOutputDimensions == 1)
489       {
490         if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) == 0)
491         {
492           __m512i sum0 = _mm512_setzero_si512();
493
494           const auto row0 = reinterpret_cast<const __m512i*>(&weights_[0]);
495
496           for (IndexType j = 0; j < kNumChunks512; ++j)
497           {
498             const __m512i in = input_vector512[j];
499
500             m512_add_dpbusd_epi32(sum0, in, row0[j]);
501           }
502
503           output[0] = m512_hadd(sum0, biases_[0]);
504         }
505         else
506         {
507           __m256i sum0 = _mm256_setzero_si256();
508
509           const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
510
511           for (IndexType j = 0; j < kNumChunks256; ++j)
512           {
513             const __m256i in = input_vector256[j];
514
515             m256_add_dpbusd_epi32(sum0, in, row0[j]);
516           }
517
518           output[0] = m256_hadd(sum0, biases_[0]);
519         }
520       }
521       else
522       {
523         // This case can never happen because kOutputDimensions
524         // is always 1 or a multiple of kSimdWidth.
525         assert(false);
526       }
527
528 #elif defined (USE_AVX2)
529
530       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
531
532       const auto output = reinterpret_cast<OutputType*>(buffer);
533       const auto input_vector = reinterpret_cast<const __m256i*>(input);
534
535       // kOutputDimensions is either 1 or a multiple of kSimdWidth
536       // because then it is also an input dimension.
537       if constexpr (kOutputDimensions % 4 == 0)
538       {
539         for (IndexType i = 0; i < kOutputDimensions; i += 4)
540         {
541           const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
542           const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
543           const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
544           const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
545
546           const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
547           __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
548
549           __m256i sum0 = _mm256_setzero_si256();
550           __m256i sum1 = _mm256_setzero_si256();
551           __m256i sum2 = _mm256_setzero_si256();
552           __m256i sum3 = _mm256_setzero_si256();
553
554           const auto row0 = reinterpret_cast<const __m256i*>(&weights_[offset0]);
555           const auto row1 = reinterpret_cast<const __m256i*>(&weights_[offset1]);
556           const auto row2 = reinterpret_cast<const __m256i*>(&weights_[offset2]);
557           const auto row3 = reinterpret_cast<const __m256i*>(&weights_[offset3]);
558
559           int j = 0;
560           if (!canSaturate16x4[i / 4])
561           {
562               for (; j < (int)kNumChunks - 1; j += 2)
563               {
564                   const __m256i in0 = input_vector[j];
565                   const __m256i in1 = input_vector[j + 1];
566
567                   m256_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
568                   m256_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
569                   m256_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
570                   m256_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
571               }
572           }
573           for (; j < (int)kNumChunks; ++j)
574           {
575                 const __m256i in = input_vector[j];
576
577                 m256_add_dpbusd_epi32(sum0, in, row0[j]);
578                 m256_add_dpbusd_epi32(sum1, in, row1[j]);
579                 m256_add_dpbusd_epi32(sum2, in, row2[j]);
580                 m256_add_dpbusd_epi32(sum3, in, row3[j]);
581           }
582
583           *outptr = m256_haddx4(sum0, sum1, sum2, sum3, bias);
584         }
585       }
586       else if constexpr (kOutputDimensions == 1)
587       {
588         __m256i sum0 = _mm256_setzero_si256();
589
590         const auto row0 = reinterpret_cast<const __m256i*>(&weights_[0]);
591
592         for (IndexType j = 0; j < kNumChunks; ++j)
593         {
594             const __m256i in = input_vector[j];
595
596             m256_add_dpbusd_epi32(sum0, in, row0[j]);
597         }
598
599         output[0] = m256_hadd(sum0, biases_[0]);
600       }
601       else
602       {
603         // This case can never happen because kOutputDimensions
604         // is always 1 or a multiple of kSimdWidth.
605         assert(false);
606       }
607
608 #elif defined (USE_SSSE3)
609
610       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
611
612       auto output = reinterpret_cast<OutputType*>(buffer);
613       const auto input_vector = reinterpret_cast<const __m128i*>(input);
614
615       // kOutputDimensions is either 1 or a multiple of kSimdWidth
616       // because then it is also an input dimension.
617       if constexpr (kOutputDimensions % 4 == 0)
618       {
619         for (IndexType i = 0; i < kOutputDimensions; i += 4)
620         {
621           const IndexType offset0 = (i + 0) * kPaddedInputDimensions;
622           const IndexType offset1 = (i + 1) * kPaddedInputDimensions;
623           const IndexType offset2 = (i + 2) * kPaddedInputDimensions;
624           const IndexType offset3 = (i + 3) * kPaddedInputDimensions;
625
626           const __m128i bias = *reinterpret_cast<const __m128i*>(&biases_[i]);
627           __m128i* outptr = reinterpret_cast<__m128i*>(&output[i]);
628
629           __m128i sum0 = _mm_setzero_si128();
630           __m128i sum1 = _mm_setzero_si128();
631           __m128i sum2 = _mm_setzero_si128();
632           __m128i sum3 = _mm_setzero_si128();
633
634           const auto row0 = reinterpret_cast<const __m128i*>(&weights_[offset0]);
635           const auto row1 = reinterpret_cast<const __m128i*>(&weights_[offset1]);
636           const auto row2 = reinterpret_cast<const __m128i*>(&weights_[offset2]);
637           const auto row3 = reinterpret_cast<const __m128i*>(&weights_[offset3]);
638
639           int j = 0;
640           if (!canSaturate16x4[i / 4])
641           {
642               for (; j < (int)kNumChunks - 1; j += 2)
643               {
644                   const __m128i in0 = input_vector[j];
645                   const __m128i in1 = input_vector[j + 1];
646
647                   m128_add_dpbusd_epi32x2(sum0, in0, row0[j], in1, row0[j + 1]);
648                   m128_add_dpbusd_epi32x2(sum1, in0, row1[j], in1, row1[j + 1]);
649                   m128_add_dpbusd_epi32x2(sum2, in0, row2[j], in1, row2[j + 1]);
650                   m128_add_dpbusd_epi32x2(sum3, in0, row3[j], in1, row3[j + 1]);
651               }
652           }
653           for (; j < (int)kNumChunks; ++j)
654           {
655               const __m128i in = input_vector[j];
656
657               m128_add_dpbusd_epi32(sum0, in, row0[j]);
658               m128_add_dpbusd_epi32(sum1, in, row1[j]);
659               m128_add_dpbusd_epi32(sum2, in, row2[j]);
660               m128_add_dpbusd_epi32(sum3, in, row3[j]);
661           }
662
663           *outptr = m128_haddx4(sum0, sum1, sum2, sum3, bias);
664         }
665       }
666       else if constexpr (kOutputDimensions == 1)
667       {
668         __m128i sum0 = _mm_setzero_si128();
669
670         const auto row0 = reinterpret_cast<const __m128i*>(&weights_[0]);
671
672         for (int j = 0; j < (int)kNumChunks; ++j)
673         {
674           const __m128i in = input_vector[j];
675
676           m128_add_dpbusd_epi32(sum0, in, row0[j]);
677         }
678
679         output[0] = m128_hadd(sum0, biases_[0]);
680       }
681       else
682       {
683         // This case can never happen because kOutputDimensions
684         // is always 1 or a multiple of kSimdWidth.
685         assert(false);
686       }
687
688 #else
689
690 // Use old implementation for the other architectures.
691
692       auto output = reinterpret_cast<OutputType*>(buffer);
693
694 #if defined(USE_SSE2)
695       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
696 #ifndef USE_SSSE3
697       const __m128i kZeros = _mm_setzero_si128();
698 #else
699       const __m128i kOnes = _mm_set1_epi16(1);
700 #endif
701       const auto input_vector = reinterpret_cast<const __m128i*>(input);
702
703 #elif defined(USE_MMX)
704       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
705       const __m64 kZeros = _mm_setzero_si64();
706       const auto input_vector = reinterpret_cast<const __m64*>(input);
707
708 #elif defined(USE_NEON)
709       constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
710       const auto input_vector = reinterpret_cast<const int8x8_t*>(input);
711 #endif
712
713       for (IndexType i = 0; i < kOutputDimensions; ++i) {
714         const IndexType offset = i * kPaddedInputDimensions;
715
716 #if defined(USE_SSE2)
717         __m128i sum_lo = _mm_cvtsi32_si128(biases_[i]);
718         __m128i sum_hi = kZeros;
719         const auto row = reinterpret_cast<const __m128i*>(&weights_[offset]);
720         for (IndexType j = 0; j < kNumChunks; ++j) {
721           __m128i row_j = _mm_load_si128(&row[j]);
722           __m128i input_j = _mm_load_si128(&input_vector[j]);
723           __m128i extended_row_lo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
724           __m128i extended_row_hi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);
725           __m128i extended_input_lo = _mm_unpacklo_epi8(input_j, kZeros);
726           __m128i extended_input_hi = _mm_unpackhi_epi8(input_j, kZeros);
727           __m128i product_lo = _mm_madd_epi16(extended_row_lo, extended_input_lo);
728           __m128i product_hi = _mm_madd_epi16(extended_row_hi, extended_input_hi);
729           sum_lo = _mm_add_epi32(sum_lo, product_lo);
730           sum_hi = _mm_add_epi32(sum_hi, product_hi);
731         }
732         __m128i sum = _mm_add_epi32(sum_lo, sum_hi);
733         __m128i sum_high_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
734         sum = _mm_add_epi32(sum, sum_high_64);
735         __m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2));
736         sum = _mm_add_epi32(sum, sum_second_32);
737         output[i] = _mm_cvtsi128_si32(sum);
738
739 #elif defined(USE_MMX)
740         __m64 sum_lo = _mm_cvtsi32_si64(biases_[i]);
741         __m64 sum_hi = kZeros;
742         const auto row = reinterpret_cast<const __m64*>(&weights_[offset]);
743         for (IndexType j = 0; j < kNumChunks; ++j) {
744           __m64 row_j = row[j];
745           __m64 input_j = input_vector[j];
746           __m64 extended_row_lo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8);
747           __m64 extended_row_hi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8);
748           __m64 extended_input_lo = _mm_unpacklo_pi8(input_j, kZeros);
749           __m64 extended_input_hi = _mm_unpackhi_pi8(input_j, kZeros);
750           __m64 product_lo = _mm_madd_pi16(extended_row_lo, extended_input_lo);
751           __m64 product_hi = _mm_madd_pi16(extended_row_hi, extended_input_hi);
752           sum_lo = _mm_add_pi32(sum_lo, product_lo);
753           sum_hi = _mm_add_pi32(sum_hi, product_hi);
754         }
755         __m64 sum = _mm_add_pi32(sum_lo, sum_hi);
756         sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum));
757         output[i] = _mm_cvtsi64_si32(sum);
758
759 #elif defined(USE_NEON)
760         int32x4_t sum = {biases_[i]};
761         const auto row = reinterpret_cast<const int8x8_t*>(&weights_[offset]);
762         for (IndexType j = 0; j < kNumChunks; ++j) {
763           int16x8_t product = vmull_s8(input_vector[j * 2], row[j * 2]);
764           product = vmlal_s8(product, input_vector[j * 2 + 1], row[j * 2 + 1]);
765           sum = vpadalq_s16(sum, product);
766         }
767         output[i] = sum[0] + sum[1] + sum[2] + sum[3];
768
769 #else
770         OutputType sum = biases_[i];
771         for (IndexType j = 0; j < kInputDimensions; ++j) {
772           sum += weights_[offset + j] * input[j];
773         }
774         output[i] = sum;
775 #endif
776
777       }
778 #if defined(USE_MMX)
779       _mm_empty();
780 #endif
781
782 #endif
783
784       return output;
785     }
786
787    private:
788     using BiasType = OutputType;
789     using WeightType = std::int8_t;
790
791     PreviousLayer previous_layer_;
792
793     alignas(kCacheLineSize) BiasType biases_[kOutputDimensions];
794     alignas(kCacheLineSize) WeightType weights_[kOutputDimensions * kPaddedInputDimensions];
795     union {
796         uint32_t canSaturate16x4[(kOutputDimensions + 3) / 4];
797         bool canSaturate16[kOutputDimensions];
798     };
799   };
800
801 }  // namespace Eval::NNUE::Layers
802
803 #endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED