Spaces:
No application file
No application file
""" | |
Based on https://github.com/huggingface/transformers/issues/80 | |
""" | |
import json | |
import argparse | |
import glob | |
import sys | |
import logging | |
import os | |
import time | |
import string | |
from filelock import FileLock | |
import numpy as np | |
import pickle | |
import torch | |
from torch.utils.data import DataLoader, TensorDataset | |
from .base import BaseModule, create_trainer | |
from ..data.examples import InputFeatures | |
from collections import ChainMap | |
from torch.utils.data import DataLoader, TensorDataset | |
logger = logging.getLogger(__name__) | |
class MaskedLM(BaseModule): | |
mode = 'language-modeling' | |
output_mode = 'classification' | |
example_type = 'multiple-choice' | |
def __init__(self, hparams): | |
super().__init__(hparams) | |
self.mask_id = self.tokenizer.convert_tokens_to_ids('[MASK]') | |
self.test_results_fpath = 'test_results' | |
if os.path.exists(self.test_results_fpath): | |
os.remove(self.test_results_fpath) | |
def convert_examples_to_features(self, examples): | |
batch_encoding = self.tokenizer( | |
[example.question for example in examples], | |
max_length=self.hparams['max_seq_length'], | |
padding='max_length', | |
truncation=True, | |
) | |
features = [] | |
for i in range(len(examples)): | |
inputs = {k: batch_encoding[k][i] for k in batch_encoding} | |
candidates = examples[i].endings | |
tokens = [self.tokenizer.tokenize(cand) for cand in candidates] | |
token_candidates = [] | |
for toks in tokens: | |
if len(toks) == 0: | |
token_candidates.append(self.tokenizer.unk_token) | |
else: | |
token_candidates.append(max(toks, key=lambda t: len(t.strip(string.punctuation)))) | |
candidate_ids = self.tokenizer.convert_tokens_to_ids(token_candidates) | |
feature = InputFeatures(**inputs, candidates=candidate_ids, label=examples[i].label) | |
features.append(feature) | |
return features | |
def test_dataloader(self): | |
mode = 'test' | |
cached_features_file = self._feature_file(mode) | |
if os.path.exists(cached_features_file) and not self.hparams['overwrite_cache']: | |
features = torch.load(cached_features_file) | |
else: | |
features = self.load_features(mode) | |
torch.save(features, cached_features_file) | |
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) | |
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) | |
all_token_type_ids = torch.tensor([f.token_type_ids or 0 for f in features], dtype=torch.long) | |
all_labels = torch.tensor([f.label for f in features], dtype=torch.long) | |
all_cands = torch.tensor([f.candidates for f in features], dtype=torch.long) | |
all_answers = torch.tensor([f.label for f in features], dtype=torch.long) | |
return DataLoader( | |
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels, all_cands, all_answers), | |
batch_size=self.hparams['eval_batch_size'], | |
) | |
def test_step(self, batch, batch_idx): | |
inputs = {'input_ids': batch[0], 'token_type_ids': batch[2], | |
'attention_mask': batch[1]} | |
answers = batch[3].detach().cpu().numpy() | |
candidates = batch[4].detach().cpu().numpy() | |
# get first mask location | |
input_ids = batch[0].detach().cpu().numpy() | |
mask_ids = (input_ids == self.mask_id).argmax(axis=1) | |
mask_ids = torch.from_numpy(mask_ids) | |
predictions = self(**inputs)[0] | |
i = torch.arange(0, predictions.shape[0], dtype=torch.int64) | |
predictions = predictions[i, mask_ids] | |
predictions = predictions.detach().cpu().numpy() | |
right, wrong = 0, 0 | |
for i, pred in enumerate(predictions): | |
prob = pred[candidates[i]] | |
pred_answer = int(np.argmax(prob)) | |
if answers[i] == pred_answer: | |
right += 1 | |
else: | |
wrong += 1 | |
return {"right": right, "wrong": wrong} | |
def test_epoch_end(self, outputs): | |
right = sum(output['right'] for output in outputs) | |
wrong = sum(output['wrong'] for output in outputs) | |
merged = {'right': right, 'wrong': wrong} | |
with FileLock(self.test_results_fpath + '.lock'): | |
if os.path.exists(self.test_results_fpath): | |
with open(self.test_results_fpath, 'rb') as fp: | |
data = pickle.load(fp) | |
data = {'right': data['right'] + merged['right'], 'wrong': data['wrong'] + merged['wrong']} | |
else: | |
data = merged | |
with open(self.test_results_fpath, 'wb') as fp: | |
pickle.dump(data, fp) | |
return data | |
def add_model_specific_args(parser, root_dir): | |
return parser | |
def run_module(self): | |
self.eval() | |
self.freeze() | |
torch.no_grad() | |
trainer = create_trainer(self, self.hparams) | |
trainer.test(self) | |
preds = pickle.load(open(self.test_results_fpath, 'rb')) | |
correct, wrong = preds['right'], preds['wrong'] | |
with open(os.path.join(self.hparams['output_dir'], 'test_results.txt'), 'w') as fp: | |
json.dump({'test_acc': correct/(correct + wrong)}, fp) | |