File size: 2,105 Bytes
a779273
 
 
 
2e649f6
a779273
 
ced76fc
 
 
 
a779273
ced76fc
 
 
 
 
a779273
ced76fc
 
 
 
 
a779273
 
 
ced76fc
a779273
ced76fc
 
 
2e649f6
 
ced76fc
 
 
 
 
a779273
 
 
164b173
 
a779273
ced76fc
 
 
2e649f6
ced76fc
 
164b173
a779273
ced76fc
a779273
 
ced76fc
 
a779273
 
ced76fc
a779273
 
 
 
ced76fc
 
 
 
 
a779273
 
 
 
 
 
 
ced76fc
 
 
 
 
 
a779273
2e649f6
a779273
 
ced76fc
2e649f6
ced76fc
a779273
 
ced76fc
2e649f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import time
from tqdm import tqdm
from annoy import AnnoyIndex
from memory_profiler import profile
from typing import List

class TicToc:
    def __init__(
        self
    ) -> None:

        self.i = None

    def start(
        self
    ) -> None:

        self.i = time.time()

    def stop(
        self
    ) -> None:

        f = time.time()
        print(f - self.i, "seg.")


class Ann:
    def __init__(
        self, 
        words: List[str], 
        vectors: List, 
        coord: List, 
    ) -> None:

        self.words = words
        self.vectors = vectors
        self.coord = coord
        self.tree = None

        self.tt = TicToc()
        self.availables_metrics = ['angular','euclidean','manhattan','hamming','dot']


    def init(self, 
        n_trees: int=10, 
        metric: str='angular', 
        n_jobs: int=-1  # n_jobs=-1 Run over all CPU availables
    ) -> None:

        assert(metric in self.availables_metrics), f"Error: The value of the parameter 'metric' can only be {self.availables_metrics}!"

        print("\tInit tree...")
        self.tt.start()
        self.tree = AnnoyIndex(len(self.vectors[0]), metric=metric)
        for i, v in tqdm(enumerate(self.vectors), total=len(self.vectors)):
            self.tree.add_item(i, v)
        self.tt.stop()

        print("\tBuild tree...")
        self.tt.start()
        self.tree.build(n_trees=n_trees, n_jobs=n_jobs)
        self.tt.stop()

    def __getWordId(
        self, 
        word: str
    ) -> int:

        word_id = None
        try:
            word_id = self.words.index(word)
        except:
            pass
        return word_id

    def get(
        self, 
        word: str, 
        n_neighbors: int=10
    ) -> List[str]:
        
        word_id = self.__getWordId(word)
        neighbors_list = None

        if word_id != None:
            neighbords_id = self.tree.get_nns_by_item(word_id, n_neighbors + 1)
            neighbors_list = [self.words[idx] for idx in neighbords_id][1:]

        else:
            print(f"The word '{word}' does not exist")

        return neighbors_list