File size: 2,265 Bytes
107cd34 677d4b0 107cd34 677d4b0 107cd34 677d4b0 107cd34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import re
import os
import fire
import torch
from functools import partial
from transformers import AutoTokenizer
from transformers import AutoModelForPreTraining
from pya0.preprocess import preprocess_for_transformer
def highlight_masked(txt):
return re.sub(r"(\[MASK\])", '\033[92m' + r"\1" + '\033[0m', txt)
def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs):
unmask_scores, seq_rel_scores = outputs
MSK_CODE = 103
token_ids = tokens['input_ids'][0]
masked_idx = (token_ids == torch.tensor([MSK_CODE]))
scores = unmask_scores[0][masked_idx]
cands = torch.argsort(scores, dim=1, descending=True)
for i, mask_cands in enumerate(cands):
top_cands = mask_cands[:topk].detach().cpu()
print(f'MASK[{i}] top candidates: ' +
str(tokenizer.convert_ids_to_tokens(top_cands)))
def test(model_name_or_path, tokenizer_name_or_path, test_file='test.txt'):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
model = AutoModelForPreTraining.from_pretrained(model_name_or_path,
tie_word_embeddings=True
)
with open(test_file, 'r') as fh:
for line in fh:
# parse test file line
line = line.rstrip()
fields = line.split('\t')
maskpos = list(map(int, fields[0].split(',')))
# preprocess and mask words
sentence = preprocess_for_transformer(fields[1])
tokens = sentence.split()
for pos in filter(lambda x: x!=0, maskpos):
tokens[pos-1] = '[MASK]'
sentence = ' '.join(tokens)
tokens = tokenizer(sentence,
padding=True, truncation=True, return_tensors="pt")
#print(tokenizer.decode(tokens['input_ids'][0]))
print('*', highlight_masked(sentence))
# print unmasked
with torch.no_grad():
display = ['\n', '']
classifier = model.cls
partial_hook = partial(classifier_hook, tokenizer, tokens, 3)
hook = classifier.register_forward_hook(partial_hook)
model(**tokens)
hook.remove()
if __name__ == '__main__':
os.environ["PAGER"] = 'cat'
fire.Fire(test)
|