herrius's picture
Upload 259 files
32b542e
raw
history blame
2.2 kB
import os
import sys
import pickle
import json
from json import encoder
from uniperceiver.config import configurable
from .build import EVALUATION_REGISTRY
import numpy as np
from sklearn.metrics import f1_score, matthews_corrcoef
from scipy.stats import pearsonr, spearmanr
def simple_accuracy(preds, labels):
return (preds == labels).mean()
@EVALUATION_REGISTRY.register()
class GLUEEvaler(object):
def __init__(self, cfg, *args, **kwargs):
super(GLUEEvaler, self).__init__()
self.task_name = cfg.DATASETS.DATASET_NAME
self.tasks = [""]
def eval(self, results, epoch):
preds = []
labels = []
for result in results:
# cls task
if self.task_name != 'STS-B':
preds.append(result["pred"].argmax().item())
labels.append(int(result["label"]))
else:
# regression task
preds.append(float(result["pred"].sigmoid().item()))
labels.append(float(result["label"]))
preds = np.array(preds)
labels = np.array(labels)
if self.task_name == 'CoLA':
acc = simple_accuracy(preds, labels)
matthewscorr = matthews_corrcoef(labels, preds)
result = {
"accuracy": acc,
"matthews_corrcoef": matthewscorr,
}
elif self.task_name in [ 'QNLI', 'RTE', 'SST-2'] or self.task_name.startswith("MNLI"):
acc = simple_accuracy(preds, labels)
result = {
"accuracy": acc,
}
elif self.task_name in ['MRPC', 'QQP']:
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
result = {
"accuracy": acc,
"f1_score": f1,
}
elif self.task_name in ['STS-B']:
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
result ={
"pearson_corr": pearson_corr,
"spearman_corr": spearman_corr,
}
else:
raise NotImplementedError
return result