Spaces:
Build error
Build error
File size: 4,315 Bytes
84d2a97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
#include <stdlib.h>
#include <stdint.h>
#include <immintrin.h>
#include "export_macro.h"
#define HSUM256_PS(X, R) \
float R = 0.0f; \
{ \
__m128 x128 = _mm_add_ps(_mm256_extractf128_ps(X, 1), _mm256_castps256_ps128(X)); \
__m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); \
__m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); \
R = _mm_cvtss_f32(x32); \
}
#define HSUM256_EPI32(X, R) \
int R = 0; \
{ \
__m128i x128 = _mm_add_epi32(_mm256_extractf128_si256(X, 1), _mm256_castsi256_si128(X)); \
__m128i x64 = _mm_add_epi32(x128, _mm_srli_si128(x128, 8)); \
__m128i x32 = _mm_add_epi32(x64, _mm_srli_si128(x64, 4)); \
R = _mm_cvtsi128_si32(x32); \
}
EXPORT float impl_score_dot_avx(
const uint8_t* query_ptr,
const uint8_t* vector_ptr,
uint32_t dim
) {
const __m256i* v_ptr = (const __m256i*)vector_ptr;
const __m256i* q_ptr = (const __m256i*)query_ptr;
__m256i mul1 = _mm256_setzero_si256();
__m256i mask_epu32 = _mm256_set1_epi32(0xFFFF);
for (uint32_t _i = 0; _i < dim / 32; _i++) {
__m256i v = _mm256_loadu_si256(v_ptr);
__m256i q = _mm256_loadu_si256(q_ptr);
v_ptr++;
q_ptr++;
__m256i s = _mm256_maddubs_epi16(v, q);
__m256i s_low = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(s));
__m256i s_high = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(s, 1));
mul1 = _mm256_add_epi32(mul1, s_low);
mul1 = _mm256_add_epi32(mul1, s_high);
}
// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining
if (dim % 32 != 0) {
__m128i v_short = _mm_loadu_si128((const __m128i*)v_ptr);
__m128i q_short = _mm_loadu_si128((const __m128i*)q_ptr);
__m256i v1 = _mm256_cvtepu8_epi16(v_short);
__m256i q1 = _mm256_cvtepu8_epi16(q_short);
__m256i s = _mm256_mullo_epi16(v1, q1);
mul1 = _mm256_add_epi32(mul1, _mm256_and_si256(s, mask_epu32));
mul1 = _mm256_add_epi32(mul1, _mm256_srli_epi32(s, 16));
}
__m256 mul_ps = _mm256_cvtepi32_ps(mul1);
HSUM256_PS(mul_ps, mul_scalar);
return mul_scalar;
}
EXPORT float impl_score_l1_avx(
const uint8_t* query_ptr,
const uint8_t* vector_ptr,
uint32_t dim
) {
const __m256i* v_ptr = (const __m256i*)vector_ptr;
const __m256i* q_ptr = (const __m256i*)query_ptr;
uint32_t m = dim - (dim % 32);
__m256i sum256 = _mm256_setzero_si256();
for (uint32_t i = 0; i < m; i += 32) {
__m256i v = _mm256_loadu_si256(v_ptr);
__m256i q = _mm256_loadu_si256(q_ptr);
v_ptr++;
q_ptr++;
// Compute the difference in both directions and take the maximum for abs
__m256i diff1 = _mm256_subs_epu8(v, q);
__m256i diff2 = _mm256_subs_epu8(q, v);
__m256i abs_diff = _mm256_max_epu8(diff1, diff2);
__m256i abs_diff16_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256());
__m256i abs_diff16_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256());
sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
}
// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining
if (m < dim) {
__m128i v_short = _mm_loadu_si128((const __m128i * ) v_ptr);
__m128i q_short = _mm_loadu_si128((const __m128i * ) q_ptr);
__m128i diff1 = _mm_subs_epu8(v_short, q_short);
__m128i diff2 = _mm_subs_epu8(q_short, v_short);
__m128i abs_diff = _mm_max_epu8(diff1, diff2);
__m128i abs_diff16_lo_128 = _mm_unpacklo_epi8(abs_diff, _mm_setzero_si128());
__m128i abs_diff16_hi_128 = _mm_unpackhi_epi8(abs_diff, _mm_setzero_si128());
__m256i abs_diff16_lo = _mm256_cvtepu16_epi32(abs_diff16_lo_128);
__m256i abs_diff16_hi = _mm256_cvtepu16_epi32(abs_diff16_hi_128);
sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
}
__m256i sum_epi32 = _mm256_add_epi32(
_mm256_unpacklo_epi16(sum256, _mm256_setzero_si256()),
_mm256_unpackhi_epi16(sum256, _mm256_setzero_si256()));
HSUM256_EPI32(sum_epi32, sum);
return (float) sum;
}
|