|
import numpy as np |
|
from sklearn import metrics |
|
from collections import defaultdict |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def get_accracy(output, label): |
|
_, prediction = torch.max(output, 1) |
|
correct = (prediction == label).sum().item() |
|
accuracy = correct / prediction.size(0) |
|
return accuracy |
|
|
|
|
|
def get_prediction(output, label): |
|
prob = nn.functional.softmax(output, dim=1)[:, 1] |
|
prob = prob.view(prob.size(0), 1) |
|
label = label.view(label.size(0), 1) |
|
|
|
datas = torch.cat((prob, label.float()), dim=1) |
|
return datas |
|
|
|
|
|
def calculate_metrics_for_train(label, output): |
|
if output.size(1) == 2: |
|
prob = torch.softmax(output, dim=1)[:, 1] |
|
else: |
|
prob = output |
|
|
|
|
|
_, prediction = torch.max(output, 1) |
|
correct = (prediction == label).sum().item() |
|
accuracy = correct / prediction.size(0) |
|
|
|
|
|
y_true = label.cpu().detach().numpy() |
|
y_pred = prob.cpu().detach().numpy() |
|
ap = metrics.average_precision_score(y_true, y_pred) |
|
|
|
|
|
try: |
|
fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(), |
|
prob.squeeze().cpu().numpy(), |
|
pos_label=1) |
|
except: |
|
|
|
return None, None, accuracy, ap |
|
|
|
if np.isnan(fpr[0]) or np.isnan(tpr[0]): |
|
|
|
auc, eer = None, None |
|
else: |
|
auc = metrics.auc(fpr, tpr) |
|
fnr = 1 - tpr |
|
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] |
|
|
|
return auc, eer, accuracy, ap |
|
|
|
|
|
|
|
class Metrics_batch(): |
|
def __init__(self): |
|
self.tprs = [] |
|
self.mean_fpr = np.linspace(0, 1, 100) |
|
self.aucs = [] |
|
self.eers = [] |
|
self.aps = [] |
|
|
|
self.correct = 0 |
|
self.total = 0 |
|
self.losses = [] |
|
|
|
def update(self, label, output): |
|
acc = self._update_acc(label, output) |
|
if output.size(1) == 2: |
|
prob = torch.softmax(output, dim=1)[:, 1] |
|
else: |
|
prob = output |
|
|
|
|
|
auc, eer = self._update_auc(label, prob) |
|
ap = self._update_ap(label, prob) |
|
|
|
return acc, auc, eer, ap |
|
|
|
def _update_auc(self, lab, prob): |
|
fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(), |
|
prob.squeeze().cpu().numpy(), |
|
pos_label=1) |
|
if np.isnan(fpr[0]) or np.isnan(tpr[0]): |
|
return -1, -1 |
|
|
|
auc = metrics.auc(fpr, tpr) |
|
interp_tpr = np.interp(self.mean_fpr, fpr, tpr) |
|
interp_tpr[0] = 0.0 |
|
self.tprs.append(interp_tpr) |
|
self.aucs.append(auc) |
|
|
|
|
|
|
|
|
|
fnr = 1 - tpr |
|
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] |
|
self.eers.append(eer) |
|
|
|
return auc, eer |
|
|
|
def _update_acc(self, lab, output): |
|
_, prediction = torch.max(output, 1) |
|
correct = (prediction == lab).sum().item() |
|
accuracy = correct / prediction.size(0) |
|
|
|
self.correct = self.correct+correct |
|
self.total = self.total+lab.size(0) |
|
return accuracy |
|
|
|
def _update_ap(self, label, prob): |
|
y_true = label.cpu().detach().numpy() |
|
y_pred = prob.cpu().detach().numpy() |
|
ap = metrics.average_precision_score(y_true,y_pred) |
|
self.aps.append(ap) |
|
|
|
return np.mean(ap) |
|
|
|
def get_mean_metrics(self): |
|
mean_acc, std_acc = self.correct/self.total, 0 |
|
mean_auc, std_auc = self._mean_auc() |
|
mean_err, std_err = np.mean(self.eers), np.std(self.eers) |
|
mean_ap, std_ap = np.mean(self.aps), np.std(self.aps) |
|
|
|
return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap} |
|
|
|
def _mean_auc(self): |
|
mean_tpr = np.mean(self.tprs, axis=0) |
|
mean_tpr[-1] = 1.0 |
|
mean_auc = metrics.auc(self.mean_fpr, mean_tpr) |
|
std_auc = np.std(self.aucs) |
|
return mean_auc, std_auc |
|
|
|
def clear(self): |
|
self.tprs.clear() |
|
self.aucs.clear() |
|
|
|
self.correct=0 |
|
self.total=0 |
|
self.eers.clear() |
|
self.aps.clear() |
|
self.losses.clear() |
|
|
|
|
|
|
|
class Metrics_all(): |
|
def __init__(self): |
|
self.probs = [] |
|
self.labels = [] |
|
self.correct = 0 |
|
self.total = 0 |
|
|
|
def store(self, label, output): |
|
prob = torch.softmax(output, dim=1)[:, 1] |
|
_, prediction = torch.max(output, 1) |
|
correct = (prediction == label).sum().item() |
|
self.correct += correct |
|
self.total += label.size(0) |
|
self.labels.append(label.squeeze().cpu().numpy()) |
|
self.probs.append(prob.squeeze().cpu().numpy()) |
|
|
|
def get_metrics(self): |
|
y_pred = np.concatenate(self.probs) |
|
y_true = np.concatenate(self.labels) |
|
|
|
fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1) |
|
auc = metrics.auc(fpr, tpr) |
|
|
|
fnr = 1 - tpr |
|
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] |
|
|
|
ap = metrics.average_precision_score(y_true,y_pred) |
|
|
|
acc = self.correct / self.total |
|
return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap} |
|
|
|
def clear(self): |
|
self.probs.clear() |
|
self.labels.clear() |
|
self.correct = 0 |
|
self.total = 0 |
|
|
|
|
|
|
|
class Recorder: |
|
def __init__(self): |
|
self.sum = 0 |
|
self.num = 0 |
|
def update(self, item, num=1): |
|
if item is not None: |
|
self.sum += item * num |
|
self.num += num |
|
def average(self): |
|
if self.num == 0: |
|
return None |
|
return self.sum/self.num |
|
def clear(self): |
|
self.sum = 0 |
|
self.num = 0 |
|
|