File size: 1,289 Bytes
6227608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

class Kelm_Wrapper:
    def __init__(self, model):
        self.model = model
    def get_best_candidate_word(self, default_phrases, candidate_phrases, index):
        candidate_texts = [' '.join(default_phrases[:index]) + ' ' + cnd + ' ' + ' '.join(default_phrases[index+1:]) for cnd in candidate_phrases]
        scores = list(map(self.model.score, candidate_texts))
        return scores.index(max(scores))


    def get_best_ongram_phrases(self, candidates_list):
        bests = []
        for candidate_phrase in candidates_list:
            scores = list(map(self.model.score, candidate_phrase))
            best_phrase = candidate_phrase[scores.index(max(scores))]
            bests.append(best_phrase)
        return bests


    def get_best(self, candidates_list):
        bests = []
        default_phrases = self.get_best_ongram_phrases(candidates_list)
        # print(default_phrases)
        for index in range(len(candidates_list)):
            if len(candidates_list[index]) > 1:
                best_phrase_index = self.get_best_candidate_word(default_phrases, candidates_list[index], index)
                bests.append(candidates_list[index][best_phrase_index])
            else:
                bests.append(candidates_list[index][0])
        return ' '.join(bests)