- const auto iv_256 = reinterpret_cast<const __m256i*>(input);
- const auto row_256 = reinterpret_cast<const __m256i*>(&weights_[offset]);
- int j = kNumChunks * 2;
- __m256i sum256 = _mm256_maddubs_epi16(_mm256_loadA_si256(&iv_256[j]), _mm256_load_si256(&row_256[j]));
- sum256 = _mm256_madd_epi16(sum256, _mm256_set1_epi16(1));
- sum256 = _mm256_hadd_epi32(sum256, sum256);
- sum256 = _mm256_hadd_epi32(sum256, sum256);
- const __m128i lo = _mm256_extracti128_si256(sum256, 0);
- const __m128i hi = _mm256_extracti128_si256(sum256, 1);
- output[i] += _mm_cvtsi128_si32(lo) + _mm_cvtsi128_si32(hi);
+ const auto iv256 = reinterpret_cast<const __m256i*>(&input_vector[kNumChunks]);
+ const auto row256 = reinterpret_cast<const __m256i*>(&row[kNumChunks]);
+ __m256i product256 = _mm256_maddubs_epi16(_mm256_loadA_si256(&iv256[0]), _mm256_load_si256(&row256[0]));
+ product256 = _mm256_madd_epi16(product256, _mm256_set1_epi16(1));
+ sum = _mm512_add_epi32(sum, _mm512_zextsi256_si512(product256));