Update modeling_deberta.py
Browse files- modeling_deberta.py +49 -0
modeling_deberta.py
CHANGED
@@ -1158,6 +1158,55 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
|
1158 |
attentions=outputs.attentions,
|
1159 |
)
|
1160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1161 |
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
|
1162 |
class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
1163 |
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
|
|
1158 |
attentions=outputs.attentions,
|
1159 |
)
|
1160 |
|
1161 |
+
@torch.no_grad()
|
1162 |
+
def score(self, sequence: str, scored_length: int, tokenizer, device, batch_size):
|
1163 |
+
mask_index = tokenizer.mask_token_id
|
1164 |
+
cls_index = torch.tensor([tokenizer.cls_token_id])
|
1165 |
+
sep_index = torch.tensor([tokenizer.sep_token_id])
|
1166 |
+
|
1167 |
+
encoding = tokenizer(sequence, add_special_tokens=False, return_tensors="pt")
|
1168 |
+
num_words = max(i for i in encoding.word_ids() if i is not None) + 1
|
1169 |
+
scored_mask = [i and i >= num_words - scored_length for i in encoding.word_ids()]
|
1170 |
+
num_to_score = sum(scored_mask)
|
1171 |
+
|
1172 |
+
tokens = encoding.input_ids.squeeze(0)
|
1173 |
+
|
1174 |
+
tokens = torch.cat([cls_index, tokens, sep_index]).to(device)
|
1175 |
+
tokens = tokens.repeat(num_to_score, 1)
|
1176 |
+
mask = torch.eye(tokens.size(1), device=device).bool()[-(num_to_score+1):-1, :]
|
1177 |
+
input_ids = tokens.masked_fill(mask, value=mask_index)
|
1178 |
+
if num_to_score > 1:
|
1179 |
+
mask = torch.eye(tokens.size(1), device=device).bool()[-(num_to_score+1)+1:-1, :]
|
1180 |
+
input_ids[:-1, :] = input_ids[:-1, :].masked_fill(mask, value=mask_index)
|
1181 |
+
if num_to_score > 2:
|
1182 |
+
mask = torch.eye(tokens.size(1), device=device).bool()[-(num_to_score+1)+2:-1, :]
|
1183 |
+
input_ids[:-2, :] = input_ids[:-2, :].masked_fill(mask, value=mask_index)
|
1184 |
+
|
1185 |
+
indices = torch.arange(input_ids.size(1) - num_to_score - 1, input_ids.size(1) - 1, device=device)
|
1186 |
+
total_score = []
|
1187 |
+
|
1188 |
+
for b in range((input_ids.size(0) - 1) // batch_size + 1):
|
1189 |
+
logits = self(
|
1190 |
+
input_ids[b * batch_size : (b+1) * batch_size, :].contiguous(),
|
1191 |
+
).logits
|
1192 |
+
|
1193 |
+
logits = torch.gather(
|
1194 |
+
logits,
|
1195 |
+
dim=1,
|
1196 |
+
index=indices[b * batch_size : (b+1) * batch_size].reshape(-1, 1, 1).expand(-1, -1, logits.size(-1))
|
1197 |
+
).squeeze(1)
|
1198 |
+
log_p = F.log_softmax(logits, dim=-1)
|
1199 |
+
|
1200 |
+
log_p = log_p.gather(
|
1201 |
+
index=tokens[0, -(num_to_score+1):-1][b * batch_size : (b+1) * batch_size].unsqueeze(-1),
|
1202 |
+
dim=-1
|
1203 |
+
).squeeze(-1)
|
1204 |
+
total_score.append(log_p)
|
1205 |
+
|
1206 |
+
total_score = torch.cat(total_score)
|
1207 |
+
return total_score.sum().item()
|
1208 |
+
|
1209 |
+
|
1210 |
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
|
1211 |
class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
|
1212 |
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|