Spaces:
Runtime error
Runtime error
from difflib import Differ | |
import torch, re | |
class PllScore: | |
def __init__( | |
self, | |
language_model # LanguageModel class instance | |
) -> None: | |
self.tokenizer = language_model.initTokenizer() | |
self.model = language_model.initModel() | |
_ = self.model.eval() | |
self.logSoftmax = torch.nn.LogSoftmax(dim=-1) | |
def sentIsCorrect( | |
self, | |
sent: str | |
) -> bool: | |
# Mod | |
is_correct = True | |
# Check mark existence | |
open_mark = sent.count("<") | |
close_mark = sent.count(">") | |
total_mark = open_mark + close_mark | |
if (total_mark == 0) or (open_mark != close_mark): | |
is_correct = False | |
# Check existence of twin marks (ie: '<<' or '>>') | |
if is_correct: | |
left_twin = sent.count("<<") | |
rigth_twin = sent.count(">>") | |
if left_twin + rigth_twin > 0: | |
is_correct = False | |
if is_correct: | |
# Check balanced symbols '<' and '>' | |
stack = [] | |
for c in sent: | |
if c == '<': | |
stack.append('<') | |
elif c == '>': | |
if len(stack) == 0: | |
is_correct = False | |
break | |
if stack.pop() != "<": | |
is_correct = False | |
break | |
if len(stack) > 0: | |
is_correct = False | |
if is_correct: | |
for w in re.findall("\<.*?\>", sent): | |
# Check empty interest words | |
word = w.replace("<","").replace(">","").strip() | |
if not word: | |
is_correct = False | |
break | |
# Check if there are any marks inside others (ie: <this is a <sentence>>) | |
word = w.strip()[1:-1] #Delete the first and last mark | |
if '<' in word or '>' in word: | |
is_correct = False | |
break | |
if is_correct: | |
# Check that there is at least one uninteresting word. The next examples should not be allowed | |
# (ie: <this is a sent>, <this> <is a sent>) | |
outside_words = re.sub("\<.*?\>", "", sent.replace("<", " < ").replace(">", " > ")) | |
outside_words = [w for w in outside_words.split() if w != ""] | |
if not outside_words: | |
is_correct = False | |
return is_correct | |
def compute( | |
self, | |
sent: str | |
) -> float: | |
assert(self.sentIsCorrect(sent)), f"Error: The sentence '{sent}' does not have the correct format!" | |
outside_words = re.sub("\<.*?\>", "", sent.replace("<", " < ").replace(">", " > ")) | |
outside_words = [w for w in outside_words.split() if w != ""] | |
all_words = [w.strip() for w in sent.replace("<"," ").replace(">"," ").split() if w != ""] | |
tks_id_outside_words = self.tokenizer.encode( | |
" ".join(outside_words), | |
add_special_tokens=False, | |
truncation=True | |
) | |
tks_id_all_words = self.tokenizer.encode( | |
" ".join(all_words), | |
add_special_tokens=False, | |
truncation=True | |
) | |
diff = [(tk[0], tk[2:]) for tk in Differ().compare(tks_id_outside_words, tks_id_all_words)] | |
cls_tk_id = self.tokenizer.cls_token_id | |
sep_tk_id = self.tokenizer.sep_token_id | |
mask_tk_id = self.tokenizer.mask_token_id | |
all_sent_masked = [] | |
all_tks_id_masked = [] | |
all_tks_position_masked = [] | |
for i in range(0, len(diff)): | |
current_sent_masked = [cls_tk_id] | |
add_sent = True | |
for j, (mark, tk_id) in enumerate(diff): | |
if j == i: | |
if mark == '+': | |
add_sent = False | |
break | |
else: | |
current_sent_masked.append(mask_tk_id) | |
all_tks_id_masked.append(int(tk_id)) | |
all_tks_position_masked.append(i+1) | |
else: | |
current_sent_masked.append(int(tk_id)) | |
if add_sent: | |
current_sent_masked.append(sep_tk_id) | |
all_sent_masked.append(current_sent_masked) | |
inputs_ids = torch.tensor(all_sent_masked) | |
attention_mask = torch.ones_like(inputs_ids) | |
with torch.no_grad(): | |
out = self.model(inputs_ids, attention_mask) | |
logits = out.logits | |
outputs = self.logSoftmax(logits) | |
pll_score = 0 | |
for out, tk_pos, tk_id in zip(outputs, all_tks_position_masked, all_tks_id_masked): | |
probabilities = out[tk_pos] | |
tk_prob = probabilities[tk_id] | |
pll_score += tk_prob.item() | |
return pll_score |