Spaces:
Runtime error
Runtime error
import time | |
from tqdm import tqdm | |
from annoy import AnnoyIndex | |
from typing import List | |
class TicToc: | |
def __init__( | |
self | |
) -> None: | |
self.i = None | |
def start( | |
self | |
) -> None: | |
self.i = time.time() | |
def stop( | |
self | |
) -> None: | |
f = time.time() | |
print(f - self.i, "seg.") | |
class Ann: | |
def __init__( | |
self, | |
words: List[str], | |
vectors: List, | |
coord: List, | |
) -> None: | |
self.words = words | |
self.vectors = vectors | |
self.coord = coord | |
self.tree = None | |
self.tt = TicToc() | |
self.availables_metrics = ['angular','euclidean','manhattan','hamming','dot'] | |
def init(self, | |
n_trees: int=10, | |
metric: str='angular', | |
n_jobs: int=-1 # n_jobs=-1 Run over all CPU availables | |
) -> None: | |
assert(metric in self.availables_metrics), f"Error: The value of the parameter 'metric' can only be {self.availables_metrics}!" | |
print("\tInit tree...") | |
self.tt.start() | |
self.tree = AnnoyIndex(len(self.vectors[0]), metric=metric) | |
for i, v in tqdm(enumerate(self.vectors), total=len(self.vectors)): | |
self.tree.add_item(i, v) | |
self.tt.stop() | |
print("\tBuild tree...") | |
self.tt.start() | |
self.tree.build(n_trees=n_trees, n_jobs=n_jobs) | |
self.tt.stop() | |
def __getWordId( | |
self, | |
word: str | |
) -> int: | |
word_id = None | |
try: | |
word_id = self.words.index(word) | |
except: | |
pass | |
return word_id | |
def get( | |
self, | |
word: str, | |
n_neighbors: int=10 | |
) -> List[str]: | |
word_id = self.__getWordId(word) | |
neighbors_list = None | |
if word_id != None: | |
neighbords_id = self.tree.get_nns_by_item(word_id, n_neighbors + 1) | |
neighbors_list = [self.words[idx] for idx in neighbords_id][1:] | |
else: | |
print(f"The word '{word}' does not exist") | |
return neighbors_list |