Shaltiel commited on
Commit
a039d5e
·
1 Parent(s): 0ca7460

Upload BertForLexPrediction.py

Browse files
Files changed (1) hide show
  1. BertForLexPrediction.py +37 -0
BertForLexPrediction.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union
3
+ from transformers import BertForMaskedLM, BertTokenizerFast
4
+
5
+ class BertForLexPrediction(BertForMaskedLM):
6
+
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+
10
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast):
11
+ if isinstance(sentences, str):
12
+ sentences = [sentences]
13
+
14
+ # predict the logits for the sentence
15
+ inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
16
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
17
+ logits = self.forward(**inputs, return_dict=True).logits
18
+
19
+ # for each token, we will take the top 10, and search for one that is appropriate. If none, then
20
+ # return a [BLANK] for that word.
21
+ input_ids = inputs['input_ids']
22
+ batch_ret = []
23
+ for batch_idx in range(len(sentences)):
24
+ ret = []
25
+ batch_ret.append(ret)
26
+ for tok_idx in range(input_ids.shape[1]):
27
+ token_id = input_ids[batch_idx, tok_idx]
28
+ # ignore cls, sep, pad
29
+ if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
30
+
31
+ token = tokenizer._convert_id_to_token(token_id)
32
+ # wordpieces should just be appended to the previous word
33
+ if token.startswith('##'):
34
+ ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
35
+ continue
36
+ ret.append((token, tokenizer._convert_id_to_token(torch.argmax(logits[batch_idx, tok_idx]))))
37
+ return batch_ret