File size: 3,302 Bytes
84d2a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <stdlib.h>
#include <arm_neon.h>

#include "export_macro.h"

EXPORT float impl_score_dot_neon(
    const uint8_t* query_ptr,
    const uint8_t* vector_ptr,
    uint32_t dim
) {
    uint32x4_t mul1 = vdupq_n_u32(0);
    uint32x4_t mul2 = vdupq_n_u32(0);
    for (uint32_t _i = 0; _i < dim / 16; _i++) {
        uint8x16_t q = vld1q_u8(query_ptr);
        uint8x16_t v = vld1q_u8(vector_ptr);
        query_ptr += 16;
        vector_ptr += 16;
        uint16x8_t mul_low = vmull_u8(vget_low_u8(q), vget_low_u8(v));
        uint16x8_t mul_high = vmull_u8(vget_high_u8(q), vget_high_u8(v));
        mul1 = vpadalq_u16(mul1, mul_low);
        mul2 = vpadalq_u16(mul2, mul_high);
    }
    return (float)vaddvq_u32(vaddq_u32(mul1, mul2));
}

EXPORT uint32_t impl_xor_popcnt_neon_uint128(
    const uint8_t* query_ptr,
    const uint8_t* vector_ptr,
    uint32_t count
) {
    uint32x4_t result = vdupq_n_u32(0);
    for (uint32_t _i = 0; _i < count; _i++) {
        uint8x16_t v = vld1q_u8(vector_ptr);
        uint8x16_t q = vld1q_u8(query_ptr);

        uint8x16_t x = veorq_u8(q, v);
        uint8x16_t popcnt = vcntq_u8(x);
        uint8x8_t popcnt_low = vget_low_u8(popcnt);
        uint8x8_t popcnt_high = vget_high_u8(popcnt);
        uint16x8_t sum = vaddl_u8(popcnt_low, popcnt_high);
        result = vpadalq_u16(result, sum);

        query_ptr += 16;
        vector_ptr += 16;
    }
    return (uint32_t)vaddvq_u32(result);
}

EXPORT uint32_t impl_xor_popcnt_neon_uint64(
    const uint8_t* query_ptr,
    const uint8_t* vector_ptr,
    uint32_t count
) {
    uint16x4_t result = vdup_n_u16(0);
    for (uint32_t _i = 0; _i < count; _i++) {
        uint8x8_t v = vld1_u8(vector_ptr);
        uint8x8_t q = vld1_u8(query_ptr);

        uint8x8_t x = veor_u8(q, v);
        uint8x8_t popcnt = vcnt_u8(x);
        result = vpadal_u8(result, popcnt);

        query_ptr += 8;
        vector_ptr += 8;
    }
    return (uint32_t)vaddv_u16(result);
}

EXPORT float impl_score_l1_neon(
   const uint8_t * query_ptr,
   const uint8_t * vector_ptr,
   uint32_t dim
) {
    const uint8_t* v_ptr = (const uint8_t*)vector_ptr;
    const uint8_t* q_ptr = (const uint8_t*)query_ptr;

    uint32_t m = dim - (dim % 16);
    uint16x8_t sum16_low = vdupq_n_u16(0);
    uint16x8_t sum16_high = vdupq_n_u16(0);

    // the vector sizes are assumed to be multiples of 16, no remaining part here
    for (uint32_t i = 0; i < m; i += 16) {
        uint8x16_t vec1 = vld1q_u8(v_ptr);
        uint8x16_t vec2 = vld1q_u8(q_ptr);

        uint8x16_t abs_diff = vabdq_u8(vec1, vec2);
        uint16x8_t abs_diff16_low = vmovl_u8(vget_low_u8(abs_diff));
        uint16x8_t abs_diff16_high = vmovl_u8(vget_high_u8(abs_diff));

        sum16_low = vaddq_u16(sum16_low, abs_diff16_low);
        sum16_high = vaddq_u16(sum16_high, abs_diff16_high);

        v_ptr += 16;
        q_ptr += 16;
    }

    // Horizontal sum of 16-bit integers
    uint32x4_t sum32_low = vpaddlq_u16(sum16_low);
    uint32x4_t sum32_high = vpaddlq_u16(sum16_high);
    uint32x4_t sum32 = vaddq_u32(sum32_low, sum32_high);

    uint32x2_t sum64_low = vadd_u32(vget_low_u32(sum32), vget_high_u32(sum32));
    uint32x2_t sum64_high = vpadd_u32(sum64_low, sum64_low);
    uint32_t sum = vget_lane_u32(sum64_high, 0);

    return (float) sum;
}