- void propagate(
- const InputType* input, OutputType* output) const {
-
- #if defined(USE_SSE2)
- constexpr IndexType NumChunks = InputDimensions / 16;
-
- static_assert(WeightScaleBits == 6);
- const auto in = reinterpret_cast<const __m128i*>(input);
- const auto out = reinterpret_cast<__m128i*>(output);
- for (IndexType i = 0; i < NumChunks; ++i) {
- __m128i words0 = _mm_packs_epi32(
- _mm_load_si128(&in[i * 4 + 0]),
- _mm_load_si128(&in[i * 4 + 1]));
- __m128i words1 = _mm_packs_epi32(
- _mm_load_si128(&in[i * 4 + 2]),
- _mm_load_si128(&in[i * 4 + 3]));
-
- // We shift by WeightScaleBits * 2 = 12 and divide by 128
- // which is an additional shift-right of 7, meaning 19 in total.
- // MulHi strips the lower 16 bits so we need to shift out 3 more to match.
- words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3);
- words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3);
-
- _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
- }
- constexpr IndexType Start = NumChunks * 16;
-
- #else
- constexpr IndexType Start = 0;
- #endif
-
- for (IndexType i = Start; i < InputDimensions; ++i) {
- output[i] = static_cast<OutputType>(
- // really should be /127 but we need to make it fast
- // needs to be accounted for in the trainer
- std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128));
- }
+ void propagate(const InputType* input, OutputType* output) const {
+
+#if defined(USE_SSE2)
+ constexpr IndexType NumChunks = InputDimensions / 16;
+
+ static_assert(WeightScaleBits == 6);
+ const auto in = reinterpret_cast<const __m128i*>(input);
+ const auto out = reinterpret_cast<__m128i*>(output);
+ for (IndexType i = 0; i < NumChunks; ++i)
+ {
+ __m128i words0 =
+ _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 0]), _mm_load_si128(&in[i * 4 + 1]));
+ __m128i words1 =
+ _mm_packs_epi32(_mm_load_si128(&in[i * 4 + 2]), _mm_load_si128(&in[i * 4 + 3]));
+
+ // We shift by WeightScaleBits * 2 = 12 and divide by 128
+ // which is an additional shift-right of 7, meaning 19 in total.
+ // MulHi strips the lower 16 bits so we need to shift out 3 more to match.
+ words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3);
+ words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3);
+
+ _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
+ }
+ constexpr IndexType Start = NumChunks * 16;
+
+#else
+ constexpr IndexType Start = 0;
+#endif
+
+ for (IndexType i = Start; i < InputDimensions; ++i)
+ {
+ output[i] = static_cast<OutputType>(
+ // Really should be /127 but we need to make it fast so we right-shift
+ // by an extra 7 bits instead. Needs to be accounted for in the trainer.
+ std::min(127ll, ((long long) (input[i]) * input[i]) >> (2 * WeightScaleBits + 7)));
+ }