Spaces:
No application file
No application file
import torch | |
import numpy as np | |
from .base import BaseModule | |
from .utils import mean_accuracy | |
class MultipleChoice(BaseModule): | |
mode = 'multiple-choice' | |
output_mode = 'classification' | |
example_type = 'multiple-choice' | |
def __init__(self, hparams): | |
super().__init__(hparams) | |
def _eval_end(self, outputs): | |
val_loss_mean = torch.stack([x['val_loss'] for x in outputs])\ | |
.mean().detach().cpu().item() | |
preds = np.concatenate([x['pred'] for x in outputs], axis=0) | |
preds = np.argmax(preds, axis=1) | |
out_label_ids = np.concatenate([x['target'] for x in outputs], axis=0) | |
out_label_list = [[] for _ in range(out_label_ids.shape[0])] | |
preds_list = [[] for _ in range(out_label_ids.shape[0])] | |
results = {**{'val_loss': val_loss_mean}, | |
**mean_accuracy(preds, out_label_ids)} | |
ret = {k: v for k, v in results.items()} | |
ret['log'] = results | |
return ret, preds_list, out_label_list | |
def validation_epoch_end(self, outputs: list) -> dict: | |
ret, preds, targets = self._eval_end(outputs) | |
logs = ret['log'] | |
return {'val_loss': logs['val_loss'], 'log': logs, 'progress_bar': logs} | |
def test_epoch_end(self, outputs): | |
ret, predictions, targets = self._eval_end(outputs) | |
# Converting to the dic required by pl | |
logs = ret['log'] | |
# `val_loss` is the key returned by `self._eval_end()` | |
# but actually refers to `test_loss` | |
return {'avg_test_loss': logs['val_loss'], | |
'log': logs, 'progress_bar': logs} | |
def add_model_specific_args(parser, root_dir): | |
return parser |