for (IndexType i = Start; i < InputDimensions; ++i) {
output[i] = static_cast<OutputType>(
- // really should be /127 but we need to make it fast
- // needs to be accounted for in the trainer
- std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128));
+ // Really should be /127 but we need to make it fast so we right shift
+ // by an extra 7 bits instead. Needs to be accounted for in the trainer.
+ std::min(127ll, ((long long)input[i] * input[i]) >> (2 * WeightScaleBits + 7)));
}
}
};