]> git.sesse.net Git - stockfish/blob - src/nnue/layers/affine_transform.h
Add support for ARM dot product instructions
[stockfish] / src / nnue / layers / affine_transform.h
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // Definition of layer AffineTransform of NNUE evaluation function
20
21 #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED
22 #define NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED
23
24 #include <iostream>
25 #include <algorithm>
26 #include <type_traits>
27 #include "../nnue_common.h"
28 #include "simd.h"
29
30 /*
31   This file contains the definition for a fully connected layer (aka affine transform).
32   Two approaches are employed, depending on the sizes of the transform.
33
34   Approach 1:
35     - used when the PaddedInputDimensions >= 128
36     - uses AVX512 if possible
37     - processes inputs in batches of 2*InputSimdWidth
38       - so in batches of 128 for AVX512
39     - the weight blocks of size InputSimdWidth are transposed such that
40       access is sequential
41     - N columns of the weight matrix are processed a time, where N
42       depends on the architecture (the amount of registers)
43     - accumulate + hadd is used
44
45   Approach 2:
46     - used when the PaddedInputDimensions < 128
47     - does not use AVX512
48     - expected use-case is for when PaddedInputDimensions == 32 and InputDimensions <= 32.
49       - that's why AVX512 is hard to implement
50     - expected use-case is small layers
51       - not optimized as well as the approach 1
52     - inputs are processed in chunks of 4, weights are respectively transposed
53     - accumulation happens directly to int32s
54 */
55
56 namespace Stockfish::Eval::NNUE::Layers {
57
58 // Fallback implementation for older/other architectures.
59 // Identical for both approaches. Requires the input to be padded to at least 16 values.
60 #if !defined(USE_SSSE3)
61   template <IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>
62   static void affine_transform_non_ssse3(std::int32_t* output, const std::int8_t* weights, const std::int32_t* biases, const std::uint8_t* input)
63   {
64 # if defined(USE_SSE2)
65     // At least a multiple of 16, with SSE2.
66     constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
67     const __m128i Zeros = _mm_setzero_si128();
68     const auto inputVector = reinterpret_cast<const __m128i*>(input);
69
70 # elif defined(USE_MMX)
71     constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / 8;
72     const __m64 Zeros = _mm_setzero_si64();
73     const auto inputVector = reinterpret_cast<const __m64*>(input);
74
75 # elif defined(USE_NEON_DOTPROD)
76     constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
77     const auto inputVector = reinterpret_cast<const int8x16_t*>(input);
78
79 # elif defined(USE_NEON)
80     constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
81     const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
82 # endif
83
84     for (IndexType i = 0; i < OutputDimensions; ++i) {
85       const IndexType offset = i * PaddedInputDimensions;
86
87 # if defined(USE_SSE2)
88       __m128i sumLo = _mm_cvtsi32_si128(biases[i]);
89       __m128i sumHi = Zeros;
90       const auto row = reinterpret_cast<const __m128i*>(&weights[offset]);
91       for (IndexType j = 0; j < NumChunks; ++j) {
92         __m128i row_j = _mm_load_si128(&row[j]);
93         __m128i input_j = _mm_load_si128(&inputVector[j]);
94         __m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);
95         __m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);
96         __m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros);
97         __m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros);
98         __m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo);
99         __m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi);
100         sumLo = _mm_add_epi32(sumLo, productLo);
101         sumHi = _mm_add_epi32(sumHi, productHi);
102       }
103       __m128i sum = _mm_add_epi32(sumLo, sumHi);
104       __m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));
105       sum = _mm_add_epi32(sum, sumHigh_64);
106       __m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2));
107       sum = _mm_add_epi32(sum, sum_second_32);
108       output[i] = _mm_cvtsi128_si32(sum);
109
110 # elif defined(USE_MMX)
111       __m64 sumLo = _mm_cvtsi32_si64(biases[i]);
112       __m64 sumHi = Zeros;
113       const auto row = reinterpret_cast<const __m64*>(&weights[offset]);
114       for (IndexType j = 0; j < NumChunks; ++j) {
115         __m64 row_j = row[j];
116         __m64 input_j = inputVector[j];
117         __m64 extendedRowLo = _mm_srai_pi16(_mm_unpacklo_pi8(row_j, row_j), 8);
118         __m64 extendedRowHi = _mm_srai_pi16(_mm_unpackhi_pi8(row_j, row_j), 8);
119         __m64 extendedInputLo = _mm_unpacklo_pi8(input_j, Zeros);
120         __m64 extendedInputHi = _mm_unpackhi_pi8(input_j, Zeros);
121         __m64 productLo = _mm_madd_pi16(extendedRowLo, extendedInputLo);
122         __m64 productHi = _mm_madd_pi16(extendedRowHi, extendedInputHi);
123         sumLo = _mm_add_pi32(sumLo, productLo);
124         sumHi = _mm_add_pi32(sumHi, productHi);
125       }
126       __m64 sum = _mm_add_pi32(sumLo, sumHi);
127       sum = _mm_add_pi32(sum, _mm_unpackhi_pi32(sum, sum));
128       output[i] = _mm_cvtsi64_si32(sum);
129
130 # elif defined(USE_NEON_DOTPROD)
131       int32x4_t sum = {biases[i]};
132       const auto row = reinterpret_cast<const int8x16_t*>(&weights[offset]);
133       for (IndexType j = 0; j < NumChunks; ++j) {
134         sum = vdotq_s32(sum, inputVector[j], row[j]);
135       }
136       output[i] = vaddvq_s32(sum);
137
138 # elif defined(USE_NEON)
139       int32x4_t sum = {biases[i]};
140       const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
141       for (IndexType j = 0; j < NumChunks; ++j) {
142         int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]);
143         product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]);
144         sum = vpadalq_s16(sum, product);
145       }
146       output[i] = sum[0] + sum[1] + sum[2] + sum[3];
147
148 # else
149       std::int32_t sum = biases[i];
150       for (IndexType j = 0; j < InputDimensions; ++j) {
151         sum += weights[offset + j] * input[j];
152       }
153       output[i] = sum;
154 # endif
155     }
156
157 # if defined(USE_MMX)
158     _mm_empty();
159 # endif
160   }
161 #endif
162
163   template <IndexType InDims, IndexType OutDims, typename Enabled = void>
164   class AffineTransform;
165
166 #if defined (USE_AVX512)
167   constexpr IndexType LargeInputSize = 2 * 64;
168 #else
169   constexpr IndexType LargeInputSize = std::numeric_limits<IndexType>::max();
170 #endif
171
172   // A specialization for large inputs.
173   template <IndexType InDims, IndexType OutDims>
174   class AffineTransform<InDims, OutDims, std::enable_if_t<(ceil_to_multiple<IndexType>(InDims, MaxSimdWidth) >= LargeInputSize)>> {
175    public:
176     // Input/output type
177     using InputType = std::uint8_t;
178     using OutputType = std::int32_t;
179
180     // Number of input/output dimensions
181     static constexpr IndexType InputDimensions = InDims;
182     static constexpr IndexType OutputDimensions = OutDims;
183
184     static constexpr IndexType PaddedInputDimensions =
185       ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth);
186     static constexpr IndexType PaddedOutputDimensions =
187       ceil_to_multiple<IndexType>(OutputDimensions, MaxSimdWidth);
188
189     using OutputBuffer = OutputType[PaddedOutputDimensions];
190
191     static_assert(PaddedInputDimensions >= LargeInputSize, "Something went wrong. This specialization should not have been chosen.");
192
193 #if defined (USE_AVX512)
194     static constexpr IndexType InputSimdWidth = 64;
195     static constexpr IndexType MaxNumOutputRegs = 16;
196 #elif defined (USE_AVX2)
197     static constexpr IndexType InputSimdWidth = 32;
198     static constexpr IndexType MaxNumOutputRegs = 8;
199 #elif defined (USE_SSSE3)
200     static constexpr IndexType InputSimdWidth = 16;
201     static constexpr IndexType MaxNumOutputRegs = 8;
202 #elif defined (USE_NEON_DOTPROD)
203     static constexpr IndexType InputSimdWidth = 16;
204     static constexpr IndexType MaxNumOutputRegs = 8;
205 #elif defined (USE_NEON)
206     static constexpr IndexType InputSimdWidth = 8;
207     static constexpr IndexType MaxNumOutputRegs = 8;
208 #else
209     // The fallback implementation will not have permuted weights.
210     // We define these to avoid a lot of ifdefs later.
211     static constexpr IndexType InputSimdWidth = 1;
212     static constexpr IndexType MaxNumOutputRegs = 1;
213 #endif
214
215     // A big block is a region in the weight matrix of the size [PaddedInputDimensions, NumOutputRegs].
216     // A small block is a region of size [InputSimdWidth, 1]
217
218     static constexpr IndexType NumOutputRegs = std::min(MaxNumOutputRegs, OutputDimensions);
219     static constexpr IndexType SmallBlockSize = InputSimdWidth;
220     static constexpr IndexType BigBlockSize = NumOutputRegs * PaddedInputDimensions;
221     static constexpr IndexType NumSmallBlocksInBigBlock = BigBlockSize / SmallBlockSize;
222     static constexpr IndexType NumSmallBlocksPerOutput = PaddedInputDimensions / SmallBlockSize;
223     static constexpr IndexType NumBigBlocks = OutputDimensions / NumOutputRegs;
224
225     static_assert(OutputDimensions % NumOutputRegs == 0);
226
227     // Hash value embedded in the evaluation file
228     static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) {
229       std::uint32_t hashValue = 0xCC03DAE4u;
230       hashValue += OutputDimensions;
231       hashValue ^= prevHash >> 1;
232       hashValue ^= prevHash << 31;
233       return hashValue;
234     }
235
236     /*
237       Transposes the small blocks within a block.
238       Effectively means that weights can be traversed sequentially during inference.
239     */
240     static IndexType get_weight_index(IndexType i)
241     {
242       const IndexType smallBlock = (i / SmallBlockSize) % NumSmallBlocksInBigBlock;
243       const IndexType smallBlockCol = smallBlock / NumSmallBlocksPerOutput;
244       const IndexType smallBlockRow = smallBlock % NumSmallBlocksPerOutput;
245       const IndexType bigBlock   = i / BigBlockSize;
246       const IndexType rest       = i % SmallBlockSize;
247
248       const IndexType idx =
249           bigBlock * BigBlockSize
250         + smallBlockRow * SmallBlockSize * NumOutputRegs
251         + smallBlockCol * SmallBlockSize
252         + rest;
253
254       return idx;
255     }
256
257     // Read network parameters
258     bool read_parameters(std::istream& stream) {
259       for (IndexType i = 0; i < OutputDimensions; ++i)
260         biases[i] = read_little_endian<BiasType>(stream);
261
262       for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
263         weights[get_weight_index(i)] = read_little_endian<WeightType>(stream);
264
265       return !stream.fail();
266     }
267
268     // Write network parameters
269     bool write_parameters(std::ostream& stream) const {
270       for (IndexType i = 0; i < OutputDimensions; ++i)
271           write_little_endian<BiasType>(stream, biases[i]);
272
273       for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
274         write_little_endian<WeightType>(stream, weights[get_weight_index(i)]);
275
276       return !stream.fail();
277     }
278
279     // Forward propagation
280     const OutputType* propagate(
281         const InputType* input, OutputType* output) const {
282
283 #if defined (USE_AVX512)
284       using acc_vec_t = __m512i;
285       using bias_vec_t = __m128i;
286       using weight_vec_t = __m512i;
287       using in_vec_t = __m512i;
288       #define vec_zero _mm512_setzero_si512()
289       #define vec_add_dpbusd_32x2 Simd::m512_add_dpbusd_epi32x2
290       #define vec_hadd Simd::m512_hadd
291       #define vec_haddx4 Simd::m512_haddx4
292 #elif defined (USE_AVX2)
293       using acc_vec_t = __m256i;
294       using bias_vec_t = __m128i;
295       using weight_vec_t = __m256i;
296       using in_vec_t = __m256i;
297       #define vec_zero _mm256_setzero_si256()
298       #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
299       #define vec_hadd Simd::m256_hadd
300       #define vec_haddx4 Simd::m256_haddx4
301 #elif defined (USE_SSSE3)
302       using acc_vec_t = __m128i;
303       using bias_vec_t = __m128i;
304       using weight_vec_t = __m128i;
305       using in_vec_t = __m128i;
306       #define vec_zero _mm_setzero_si128()
307       #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
308       #define vec_hadd Simd::m128_hadd
309       #define vec_haddx4 Simd::m128_haddx4
310 #elif defined (USE_NEON_DOTPROD)
311       using acc_vec_t = int32x4_t;
312       using bias_vec_t = int32x4_t;
313       using weight_vec_t = int8x16_t;
314       using in_vec_t = int8x16_t;
315       #define vec_zero {0}
316       #define vec_add_dpbusd_32x2 Simd::dotprod_m128_add_dpbusd_epi32x2
317       #define vec_hadd Simd::neon_m128_hadd
318       #define vec_haddx4 Simd::neon_m128_haddx4
319 #elif defined (USE_NEON)
320       using acc_vec_t = int32x4_t;
321       using bias_vec_t = int32x4_t;
322       using weight_vec_t = int8x8_t;
323       using in_vec_t = int8x8_t;
324       #define vec_zero {0}
325       #define vec_add_dpbusd_32x2 Simd::neon_m128_add_dpbusd_epi32x2
326       #define vec_hadd Simd::neon_m128_hadd
327       #define vec_haddx4 Simd::neon_m128_haddx4
328 #endif
329
330 #if defined (USE_SSSE3) || defined (USE_NEON)
331       const in_vec_t* invec = reinterpret_cast<const in_vec_t*>(input);
332
333       // Perform accumulation to registers for each big block
334       for (IndexType bigBlock = 0; bigBlock < NumBigBlocks; ++bigBlock)
335       {
336         acc_vec_t acc[NumOutputRegs] = { vec_zero };
337
338         // Each big block has NumOutputRegs small blocks in each "row", one per register.
339         // We process two small blocks at a time to save on one addition without VNNI.
340         for (IndexType smallBlock = 0; smallBlock < NumSmallBlocksPerOutput; smallBlock += 2)
341         {
342           const weight_vec_t* weightvec =
343             reinterpret_cast<const weight_vec_t*>(
344                 weights
345               + bigBlock * BigBlockSize
346               + smallBlock * SmallBlockSize * NumOutputRegs);
347
348           const in_vec_t in0 = invec[smallBlock + 0];
349           const in_vec_t in1 = invec[smallBlock + 1];
350
351           for (IndexType k = 0; k < NumOutputRegs; ++k)
352             vec_add_dpbusd_32x2(acc[k], in0, weightvec[k], in1, weightvec[k + NumOutputRegs]);
353         }
354
355         // Horizontally add all accumulators.
356         if constexpr (NumOutputRegs % 4 == 0)
357         {
358           bias_vec_t* outputvec = reinterpret_cast<bias_vec_t*>(output);
359           const bias_vec_t* biasvec = reinterpret_cast<const bias_vec_t*>(biases);
360
361           for (IndexType k = 0; k < NumOutputRegs; k += 4)
362           {
363             const IndexType idx = (bigBlock * NumOutputRegs + k) / 4;
364             outputvec[idx] = vec_haddx4(acc[k+0], acc[k+1], acc[k+2], acc[k+3], biasvec[idx]);
365           }
366         }
367         else
368         {
369           for (IndexType k = 0; k < NumOutputRegs; ++k)
370           {
371             const IndexType idx = (bigBlock * NumOutputRegs + k);
372             output[idx] = vec_hadd(acc[k], biases[idx]);
373           }
374         }
375       }
376
377 # undef vec_zero
378 # undef vec_add_dpbusd_32x2
379 # undef vec_hadd
380 # undef vec_haddx4
381 #else
382       // Use old implementation for the other architectures.
383       affine_transform_non_ssse3<
384         InputDimensions,
385         PaddedInputDimensions,
386         OutputDimensions>(output, weights, biases, input);
387
388 #endif
389
390       return output;
391     }
392
393    private:
394     using BiasType = OutputType;
395     using WeightType = std::int8_t;
396
397     alignas(CacheLineSize) BiasType biases[OutputDimensions];
398     alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions];
399   };
400
401   template <IndexType InDims, IndexType OutDims>
402   class AffineTransform<InDims, OutDims, std::enable_if_t<(ceil_to_multiple<IndexType>(InDims, MaxSimdWidth) < LargeInputSize)>> {
403    public:
404     // Input/output type
405     // Input/output type
406     using InputType = std::uint8_t;
407     using OutputType = std::int32_t;
408
409     // Number of input/output dimensions
410     static constexpr IndexType InputDimensions = InDims;
411     static constexpr IndexType OutputDimensions = OutDims;
412
413     static constexpr IndexType PaddedInputDimensions =
414       ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth);
415     static constexpr IndexType PaddedOutputDimensions =
416       ceil_to_multiple<IndexType>(OutputDimensions, MaxSimdWidth);
417
418     using OutputBuffer = OutputType[PaddedOutputDimensions];
419
420     static_assert(PaddedInputDimensions < LargeInputSize, "Something went wrong. This specialization should not have been chosen.");
421
422 #if defined (USE_SSSE3)
423     static constexpr IndexType OutputSimdWidth = SimdWidth / 4;
424     static constexpr IndexType InputSimdWidth = SimdWidth;
425 #endif
426
427     // Hash value embedded in the evaluation file
428     static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) {
429       std::uint32_t hashValue = 0xCC03DAE4u;
430       hashValue += OutputDimensions;
431       hashValue ^= prevHash >> 1;
432       hashValue ^= prevHash << 31;
433       return hashValue;
434     }
435
436     static IndexType get_weight_index_scrambled(IndexType i)
437     {
438       return
439         (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4 +
440         i / PaddedInputDimensions * 4 +
441         i % 4;
442     }
443
444     static IndexType get_weight_index(IndexType i)
445     {
446 #if defined (USE_SSSE3)
447       return get_weight_index_scrambled(i);
448 #else
449       return i;
450 #endif
451     }
452
453     // Read network parameters
454     bool read_parameters(std::istream& stream) {
455       for (IndexType i = 0; i < OutputDimensions; ++i)
456         biases[i] = read_little_endian<BiasType>(stream);
457       for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
458         weights[get_weight_index(i)] = read_little_endian<WeightType>(stream);
459
460       return !stream.fail();
461     }
462
463     // Write network parameters
464     bool write_parameters(std::ostream& stream) const {
465       for (IndexType i = 0; i < OutputDimensions; ++i)
466         write_little_endian<BiasType>(stream, biases[i]);
467
468       for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)
469         write_little_endian<WeightType>(stream, weights[get_weight_index(i)]);
470
471       return !stream.fail();
472     }
473     // Forward propagation
474     const OutputType* propagate(
475         const InputType* input, OutputType* output) const {
476
477 #if defined (USE_AVX2)
478       using vec_t = __m256i;
479       #define vec_setzero _mm256_setzero_si256
480       #define vec_set_32 _mm256_set1_epi32
481       #define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
482       #define vec_add_dpbusd_32x2 Simd::m256_add_dpbusd_epi32x2
483       #define vec_add_dpbusd_32x4 Simd::m256_add_dpbusd_epi32x4
484       #define vec_hadd Simd::m256_hadd
485       #define vec_haddx4 Simd::m256_haddx4
486 #elif defined (USE_SSSE3)
487       using vec_t = __m128i;
488       #define vec_setzero _mm_setzero_si128
489       #define vec_set_32 _mm_set1_epi32
490       #define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
491       #define vec_add_dpbusd_32x2 Simd::m128_add_dpbusd_epi32x2
492       #define vec_add_dpbusd_32x4 Simd::m128_add_dpbusd_epi32x4
493       #define vec_hadd Simd::m128_hadd
494       #define vec_haddx4 Simd::m128_haddx4
495 #endif
496
497 #if defined (USE_SSSE3)
498       const auto inputVector = reinterpret_cast<const vec_t*>(input);
499
500       static_assert(OutputDimensions % OutputSimdWidth == 0 || OutputDimensions == 1);
501
502       if constexpr (OutputDimensions % OutputSimdWidth == 0)
503       {
504         constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / 4;
505         constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
506
507         const auto input32 = reinterpret_cast<const std::int32_t*>(input);
508         const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
509         vec_t acc[NumRegs];
510         for (IndexType k = 0; k < NumRegs; ++k)
511           acc[k] = biasvec[k];
512
513         for (IndexType i = 0; i < NumChunks; i += 2)
514         {
515           const vec_t in0 = vec_set_32(input32[i + 0]);
516           const vec_t in1 = vec_set_32(input32[i + 1]);
517           const auto col0 = reinterpret_cast<const vec_t*>(&weights[(i + 0) * OutputDimensions * 4]);
518           const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
519           for (IndexType k = 0; k < NumRegs; ++k)
520             vec_add_dpbusd_32x2(acc[k], in0, col0[k], in1, col1[k]);
521         }
522
523         vec_t* outptr = reinterpret_cast<vec_t*>(output);
524         for (IndexType k = 0; k < NumRegs; ++k)
525           outptr[k] = acc[k];
526       }
527       else if constexpr (OutputDimensions == 1)
528       {
529         constexpr IndexType NumChunks = PaddedInputDimensions / SimdWidth;
530         vec_t sum0 = vec_setzero();
531         const auto row0 = reinterpret_cast<const vec_t*>(&weights[0]);
532
533         for (int j = 0; j < (int)NumChunks; ++j)
534         {
535           const vec_t in = inputVector[j];
536           vec_add_dpbusd_32(sum0, in, row0[j]);
537         }
538         output[0] = vec_hadd(sum0, biases[0]);
539       }
540
541 # undef vec_setzero
542 # undef vec_set_32
543 # undef vec_add_dpbusd_32
544 # undef vec_add_dpbusd_32x2
545 # undef vec_add_dpbusd_32x4
546 # undef vec_hadd
547 # undef vec_haddx4
548 #else
549       // Use old implementation for the other architectures.
550       affine_transform_non_ssse3<
551         InputDimensions,
552         PaddedInputDimensions,
553         OutputDimensions>(output, weights, biases, input);
554 #endif
555
556       return output;
557     }
558
559    private:
560     using BiasType = OutputType;
561     using WeightType = std::int8_t;
562
563     alignas(CacheLineSize) BiasType biases[OutputDimensions];
564     alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions];
565   };
566
567 }  // namespace Stockfish::Eval::NNUE::Layers
568
569 #endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED