File size: 1,289 Bytes
6227608 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
class Kelm_Wrapper:
def __init__(self, model):
self.model = model
def get_best_candidate_word(self, default_phrases, candidate_phrases, index):
candidate_texts = [' '.join(default_phrases[:index]) + ' ' + cnd + ' ' + ' '.join(default_phrases[index+1:]) for cnd in candidate_phrases]
scores = list(map(self.model.score, candidate_texts))
return scores.index(max(scores))
def get_best_ongram_phrases(self, candidates_list):
bests = []
for candidate_phrase in candidates_list:
scores = list(map(self.model.score, candidate_phrase))
best_phrase = candidate_phrase[scores.index(max(scores))]
bests.append(best_phrase)
return bests
def get_best(self, candidates_list):
bests = []
default_phrases = self.get_best_ongram_phrases(candidates_list)
# print(default_phrases)
for index in range(len(candidates_list)):
if len(candidates_list[index]) > 1:
best_phrase_index = self.get_best_candidate_word(default_phrases, candidates_list[index], index)
bests.append(candidates_list[index][best_phrase_index])
else:
bests.append(candidates_list[index][0])
return ' '.join(bests)
|