Spaces:
Runtime error
Runtime error
File size: 5,757 Bytes
b546526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from modules.module_customPllLabel import CustomPllLabel
from modules.module_pllScore import PllScore
from typing import List, Dict
import torch
class RankSents:
def __init__(
self,
language_model, # LanguageModel class instance
lang: str
) -> None:
self.tokenizer = language_model.initTokenizer()
self.model = language_model.initModel()
_ = self.model.eval()
self.Label = CustomPllLabel()
self.pllScore = PllScore(
language_model=language_model
)
self.softmax = torch.nn.Softmax(dim=-1)
if lang == "spanish":
self.articles = [
'un','una','unos','unas','el','los','la','las','lo'
]
self.prepositions = [
'a','ante','bajo','cabe','con','contra','de','desde','en','entre','hacia','hasta','para','por','según','sin','so','sobre','tras','durante','mediante','vía','versus'
]
self.conjunctions = [
'y','o','ni','que','pero','si'
]
elif lang == "english":
self.articles = [
'a','an', 'the'
]
self.prepositions = [
'above', 'across', 'against', 'along', 'among', 'around', 'at', 'before', 'behind', 'below', 'beneath', 'beside', 'between', 'by', 'down', 'from', 'in', 'into', 'near', 'of', 'off', 'on', 'to', 'toward', 'under', 'upon', 'with', 'within'
]
self.conjunctions = [
'and', 'or', 'but', 'that', 'if', 'whether'
]
def errorChecking(
self,
sent: str
) -> str:
out_msj = ""
if not sent:
out_msj = "Error: You most enter a sentence!"
elif sent.count("*") > 1:
out_msj= " Error: The sentence entered must contain only one ' * '!"
elif sent.count("*") == 0:
out_msj= " Error: The entered sentence needs to contain a ' * ' in order to predict the word!"
else:
sent_len = len(self.tokenizer.encode(sent.replace("*", self.tokenizer.mask_token)))
max_len = self.tokenizer.max_len_single_sentence
if sent_len > max_len:
out_msj = f"Error: The sentence has more than {max_len} tokens!"
return out_msj
def getTop5Predictions(
self,
sent: str,
banned_wl: List[str],
articles: bool,
prepositions: bool,
conjunctions: bool
) -> List[str]:
sent_masked = sent.replace("*", self.tokenizer.mask_token)
inputs = self.tokenizer.encode_plus(
sent_masked,
add_special_tokens=True,
return_tensors='pt',
return_attention_mask=True, truncation=True
)
tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item()
with torch.no_grad():
out = self.model(**inputs)
logits = out.logits
outputs = self.softmax(logits)
outputs = torch.squeeze(outputs, dim=0)
probabilities = outputs[tk_position_mask]
first_tk_id = torch.argsort(probabilities, descending=True)
top5_tks_pred = []
for tk_id in first_tk_id:
tk_string = self.tokenizer.decode([tk_id])
tk_is_banned = tk_string in banned_wl
tk_is_punctuation = not tk_string.isalnum()
tk_is_substring = tk_string.startswith("##")
tk_is_special = (tk_string in self.tokenizer.all_special_tokens)
if articles:
tk_is_article = tk_string in self.articles
else:
tk_is_article = False
if prepositions:
tk_is_prepositions = tk_string in self.prepositions
else:
tk_is_prepositions = False
if conjunctions:
tk_is_conjunctions = tk_string in self.conjunctions
else:
tk_is_conjunctions = False
predictions_is_dessire = not any([
tk_is_banned,
tk_is_punctuation,
tk_is_substring,
tk_is_special,
tk_is_article,
tk_is_prepositions,
tk_is_conjunctions
])
if predictions_is_dessire and len(top5_tks_pred) < 5:
top5_tks_pred.append(tk_string)
elif len(top5_tks_pred) >= 5:
break
return top5_tks_pred
def rank(self,
sent: str,
word_list: List[str],
banned_word_list: List[str],
articles: bool,
prepositions: bool,
conjunctions: bool
) -> Dict[str, float]:
err = self.errorChecking(sent)
if err:
raise Exception(err)
if not word_list:
word_list = self.getTop5Predictions(
sent,
banned_word_list,
articles,
prepositions,
conjunctions
)
sent_list = []
sent_list2print = []
for word in word_list:
sent_list.append(sent.replace("*", "<"+word+">"))
sent_list2print.append(sent.replace("*", "<"+word+">"))
all_plls_scores = {}
for sent, sent2print in zip(sent_list, sent_list2print):
all_plls_scores[sent2print] = self.pllScore.compute(sent)
return all_plls_scores |