]> git.sesse.net Git - stockfish/blob - src/nnue/nnue_feature_transformer.h
Reduce SIMD register count from 32 to 16
[stockfish] / src / nnue / nnue_feature_transformer.h
1 /*
2   Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3   Copyright (C) 2004-2023 The Stockfish developers (see AUTHORS file)
4
5   Stockfish is free software: you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   Stockfish is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14
15   You should have received a copy of the GNU General Public License
16   along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // A class that converts the input features of the NNUE evaluation function
20
21 #ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED
22 #define NNUE_FEATURE_TRANSFORMER_H_INCLUDED
23
24 #include <algorithm>
25 #include <cassert>
26 #include <cstdint>
27 #include <cstring>
28 #include <iosfwd>
29 #include <utility>
30
31 #include "../position.h"
32 #include "../types.h"
33 #include "nnue_accumulator.h"
34 #include "nnue_architecture.h"
35 #include "nnue_common.h"
36
37 namespace Stockfish::Eval::NNUE {
38
39   using BiasType       = std::int16_t;
40   using WeightType     = std::int16_t;
41   using PSQTWeightType = std::int32_t;
42
43   // If vector instructions are enabled, we update and refresh the
44   // accumulator tile by tile such that each tile fits in the CPU's
45   // vector registers.
46   #define VECTOR
47
48   static_assert(PSQTBuckets % 8 == 0,
49     "Per feature PSQT values cannot be processed at granularity lower than 8 at a time.");
50
51   #ifdef USE_AVX512
52   using vec_t = __m512i;
53   using psqt_vec_t = __m256i;
54   #define vec_load(a) _mm512_load_si512(a)
55   #define vec_store(a,b) _mm512_store_si512(a,b)
56   #define vec_add_16(a,b) _mm512_add_epi16(a,b)
57   #define vec_sub_16(a,b) _mm512_sub_epi16(a,b)
58   #define vec_mul_16(a,b) _mm512_mullo_epi16(a,b)
59   #define vec_zero() _mm512_setzero_epi32()
60   #define vec_set_16(a) _mm512_set1_epi16(a)
61   #define vec_max_16(a,b) _mm512_max_epi16(a,b)
62   #define vec_min_16(a,b) _mm512_min_epi16(a,b)
63   inline vec_t vec_msb_pack_16(vec_t a, vec_t b){
64     vec_t compacted = _mm512_packs_epi16(_mm512_srli_epi16(a,7),_mm512_srli_epi16(b,7));
65     return _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), compacted);
66   }
67   #define vec_load_psqt(a) _mm256_load_si256(a)
68   #define vec_store_psqt(a,b) _mm256_store_si256(a,b)
69   #define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
70   #define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
71   #define vec_zero_psqt() _mm256_setzero_si256()
72   #define NumRegistersSIMD 16
73   #define MaxChunkSize 64
74
75   #elif USE_AVX2
76   using vec_t = __m256i;
77   using psqt_vec_t = __m256i;
78   #define vec_load(a) _mm256_load_si256(a)
79   #define vec_store(a,b) _mm256_store_si256(a,b)
80   #define vec_add_16(a,b) _mm256_add_epi16(a,b)
81   #define vec_sub_16(a,b) _mm256_sub_epi16(a,b)
82   #define vec_mul_16(a,b) _mm256_mullo_epi16(a,b)
83   #define vec_zero() _mm256_setzero_si256()
84   #define vec_set_16(a) _mm256_set1_epi16(a)
85   #define vec_max_16(a,b) _mm256_max_epi16(a,b)
86   #define vec_min_16(a,b) _mm256_min_epi16(a,b)
87   inline vec_t vec_msb_pack_16(vec_t a, vec_t b){
88     vec_t compacted = _mm256_packs_epi16(_mm256_srli_epi16(a,7), _mm256_srli_epi16(b,7));
89     return _mm256_permute4x64_epi64(compacted, 0b11011000);
90   }
91   #define vec_load_psqt(a) _mm256_load_si256(a)
92   #define vec_store_psqt(a,b) _mm256_store_si256(a,b)
93   #define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
94   #define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
95   #define vec_zero_psqt() _mm256_setzero_si256()
96   #define NumRegistersSIMD 16
97   #define MaxChunkSize 32
98
99   #elif USE_SSE2
100   using vec_t = __m128i;
101   using psqt_vec_t = __m128i;
102   #define vec_load(a) (*(a))
103   #define vec_store(a,b) *(a)=(b)
104   #define vec_add_16(a,b) _mm_add_epi16(a,b)
105   #define vec_sub_16(a,b) _mm_sub_epi16(a,b)
106   #define vec_mul_16(a,b) _mm_mullo_epi16(a,b)
107   #define vec_zero() _mm_setzero_si128()
108   #define vec_set_16(a) _mm_set1_epi16(a)
109   #define vec_max_16(a,b) _mm_max_epi16(a,b)
110   #define vec_min_16(a,b) _mm_min_epi16(a,b)
111   #define vec_msb_pack_16(a,b) _mm_packs_epi16(_mm_srli_epi16(a,7),_mm_srli_epi16(b,7))
112   #define vec_load_psqt(a) (*(a))
113   #define vec_store_psqt(a,b) *(a)=(b)
114   #define vec_add_psqt_32(a,b) _mm_add_epi32(a,b)
115   #define vec_sub_psqt_32(a,b) _mm_sub_epi32(a,b)
116   #define vec_zero_psqt() _mm_setzero_si128()
117   #define NumRegistersSIMD (Is64Bit ? 16 : 8)
118   #define MaxChunkSize 16
119
120   #elif USE_MMX
121   using vec_t = __m64;
122   using psqt_vec_t = __m64;
123   #define vec_load(a) (*(a))
124   #define vec_store(a,b) *(a)=(b)
125   #define vec_add_16(a,b) _mm_add_pi16(a,b)
126   #define vec_sub_16(a,b) _mm_sub_pi16(a,b)
127   #define vec_mul_16(a,b) _mm_mullo_pi16(a,b)
128   #define vec_zero() _mm_setzero_si64()
129   #define vec_set_16(a) _mm_set1_pi16(a)
130   inline vec_t vec_max_16(vec_t a,vec_t b){
131     vec_t comparison = _mm_cmpgt_pi16(a,b);
132     return _mm_or_si64(_mm_and_si64(comparison, a), _mm_andnot_si64(comparison, b));
133   }
134   inline vec_t vec_min_16(vec_t a,vec_t b){
135     vec_t comparison = _mm_cmpgt_pi16(a,b);
136     return _mm_or_si64(_mm_and_si64(comparison, b), _mm_andnot_si64(comparison, a));
137   }
138   #define vec_msb_pack_16(a,b) _mm_packs_pi16(_mm_srli_pi16(a,7),_mm_srli_pi16(b,7))
139   #define vec_load_psqt(a) (*(a))
140   #define vec_store_psqt(a,b) *(a)=(b)
141   #define vec_add_psqt_32(a,b) _mm_add_pi32(a,b)
142   #define vec_sub_psqt_32(a,b) _mm_sub_pi32(a,b)
143   #define vec_zero_psqt() _mm_setzero_si64()
144   #define vec_cleanup() _mm_empty()
145   #define NumRegistersSIMD 8
146   #define MaxChunkSize 8
147
148   #elif USE_NEON
149   using vec_t = int16x8_t;
150   using psqt_vec_t = int32x4_t;
151   #define vec_load(a) (*(a))
152   #define vec_store(a,b) *(a)=(b)
153   #define vec_add_16(a,b) vaddq_s16(a,b)
154   #define vec_sub_16(a,b) vsubq_s16(a,b)
155   #define vec_mul_16(a,b) vmulq_s16(a,b)
156   #define vec_zero() vec_t{0}
157   #define vec_set_16(a) vdupq_n_s16(a)
158   #define vec_max_16(a,b) vmaxq_s16(a,b)
159   #define vec_min_16(a,b) vminq_s16(a,b)
160   inline vec_t vec_msb_pack_16(vec_t a, vec_t b){
161     const int8x8_t shifta = vshrn_n_s16(a, 7);
162     const int8x8_t shiftb = vshrn_n_s16(b, 7);
163     const int8x16_t compacted = vcombine_s8(shifta,shiftb);
164     return *reinterpret_cast<const vec_t*> (&compacted);
165   }
166   #define vec_load_psqt(a) (*(a))
167   #define vec_store_psqt(a,b) *(a)=(b)
168   #define vec_add_psqt_32(a,b) vaddq_s32(a,b)
169   #define vec_sub_psqt_32(a,b) vsubq_s32(a,b)
170   #define vec_zero_psqt() psqt_vec_t{0}
171   #define NumRegistersSIMD 16
172   #define MaxChunkSize 16
173
174   #else
175   #undef VECTOR
176
177   #endif
178
179
180   #ifdef VECTOR
181
182       // Compute optimal SIMD register count for feature transformer accumulation.
183
184       // We use __m* types as template arguments, which causes GCC to emit warnings
185       // about losing some attribute information. This is irrelevant to us as we
186       // only take their size, so the following pragma are harmless.
187       #if defined(__GNUC__)
188       #pragma GCC diagnostic push
189       #pragma GCC diagnostic ignored "-Wignored-attributes"
190       #endif
191
192       template <typename SIMDRegisterType,
193                 typename LaneType,
194                 int      NumLanes,
195                 int      MaxRegisters>
196       static constexpr int BestRegisterCount()
197       {
198           #define RegisterSize  sizeof(SIMDRegisterType)
199           #define LaneSize      sizeof(LaneType)
200
201           static_assert(RegisterSize >= LaneSize);
202           static_assert(MaxRegisters <= NumRegistersSIMD);
203           static_assert(MaxRegisters > 0);
204           static_assert(NumRegistersSIMD > 0);
205           static_assert(RegisterSize % LaneSize == 0);
206           static_assert((NumLanes * LaneSize) % RegisterSize == 0);
207
208           const int ideal = (NumLanes * LaneSize) / RegisterSize;
209           if (ideal <= MaxRegisters)
210             return ideal;
211
212           // Look for the largest divisor of the ideal register count that is smaller than MaxRegisters
213           for (int divisor = MaxRegisters; divisor > 1; --divisor)
214             if (ideal % divisor == 0)
215               return divisor;
216
217           return 1;
218       }
219
220       static constexpr int NumRegs     = BestRegisterCount<vec_t, WeightType, TransformedFeatureDimensions, NumRegistersSIMD>();
221       static constexpr int NumPsqtRegs = BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();
222       #if defined(__GNUC__)
223       #pragma GCC diagnostic pop
224       #endif
225   #endif
226
227
228
229   // Input feature converter
230   class FeatureTransformer {
231
232    private:
233     // Number of output dimensions for one side
234     static constexpr IndexType HalfDimensions = TransformedFeatureDimensions;
235
236     #ifdef VECTOR
237     static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
238     static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;
239     static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
240     static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
241     #endif
242
243    public:
244     // Output type
245     using OutputType = TransformedFeatureType;
246
247     // Number of input/output dimensions
248     static constexpr IndexType InputDimensions = FeatureSet::Dimensions;
249     static constexpr IndexType OutputDimensions = HalfDimensions;
250
251     // Size of forward propagation buffer
252     static constexpr std::size_t BufferSize =
253         OutputDimensions * sizeof(OutputType);
254
255     // Hash value embedded in the evaluation file
256     static constexpr std::uint32_t get_hash_value() {
257       return FeatureSet::HashValue ^ (OutputDimensions * 2);
258     }
259
260     // Read network parameters
261     bool read_parameters(std::istream& stream) {
262
263       read_leb_128<BiasType      >(stream, biases     , HalfDimensions                  );
264       read_leb_128<WeightType    >(stream, weights    , HalfDimensions * InputDimensions);
265       read_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets    * InputDimensions);
266
267       return !stream.fail();
268     }
269
270     // Write network parameters
271     bool write_parameters(std::ostream& stream) const {
272
273       write_leb_128<BiasType      >(stream, biases     , HalfDimensions                  );
274       write_leb_128<WeightType    >(stream, weights    , HalfDimensions * InputDimensions);
275       write_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets    * InputDimensions);
276
277       return !stream.fail();
278     }
279
280     // Convert input features
281     std::int32_t transform(const Position& pos, OutputType* output, int bucket) const {
282       update_accumulator<WHITE>(pos);
283       update_accumulator<BLACK>(pos);
284
285       const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()};
286       const auto& accumulation = pos.state()->accumulator.accumulation;
287       const auto& psqtAccumulation = pos.state()->accumulator.psqtAccumulation;
288
289       const auto psqt = (
290             psqtAccumulation[perspectives[0]][bucket]
291           - psqtAccumulation[perspectives[1]][bucket]
292         ) / 2;
293
294
295       for (IndexType p = 0; p < 2; ++p)
296       {
297           const IndexType offset = (HalfDimensions / 2) * p;
298
299 #if defined(VECTOR)
300
301           constexpr IndexType OutputChunkSize = MaxChunkSize;
302           static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
303           constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
304
305           vec_t Zero = vec_zero();
306           vec_t One = vec_set_16(127);
307
308           const vec_t* in0 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][0]));
309           const vec_t* in1 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
310                 vec_t* out = reinterpret_cast<      vec_t*>(output + offset);
311
312           for (IndexType j = 0; j < NumOutputChunks; j += 1)
313           {
314               const vec_t sum0a = vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero);
315               const vec_t sum0b = vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero);
316               const vec_t sum1a = vec_max_16(vec_min_16(in1[j * 2 + 0], One), Zero);
317               const vec_t sum1b = vec_max_16(vec_min_16(in1[j * 2 + 1], One), Zero);
318
319               const vec_t pa = vec_mul_16(sum0a, sum1a);
320               const vec_t pb = vec_mul_16(sum0b, sum1b);
321
322               out[j] = vec_msb_pack_16(pa, pb);
323           }
324
325 #else
326
327           for (IndexType j = 0; j < HalfDimensions / 2; ++j) {
328               BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];
329               BiasType sum1 = accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];
330               sum0 = std::clamp<BiasType>(sum0, 0, 127);
331               sum1 = std::clamp<BiasType>(sum1, 0, 127);
332               output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 128);
333           }
334
335 #endif
336       }
337
338 #if defined(vec_cleanup)
339       vec_cleanup();
340 #endif
341
342       return psqt;
343     } // end of function transform()
344
345     void hint_common_access(const Position& pos) const {
346       hint_common_access_for_perspective<WHITE>(pos);
347       hint_common_access_for_perspective<BLACK>(pos);
348     }
349
350    private:
351     template<Color Perspective>
352     [[nodiscard]] std::pair<StateInfo*, StateInfo*> try_find_computed_accumulator(const Position& pos) const {
353       // Look for a usable accumulator of an earlier position. We keep track
354       // of the estimated gain in terms of features to be added/subtracted.
355       StateInfo *st = pos.state(), *next = nullptr;
356       int gain = FeatureSet::refresh_cost(pos);
357       while (st->previous && !st->accumulator.computed[Perspective])
358       {
359         // This governs when a full feature refresh is needed and how many
360         // updates are better than just one full refresh.
361         if (   FeatureSet::requires_refresh(st, Perspective)
362             || (gain -= FeatureSet::update_cost(st) + 1) < 0)
363           break;
364         next = st;
365         st = st->previous;
366       }
367       return { st, next };
368     }
369
370     // NOTE: The parameter states_to_update is an array of position states, ending with nullptr.
371     //       All states must be sequential, that is states_to_update[i] must either be reachable
372     //       by repeatedly applying ->previous from states_to_update[i+1] or states_to_update[i] == nullptr.
373     //       computed_st must be reachable by repeatedly applying ->previous on states_to_update[0], if not nullptr.
374     template<Color Perspective, size_t N>
375     void update_accumulator_incremental(const Position& pos, StateInfo* computed_st, StateInfo* states_to_update[N]) const {
376       static_assert(N > 0);
377       assert(states_to_update[N-1] == nullptr);
378
379   #ifdef VECTOR
380       // Gcc-10.2 unnecessarily spills AVX2 registers if this array
381       // is defined in the VECTOR code below, once in each branch
382       vec_t acc[NumRegs];
383       psqt_vec_t psqt[NumPsqtRegs];
384   #endif
385
386       if (states_to_update[0] == nullptr)
387         return;
388
389       // Update incrementally going back through states_to_update.
390
391       // Gather all features to be updated.
392       const Square ksq = pos.square<KING>(Perspective);
393
394       // The size must be enough to contain the largest possible update.
395       // That might depend on the feature set and generally relies on the
396       // feature set's update cost calculation to be correct and never
397       // allow updates with more added/removed features than MaxActiveDimensions.
398       FeatureSet::IndexList removed[N-1], added[N-1];
399
400       {
401         int i = N-2; // last potential state to update. Skip last element because it must be nullptr.
402         while (states_to_update[i] == nullptr)
403           --i;
404
405         StateInfo *st2 = states_to_update[i];
406
407         for (; i >= 0; --i)
408         {
409           states_to_update[i]->accumulator.computed[Perspective] = true;
410
411           StateInfo* end_state = i == 0 ? computed_st : states_to_update[i - 1];
412
413           for (; st2 != end_state; st2 = st2->previous)
414             FeatureSet::append_changed_indices<Perspective>(
415               ksq, st2->dirtyPiece, removed[i], added[i]);
416         }
417       }
418
419       StateInfo* st = computed_st;
420
421       // Now update the accumulators listed in states_to_update[], where the last element is a sentinel.
422 #ifdef VECTOR
423       for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j)
424       {
425         // Load accumulator
426         auto accTile = reinterpret_cast<vec_t*>(
427           &st->accumulator.accumulation[Perspective][j * TileHeight]);
428         for (IndexType k = 0; k < NumRegs; ++k)
429           acc[k] = vec_load(&accTile[k]);
430
431         for (IndexType i = 0; states_to_update[i]; ++i)
432         {
433           // Difference calculation for the deactivated features
434           for (const auto index : removed[i])
435           {
436             const IndexType offset = HalfDimensions * index + j * TileHeight;
437             auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
438             for (IndexType k = 0; k < NumRegs; ++k)
439               acc[k] = vec_sub_16(acc[k], column[k]);
440           }
441
442           // Difference calculation for the activated features
443           for (const auto index : added[i])
444           {
445             const IndexType offset = HalfDimensions * index + j * TileHeight;
446             auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
447             for (IndexType k = 0; k < NumRegs; ++k)
448               acc[k] = vec_add_16(acc[k], column[k]);
449           }
450
451           // Store accumulator
452           accTile = reinterpret_cast<vec_t*>(
453             &states_to_update[i]->accumulator.accumulation[Perspective][j * TileHeight]);
454           for (IndexType k = 0; k < NumRegs; ++k)
455             vec_store(&accTile[k], acc[k]);
456         }
457       }
458
459       for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
460       {
461         // Load accumulator
462         auto accTilePsqt = reinterpret_cast<psqt_vec_t*>(
463           &st->accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]);
464         for (std::size_t k = 0; k < NumPsqtRegs; ++k)
465           psqt[k] = vec_load_psqt(&accTilePsqt[k]);
466
467         for (IndexType i = 0; states_to_update[i]; ++i)
468         {
469           // Difference calculation for the deactivated features
470           for (const auto index : removed[i])
471           {
472             const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
473             auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
474             for (std::size_t k = 0; k < NumPsqtRegs; ++k)
475               psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]);
476           }
477
478           // Difference calculation for the activated features
479           for (const auto index : added[i])
480           {
481             const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
482             auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
483             for (std::size_t k = 0; k < NumPsqtRegs; ++k)
484               psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
485           }
486
487           // Store accumulator
488           accTilePsqt = reinterpret_cast<psqt_vec_t*>(
489             &states_to_update[i]->accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]);
490           for (std::size_t k = 0; k < NumPsqtRegs; ++k)
491             vec_store_psqt(&accTilePsqt[k], psqt[k]);
492         }
493       }
494
495 #else
496       for (IndexType i = 0; states_to_update[i]; ++i)
497       {
498         std::memcpy(states_to_update[i]->accumulator.accumulation[Perspective],
499             st->accumulator.accumulation[Perspective],
500             HalfDimensions * sizeof(BiasType));
501
502         for (std::size_t k = 0; k < PSQTBuckets; ++k)
503           states_to_update[i]->accumulator.psqtAccumulation[Perspective][k] = st->accumulator.psqtAccumulation[Perspective][k];
504
505         st = states_to_update[i];
506
507         // Difference calculation for the deactivated features
508         for (const auto index : removed[i])
509         {
510           const IndexType offset = HalfDimensions * index;
511
512           for (IndexType j = 0; j < HalfDimensions; ++j)
513             st->accumulator.accumulation[Perspective][j] -= weights[offset + j];
514
515           for (std::size_t k = 0; k < PSQTBuckets; ++k)
516             st->accumulator.psqtAccumulation[Perspective][k] -= psqtWeights[index * PSQTBuckets + k];
517         }
518
519         // Difference calculation for the activated features
520         for (const auto index : added[i])
521         {
522           const IndexType offset = HalfDimensions * index;
523
524           for (IndexType j = 0; j < HalfDimensions; ++j)
525             st->accumulator.accumulation[Perspective][j] += weights[offset + j];
526
527           for (std::size_t k = 0; k < PSQTBuckets; ++k)
528             st->accumulator.psqtAccumulation[Perspective][k] += psqtWeights[index * PSQTBuckets + k];
529         }
530       }
531 #endif
532
533   #if defined(USE_MMX)
534       _mm_empty();
535   #endif
536     }
537
538     template<Color Perspective>
539     void update_accumulator_refresh(const Position& pos) const {
540   #ifdef VECTOR
541       // Gcc-10.2 unnecessarily spills AVX2 registers if this array
542       // is defined in the VECTOR code below, once in each branch
543       vec_t acc[NumRegs];
544       psqt_vec_t psqt[NumPsqtRegs];
545   #endif
546
547       // Refresh the accumulator
548       // Could be extracted to a separate function because it's done in 2 places,
549       // but it's unclear if compilers would correctly handle register allocation.
550       auto& accumulator = pos.state()->accumulator;
551       accumulator.computed[Perspective] = true;
552       FeatureSet::IndexList active;
553       FeatureSet::append_active_indices<Perspective>(pos, active);
554
555 #ifdef VECTOR
556       for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j)
557       {
558         auto biasesTile = reinterpret_cast<const vec_t*>(
559             &biases[j * TileHeight]);
560         for (IndexType k = 0; k < NumRegs; ++k)
561           acc[k] = biasesTile[k];
562
563         for (const auto index : active)
564         {
565           const IndexType offset = HalfDimensions * index + j * TileHeight;
566           auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
567
568           for (unsigned k = 0; k < NumRegs; ++k)
569             acc[k] = vec_add_16(acc[k], column[k]);
570         }
571
572         auto accTile = reinterpret_cast<vec_t*>(
573             &accumulator.accumulation[Perspective][j * TileHeight]);
574         for (unsigned k = 0; k < NumRegs; k++)
575           vec_store(&accTile[k], acc[k]);
576       }
577
578       for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
579       {
580         for (std::size_t k = 0; k < NumPsqtRegs; ++k)
581           psqt[k] = vec_zero_psqt();
582
583         for (const auto index : active)
584         {
585           const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
586           auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
587
588           for (std::size_t k = 0; k < NumPsqtRegs; ++k)
589             psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
590         }
591
592         auto accTilePsqt = reinterpret_cast<psqt_vec_t*>(
593           &accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]);
594         for (std::size_t k = 0; k < NumPsqtRegs; ++k)
595           vec_store_psqt(&accTilePsqt[k], psqt[k]);
596       }
597
598 #else
599       std::memcpy(accumulator.accumulation[Perspective], biases,
600           HalfDimensions * sizeof(BiasType));
601
602       for (std::size_t k = 0; k < PSQTBuckets; ++k)
603         accumulator.psqtAccumulation[Perspective][k] = 0;
604
605       for (const auto index : active)
606       {
607         const IndexType offset = HalfDimensions * index;
608
609         for (IndexType j = 0; j < HalfDimensions; ++j)
610           accumulator.accumulation[Perspective][j] += weights[offset + j];
611
612         for (std::size_t k = 0; k < PSQTBuckets; ++k)
613           accumulator.psqtAccumulation[Perspective][k] += psqtWeights[index * PSQTBuckets + k];
614       }
615 #endif
616
617   #if defined(USE_MMX)
618       _mm_empty();
619   #endif
620     }
621
622     template<Color Perspective>
623     void hint_common_access_for_perspective(const Position& pos) const {
624
625       // Works like update_accumulator, but performs less work.
626       // Updates ONLY the accumulator for pos.
627
628       // Look for a usable accumulator of an earlier position. We keep track
629       // of the estimated gain in terms of features to be added/subtracted.
630       // Fast early exit.
631       if (pos.state()->accumulator.computed[Perspective])
632         return;
633
634       auto [oldest_st, _] = try_find_computed_accumulator<Perspective>(pos);
635
636       if (oldest_st->accumulator.computed[Perspective])
637       {
638         // Only update current position accumulator to minimize work.
639         StateInfo* states_to_update[2] = { pos.state(), nullptr };
640         update_accumulator_incremental<Perspective, 2>(pos, oldest_st, states_to_update);
641       }
642       else
643       {
644         update_accumulator_refresh<Perspective>(pos);
645       }
646     }
647
648     template<Color Perspective>
649     void update_accumulator(const Position& pos) const {
650
651       auto [oldest_st, next] = try_find_computed_accumulator<Perspective>(pos);
652
653       if (oldest_st->accumulator.computed[Perspective])
654       {
655         if (next == nullptr)
656           return;
657
658         // Now update the accumulators listed in states_to_update[], where the last element is a sentinel.
659         // Currently we update 2 accumulators.
660         //     1. for the current position
661         //     2. the next accumulator after the computed one
662         // The heuristic may change in the future.
663         StateInfo *states_to_update[3] =
664           { next, next == pos.state() ? nullptr : pos.state(), nullptr };
665
666         update_accumulator_incremental<Perspective, 3>(pos, oldest_st, states_to_update);
667       }
668       else
669       {
670         update_accumulator_refresh<Perspective>(pos);
671       }
672     }
673
674     alignas(CacheLineSize) BiasType biases[HalfDimensions];
675     alignas(CacheLineSize) WeightType weights[HalfDimensions * InputDimensions];
676     alignas(CacheLineSize) PSQTWeightType psqtWeights[InputDimensions * PSQTBuckets];
677   };
678
679 }  // namespace Stockfish::Eval::NNUE
680
681 #endif // #ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED