ltg
/

davda54 commited on
Commit
3cd41b8
1 Parent(s): e836ed3

Update modeling_deberta.py

Browse files
Files changed (1) hide show
  1. modeling_deberta.py +49 -0
modeling_deberta.py CHANGED
@@ -1158,6 +1158,55 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
1158
  attentions=outputs.attentions,
1159
  )
1160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1161
  @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
1162
  class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
1163
  _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
 
1158
  attentions=outputs.attentions,
1159
  )
1160
 
1161
+ @torch.no_grad()
1162
+ def score(self, sequence: str, scored_length: int, tokenizer, device, batch_size):
1163
+ mask_index = tokenizer.mask_token_id
1164
+ cls_index = torch.tensor([tokenizer.cls_token_id])
1165
+ sep_index = torch.tensor([tokenizer.sep_token_id])
1166
+
1167
+ encoding = tokenizer(sequence, add_special_tokens=False, return_tensors="pt")
1168
+ num_words = max(i for i in encoding.word_ids() if i is not None) + 1
1169
+ scored_mask = [i and i >= num_words - scored_length for i in encoding.word_ids()]
1170
+ num_to_score = sum(scored_mask)
1171
+
1172
+ tokens = encoding.input_ids.squeeze(0)
1173
+
1174
+ tokens = torch.cat([cls_index, tokens, sep_index]).to(device)
1175
+ tokens = tokens.repeat(num_to_score, 1)
1176
+ mask = torch.eye(tokens.size(1), device=device).bool()[-(num_to_score+1):-1, :]
1177
+ input_ids = tokens.masked_fill(mask, value=mask_index)
1178
+ if num_to_score > 1:
1179
+ mask = torch.eye(tokens.size(1), device=device).bool()[-(num_to_score+1)+1:-1, :]
1180
+ input_ids[:-1, :] = input_ids[:-1, :].masked_fill(mask, value=mask_index)
1181
+ if num_to_score > 2:
1182
+ mask = torch.eye(tokens.size(1), device=device).bool()[-(num_to_score+1)+2:-1, :]
1183
+ input_ids[:-2, :] = input_ids[:-2, :].masked_fill(mask, value=mask_index)
1184
+
1185
+ indices = torch.arange(input_ids.size(1) - num_to_score - 1, input_ids.size(1) - 1, device=device)
1186
+ total_score = []
1187
+
1188
+ for b in range((input_ids.size(0) - 1) // batch_size + 1):
1189
+ logits = self(
1190
+ input_ids[b * batch_size : (b+1) * batch_size, :].contiguous(),
1191
+ ).logits
1192
+
1193
+ logits = torch.gather(
1194
+ logits,
1195
+ dim=1,
1196
+ index=indices[b * batch_size : (b+1) * batch_size].reshape(-1, 1, 1).expand(-1, -1, logits.size(-1))
1197
+ ).squeeze(1)
1198
+ log_p = F.log_softmax(logits, dim=-1)
1199
+
1200
+ log_p = log_p.gather(
1201
+ index=tokens[0, -(num_to_score+1):-1][b * batch_size : (b+1) * batch_size].unsqueeze(-1),
1202
+ dim=-1
1203
+ ).squeeze(-1)
1204
+ total_score.append(log_p)
1205
+
1206
+ total_score = torch.cat(total_score)
1207
+ return total_score.sum().item()
1208
+
1209
+
1210
  @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
1211
  class DebertaV2ForCausalLM(DebertaV2ForMaskedLM):
1212
  _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]