Spaces:
Configuration error
Configuration error
File size: 2,105 Bytes
a779273 2e649f6 a779273 ced76fc a779273 ced76fc a779273 ced76fc a779273 ced76fc a779273 ced76fc 2e649f6 ced76fc a779273 164b173 a779273 ced76fc 2e649f6 ced76fc 164b173 a779273 ced76fc a779273 ced76fc a779273 ced76fc a779273 ced76fc a779273 ced76fc a779273 2e649f6 a779273 ced76fc 2e649f6 ced76fc a779273 ced76fc 2e649f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import time
from tqdm import tqdm
from annoy import AnnoyIndex
from memory_profiler import profile
from typing import List
class TicToc:
def __init__(
self
) -> None:
self.i = None
def start(
self
) -> None:
self.i = time.time()
def stop(
self
) -> None:
f = time.time()
print(f - self.i, "seg.")
class Ann:
def __init__(
self,
words: List[str],
vectors: List,
coord: List,
) -> None:
self.words = words
self.vectors = vectors
self.coord = coord
self.tree = None
self.tt = TicToc()
self.availables_metrics = ['angular','euclidean','manhattan','hamming','dot']
def init(self,
n_trees: int=10,
metric: str='angular',
n_jobs: int=-1 # n_jobs=-1 Run over all CPU availables
) -> None:
assert(metric in self.availables_metrics), f"Error: The value of the parameter 'metric' can only be {self.availables_metrics}!"
print("\tInit tree...")
self.tt.start()
self.tree = AnnoyIndex(len(self.vectors[0]), metric=metric)
for i, v in tqdm(enumerate(self.vectors), total=len(self.vectors)):
self.tree.add_item(i, v)
self.tt.stop()
print("\tBuild tree...")
self.tt.start()
self.tree.build(n_trees=n_trees, n_jobs=n_jobs)
self.tt.stop()
def __getWordId(
self,
word: str
) -> int:
word_id = None
try:
word_id = self.words.index(word)
except:
pass
return word_id
def get(
self,
word: str,
n_neighbors: int=10
) -> List[str]:
word_id = self.__getWordId(word)
neighbors_list = None
if word_id != None:
neighbords_id = self.tree.get_nns_by_item(word_id, n_neighbors + 1)
neighbors_list = [self.words[idx] for idx in neighbords_id][1:]
else:
print(f"The word '{word}' does not exist")
return neighbors_list |