-
- // Determine if eights of weight and input products can be summed using 16bits
- // without saturation. We assume worst case combinations of 0 and 127 for all inputs.
- if (kOutputDimensions > 1 && !stream.fail())
- {
- canSaturate16.count = 0;
-#if !defined(USE_VNNI)
- for (IndexType i = 0; i < kPaddedInputDimensions; i += 16)
- for (IndexType j = 0; j < kOutputDimensions; ++j)
- for (int x = 0; x < 2; ++x)
- {
- WeightType* w = &weights_[i * kOutputDimensions + j * 4 + x * 2];
- int sum[2] = {0, 0};
- for (int k = 0; k < 8; ++k)
- {
- IndexType idx = k / 2 * kOutputDimensions * 4 + k % 2;
- sum[w[idx] < 0] += w[idx];
- }
- for (int sign : {-1, 1})
- while (sign * sum[sign == -1] > 258)
- {
- int maxK = 0, maxW = 0;
- for (int k = 0; k < 8; ++k)
- {
- IndexType idx = k / 2 * kOutputDimensions * 4 + k % 2;
- if (maxW < sign * w[idx])
- maxK = k, maxW = sign * w[idx];
- }
-
- IndexType idx = maxK / 2 * kOutputDimensions * 4 + maxK % 2;
- sum[sign == -1] -= w[idx];
- canSaturate16.add(j, i + maxK / 2 * 4 + maxK % 2 + x * 2, w[idx]);
- w[idx] = 0;
- }
- }
-
- // Non functional optimization for faster more linear access
- std::sort(canSaturate16.ids, canSaturate16.ids + canSaturate16.count,
- [](const typename CanSaturate::Entry& e1, const typename CanSaturate::Entry& e2)
- { return e1.in == e2.in ? e1.out < e2.out : e1.in < e2.in; });