Spaces:
Build error
Build error
EXPORT float impl_score_dot_neon( | |
const uint8_t* query_ptr, | |
const uint8_t* vector_ptr, | |
uint32_t dim | |
) { | |
uint32x4_t mul1 = vdupq_n_u32(0); | |
uint32x4_t mul2 = vdupq_n_u32(0); | |
for (uint32_t _i = 0; _i < dim / 16; _i++) { | |
uint8x16_t q = vld1q_u8(query_ptr); | |
uint8x16_t v = vld1q_u8(vector_ptr); | |
query_ptr += 16; | |
vector_ptr += 16; | |
uint16x8_t mul_low = vmull_u8(vget_low_u8(q), vget_low_u8(v)); | |
uint16x8_t mul_high = vmull_u8(vget_high_u8(q), vget_high_u8(v)); | |
mul1 = vpadalq_u16(mul1, mul_low); | |
mul2 = vpadalq_u16(mul2, mul_high); | |
} | |
return (float)vaddvq_u32(vaddq_u32(mul1, mul2)); | |
} | |
EXPORT uint32_t impl_xor_popcnt_neon_uint128( | |
const uint8_t* query_ptr, | |
const uint8_t* vector_ptr, | |
uint32_t count | |
) { | |
uint32x4_t result = vdupq_n_u32(0); | |
for (uint32_t _i = 0; _i < count; _i++) { | |
uint8x16_t v = vld1q_u8(vector_ptr); | |
uint8x16_t q = vld1q_u8(query_ptr); | |
uint8x16_t x = veorq_u8(q, v); | |
uint8x16_t popcnt = vcntq_u8(x); | |
uint8x8_t popcnt_low = vget_low_u8(popcnt); | |
uint8x8_t popcnt_high = vget_high_u8(popcnt); | |
uint16x8_t sum = vaddl_u8(popcnt_low, popcnt_high); | |
result = vpadalq_u16(result, sum); | |
query_ptr += 16; | |
vector_ptr += 16; | |
} | |
return (uint32_t)vaddvq_u32(result); | |
} | |
EXPORT uint32_t impl_xor_popcnt_neon_uint64( | |
const uint8_t* query_ptr, | |
const uint8_t* vector_ptr, | |
uint32_t count | |
) { | |
uint16x4_t result = vdup_n_u16(0); | |
for (uint32_t _i = 0; _i < count; _i++) { | |
uint8x8_t v = vld1_u8(vector_ptr); | |
uint8x8_t q = vld1_u8(query_ptr); | |
uint8x8_t x = veor_u8(q, v); | |
uint8x8_t popcnt = vcnt_u8(x); | |
result = vpadal_u8(result, popcnt); | |
query_ptr += 8; | |
vector_ptr += 8; | |
} | |
return (uint32_t)vaddv_u16(result); | |
} | |
EXPORT float impl_score_l1_neon( | |
const uint8_t * query_ptr, | |
const uint8_t * vector_ptr, | |
uint32_t dim | |
) { | |
const uint8_t* v_ptr = (const uint8_t*)vector_ptr; | |
const uint8_t* q_ptr = (const uint8_t*)query_ptr; | |
uint32_t m = dim - (dim % 16); | |
uint16x8_t sum16_low = vdupq_n_u16(0); | |
uint16x8_t sum16_high = vdupq_n_u16(0); | |
// the vector sizes are assumed to be multiples of 16, no remaining part here | |
for (uint32_t i = 0; i < m; i += 16) { | |
uint8x16_t vec1 = vld1q_u8(v_ptr); | |
uint8x16_t vec2 = vld1q_u8(q_ptr); | |
uint8x16_t abs_diff = vabdq_u8(vec1, vec2); | |
uint16x8_t abs_diff16_low = vmovl_u8(vget_low_u8(abs_diff)); | |
uint16x8_t abs_diff16_high = vmovl_u8(vget_high_u8(abs_diff)); | |
sum16_low = vaddq_u16(sum16_low, abs_diff16_low); | |
sum16_high = vaddq_u16(sum16_high, abs_diff16_high); | |
v_ptr += 16; | |
q_ptr += 16; | |
} | |
// Horizontal sum of 16-bit integers | |
uint32x4_t sum32_low = vpaddlq_u16(sum16_low); | |
uint32x4_t sum32_high = vpaddlq_u16(sum16_high); | |
uint32x4_t sum32 = vaddq_u32(sum32_low, sum32_high); | |
uint32x2_t sum64_low = vadd_u32(vget_low_u32(sum32), vget_high_u32(sum32)); | |
uint32x2_t sum64_high = vpadd_u32(sum64_low, sum64_low); | |
uint32_t sum = vget_lane_u32(sum64_high, 0); | |
return (float) sum; | |
} | |