from difflib import Differ import torch, re class PllScore: def __init__( self, language_model # LanguageModel class instance ) -> None: self.tokenizer = language_model.initTokenizer() self.model = language_model.initModel() _ = self.model.eval() self.logSoftmax = torch.nn.LogSoftmax(dim=-1) def sentIsCorrect( self, sent: str ) -> bool: # Mod is_correct = True # Check mark existence open_mark = sent.count("<") close_mark = sent.count(">") total_mark = open_mark + close_mark if (total_mark == 0) or (open_mark != close_mark): is_correct = False # Check existence of twin marks (ie: '<<' or '>>') if is_correct: left_twin = sent.count("<<") rigth_twin = sent.count(">>") if left_twin + rigth_twin > 0: is_correct = False if is_correct: # Check balanced symbols '<' and '>' stack = [] for c in sent: if c == '<': stack.append('<') elif c == '>': if len(stack) == 0: is_correct = False break if stack.pop() != "<": is_correct = False break if len(stack) > 0: is_correct = False if is_correct: for w in re.findall("\<.*?\>", sent): # Check empty interest words word = w.replace("<","").replace(">","").strip() if not word: is_correct = False break # Check if there are any marks inside others (ie: >) word = w.strip()[1:-1] #Delete the first and last mark if '<' in word or '>' in word: is_correct = False break if is_correct: # Check that there is at least one uninteresting word. The next examples should not be allowed # (ie: , ) outside_words = re.sub("\<.*?\>", "", sent.replace("<", " < ").replace(">", " > ")) outside_words = [w for w in outside_words.split() if w != ""] if not outside_words: is_correct = False return is_correct def compute( self, sent: str ) -> float: assert(self.sentIsCorrect(sent)), f"Error: The sentence '{sent}' does not have the correct format!" outside_words = re.sub("\<.*?\>", "", sent.replace("<", " < ").replace(">", " > ")) outside_words = [w for w in outside_words.split() if w != ""] all_words = [w.strip() for w in sent.replace("<"," ").replace(">"," ").split() if w != ""] tks_id_outside_words = self.tokenizer.encode( " ".join(outside_words), add_special_tokens=False, truncation=True ) tks_id_all_words = self.tokenizer.encode( " ".join(all_words), add_special_tokens=False, truncation=True ) diff = [(tk[0], tk[2:]) for tk in Differ().compare(tks_id_outside_words, tks_id_all_words)] cls_tk_id = self.tokenizer.cls_token_id sep_tk_id = self.tokenizer.sep_token_id mask_tk_id = self.tokenizer.mask_token_id all_sent_masked = [] all_tks_id_masked = [] all_tks_position_masked = [] for i in range(0, len(diff)): current_sent_masked = [cls_tk_id] add_sent = True for j, (mark, tk_id) in enumerate(diff): if j == i: if mark == '+': add_sent = False break else: current_sent_masked.append(mask_tk_id) all_tks_id_masked.append(int(tk_id)) all_tks_position_masked.append(i+1) else: current_sent_masked.append(int(tk_id)) if add_sent: current_sent_masked.append(sep_tk_id) all_sent_masked.append(current_sent_masked) inputs_ids = torch.tensor(all_sent_masked) attention_mask = torch.ones_like(inputs_ids) with torch.no_grad(): out = self.model(inputs_ids, attention_mask) logits = out.logits outputs = self.logSoftmax(logits) pll_score = 0 for out, tk_pos, tk_id in zip(outputs, all_tks_position_masked, all_tks_id_masked): probabilities = out[tk_pos] tk_prob = probabilities[tk_id] pll_score += tk_prob.item() return pll_score