Spaces:
Build error
Build error
EXPORT float impl_score_dot_avx( | |
const uint8_t* query_ptr, | |
const uint8_t* vector_ptr, | |
uint32_t dim | |
) { | |
const __m256i* v_ptr = (const __m256i*)vector_ptr; | |
const __m256i* q_ptr = (const __m256i*)query_ptr; | |
__m256i mul1 = _mm256_setzero_si256(); | |
__m256i mask_epu32 = _mm256_set1_epi32(0xFFFF); | |
for (uint32_t _i = 0; _i < dim / 32; _i++) { | |
__m256i v = _mm256_loadu_si256(v_ptr); | |
__m256i q = _mm256_loadu_si256(q_ptr); | |
v_ptr++; | |
q_ptr++; | |
__m256i s = _mm256_maddubs_epi16(v, q); | |
__m256i s_low = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(s)); | |
__m256i s_high = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(s, 1)); | |
mul1 = _mm256_add_epi32(mul1, s_low); | |
mul1 = _mm256_add_epi32(mul1, s_high); | |
} | |
// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining | |
if (dim % 32 != 0) { | |
__m128i v_short = _mm_loadu_si128((const __m128i*)v_ptr); | |
__m128i q_short = _mm_loadu_si128((const __m128i*)q_ptr); | |
__m256i v1 = _mm256_cvtepu8_epi16(v_short); | |
__m256i q1 = _mm256_cvtepu8_epi16(q_short); | |
__m256i s = _mm256_mullo_epi16(v1, q1); | |
mul1 = _mm256_add_epi32(mul1, _mm256_and_si256(s, mask_epu32)); | |
mul1 = _mm256_add_epi32(mul1, _mm256_srli_epi32(s, 16)); | |
} | |
__m256 mul_ps = _mm256_cvtepi32_ps(mul1); | |
HSUM256_PS(mul_ps, mul_scalar); | |
return mul_scalar; | |
} | |
EXPORT float impl_score_l1_avx( | |
const uint8_t* query_ptr, | |
const uint8_t* vector_ptr, | |
uint32_t dim | |
) { | |
const __m256i* v_ptr = (const __m256i*)vector_ptr; | |
const __m256i* q_ptr = (const __m256i*)query_ptr; | |
uint32_t m = dim - (dim % 32); | |
__m256i sum256 = _mm256_setzero_si256(); | |
for (uint32_t i = 0; i < m; i += 32) { | |
__m256i v = _mm256_loadu_si256(v_ptr); | |
__m256i q = _mm256_loadu_si256(q_ptr); | |
v_ptr++; | |
q_ptr++; | |
// Compute the difference in both directions and take the maximum for abs | |
__m256i diff1 = _mm256_subs_epu8(v, q); | |
__m256i diff2 = _mm256_subs_epu8(q, v); | |
__m256i abs_diff = _mm256_max_epu8(diff1, diff2); | |
__m256i abs_diff16_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256()); | |
__m256i abs_diff16_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256()); | |
sum256 = _mm256_add_epi16(sum256, abs_diff16_lo); | |
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi); | |
} | |
// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining | |
if (m < dim) { | |
__m128i v_short = _mm_loadu_si128((const __m128i * ) v_ptr); | |
__m128i q_short = _mm_loadu_si128((const __m128i * ) q_ptr); | |
__m128i diff1 = _mm_subs_epu8(v_short, q_short); | |
__m128i diff2 = _mm_subs_epu8(q_short, v_short); | |
__m128i abs_diff = _mm_max_epu8(diff1, diff2); | |
__m128i abs_diff16_lo_128 = _mm_unpacklo_epi8(abs_diff, _mm_setzero_si128()); | |
__m128i abs_diff16_hi_128 = _mm_unpackhi_epi8(abs_diff, _mm_setzero_si128()); | |
__m256i abs_diff16_lo = _mm256_cvtepu16_epi32(abs_diff16_lo_128); | |
__m256i abs_diff16_hi = _mm256_cvtepu16_epi32(abs_diff16_hi_128); | |
sum256 = _mm256_add_epi16(sum256, abs_diff16_lo); | |
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi); | |
} | |
__m256i sum_epi32 = _mm256_add_epi32( | |
_mm256_unpacklo_epi16(sum256, _mm256_setzero_si256()), | |
_mm256_unpackhi_epi16(sum256, _mm256_setzero_si256())); | |
HSUM256_EPI32(sum_epi32, sum); | |
return (float) sum; | |
} | |