File size: 2,068 Bytes
e8aad19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import time
from tqdm import tqdm
from annoy import AnnoyIndex
from typing import List

class TicToc:
    def __init__(
        self
    ) -> None:

        self.i = None

    def start(
        self
    ) -> None:

        self.i = time.time()

    def stop(
        self
    ) -> None:

        f = time.time()
        print(f - self.i, "seg.")


class Ann:
    def __init__(
        self, 
        words: List[str], 
        vectors: List, 
        coord: List, 
    ) -> None:

        self.words = words
        self.vectors = vectors
        self.coord = coord
        self.tree = None

        self.tt = TicToc()
        self.availables_metrics = ['angular','euclidean','manhattan','hamming','dot']

    def init(self, 
        n_trees: int=10, 
        metric: str='angular', 
        n_jobs: int=-1  # n_jobs=-1 Run over all CPU availables
    ) -> None:

        assert(metric in self.availables_metrics), f"Error: The value of the parameter 'metric' can only be {self.availables_metrics}!"

        print("\tInit tree...")
        self.tt.start()
        self.tree = AnnoyIndex(len(self.vectors[0]), metric=metric)
        for i, v in tqdm(enumerate(self.vectors), total=len(self.vectors)):
            self.tree.add_item(i, v)
        self.tt.stop()

        print("\tBuild tree...")
        self.tt.start()
        self.tree.build(n_trees=n_trees, n_jobs=n_jobs)
        self.tt.stop()

    def __getWordId(
        self, 
        word: str
    ) -> int:

        word_id = None
        try:
            word_id = self.words.index(word)
        except:
            pass
        return word_id

    def get(
        self, 
        word: str, 
        n_neighbors: int=10
    ) -> List[str]:
        
        word_id = self.__getWordId(word)
        neighbors_list = None

        if word_id != None:
            neighbords_id = self.tree.get_nns_by_item(word_id, n_neighbors + 1)
            neighbors_list = [self.words[idx] for idx in neighbords_id][1:]

        else:
            print(f"The word '{word}' does not exist")

        return neighbors_list