File size: 771 Bytes
a006afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from typing import Dict, Callable
import numpy as np
import numpy.typing as npt


# These match what the spec of hnswlib is
# This epsilon is used to prevent division by zero and the value is the same
# https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238
NORM_EPS = 1e-30
distance_functions: Dict[str, Callable[[npt.ArrayLike, npt.ArrayLike], float]] = {
    "l2": lambda x, y: np.linalg.norm(x - y) ** 2,  # type: ignore
    "cosine": lambda x, y: 1 - np.dot(x, y) / ((np.linalg.norm(x) + NORM_EPS) * (np.linalg.norm(y) + NORM_EPS)),  # type: ignore
    "ip": lambda x, y: 1 - np.dot(x, y),  # type: ignore
}

l2 = distance_functions["l2"]
cosine = distance_functions["cosine"]
ip = distance_functions["ip"]