#if defined(USE_SSE2)
constexpr IndexType NumChunks = InputDimensions / 16;
- #ifdef USE_SSE41
- const __m128i Zero = _mm_setzero_si128();
- #else
- const __m128i k0x80s = _mm_set1_epi8(-128);
- #endif
-
static_assert(WeightScaleBits == 6);
const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output);
_mm_load_si128(&in[i * 4 + 2]),
_mm_load_si128(&in[i * 4 + 3]));
- // Not sure if
+ // We shift by WeightScaleBits * 2 = 12 and divide by 128
+ // which is an additional shift-right of 7, meaning 19 in total.
+ // MulHi strips the lower 16 bits so we need to shift out 3 more to match.
words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3);
words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3);
- const __m128i packedbytes = _mm_packs_epi16(words0, words1);
-
- _mm_store_si128(&out[i],
-
- #ifdef USE_SSE41
- _mm_max_epi8(packedbytes, Zero)
- #else
- _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
- #endif
-
- );
+ _mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
}
constexpr IndexType Start = NumChunks * 16;
output[i] = static_cast<OutputType>(
// really should be /127 but we need to make it fast
// needs to be accounted for in the trainer
- std::max(0ll, std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)));
+ std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128));
}
}
};