|
import os |
|
import sys |
|
import pickle |
|
import json |
|
from json import encoder |
|
from uniperceiver.config import configurable |
|
from .build import EVALUATION_REGISTRY |
|
|
|
import numpy as np |
|
from sklearn.metrics import f1_score, matthews_corrcoef |
|
from scipy.stats import pearsonr, spearmanr |
|
|
|
def simple_accuracy(preds, labels): |
|
return (preds == labels).mean() |
|
|
|
@EVALUATION_REGISTRY.register() |
|
class GLUEEvaler(object): |
|
def __init__(self, cfg, *args, **kwargs): |
|
super(GLUEEvaler, self).__init__() |
|
self.task_name = cfg.DATASETS.DATASET_NAME |
|
self.tasks = [""] |
|
|
|
|
|
|
|
def eval(self, results, epoch): |
|
preds = [] |
|
labels = [] |
|
for result in results: |
|
|
|
if self.task_name != 'STS-B': |
|
preds.append(result["pred"].argmax().item()) |
|
labels.append(int(result["label"])) |
|
|
|
else: |
|
|
|
preds.append(float(result["pred"].sigmoid().item())) |
|
labels.append(float(result["label"])) |
|
|
|
preds = np.array(preds) |
|
labels = np.array(labels) |
|
|
|
if self.task_name == 'CoLA': |
|
acc = simple_accuracy(preds, labels) |
|
matthewscorr = matthews_corrcoef(labels, preds) |
|
result = { |
|
"accuracy": acc, |
|
"matthews_corrcoef": matthewscorr, |
|
} |
|
elif self.task_name in [ 'QNLI', 'RTE', 'SST-2'] or self.task_name.startswith("MNLI"): |
|
acc = simple_accuracy(preds, labels) |
|
result = { |
|
"accuracy": acc, |
|
} |
|
elif self.task_name in ['MRPC', 'QQP']: |
|
acc = simple_accuracy(preds, labels) |
|
f1 = f1_score(y_true=labels, y_pred=preds) |
|
result = { |
|
"accuracy": acc, |
|
"f1_score": f1, |
|
} |
|
elif self.task_name in ['STS-B']: |
|
pearson_corr = pearsonr(preds, labels)[0] |
|
spearman_corr = spearmanr(preds, labels)[0] |
|
result ={ |
|
"pearson_corr": pearson_corr, |
|
"spearman_corr": spearman_corr, |
|
} |
|
else: |
|
raise NotImplementedError |
|
|
|
return result |
|
|