File size: 4,900 Bytes
42185a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Optional, Any, Callable, List
import torch
import torchmetrics
from torchmetrics.metric import Metric
from torchmetrics import AUROC, PrecisionRecallCurve
from torchmetrics.functional import auroc
from torchmetrics.utilities.data import dim_zero_cat
import logging
import numpy as np
class PR_AUC(Metric):
def __init__(self, num_classes, compute_on_step=False, dist_sync_on_step=False):
super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step)
self.add_state("prauc", default=[], dist_reduce_fx='cat')
self.pr_curve = PrecisionRecallCurve(num_classes=num_classes).to(self.device)
self.auc = torchmetrics.AUC().to(self.device)
def update(self, prediction: torch.Tensor, target: torch.Tensor):
precision, recall, thresholds = self.pr_curve(prediction, target)
auc_values = [self.auc(r, p) for r, p in zip(recall, precision)]
pr_auc = torch.mean(torch.tensor([v for v in auc_values if not v.isnan()])).to(self.device)
self.prauc += [pr_auc.detach()]
def compute(self):
return torch.mean(self.prauc.detach())
class PR_AUCPerBucket(PR_AUC):
def __init__(self, num_classes, bucket, compute_on_step=False, dist_sync_on_step=False):
super().__init__(num_classes=len(bucket), compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step)
self.bucket = set(bucket)
self.num_classes = num_classes
def update(self, prediction: torch.Tensor, target: torch.Tensor):
mask = np.zeros((self.num_classes), dtype=bool)
for c in range(self.num_classes):
if c in self.bucket:
mask[c] = True
filtered_target = target[:, mask]
filtered_preds = prediction[:, mask]
if len((filtered_target > 0).nonzero()) > 0:
precision, recall, thresholds = self.pr_curve(filtered_preds, filtered_target)
auc_values = [self.auc(r, p) for r, p in zip(recall, precision)]
pr_auc = torch.mean(torch.tensor([v for v in auc_values if not v.isnan()])).to(self.device)
self.prauc += [pr_auc.detach()]
def calculate_pr_auc(prediction: torch.Tensor, target: torch.Tensor, num_classes, device):
pr_curve = PrecisionRecallCurve(num_classes=num_classes).to(device)
auc = torchmetrics.AUC().to(device)
precision, recall, thresholds = pr_curve(prediction, target)
auc_values = [auc(r, p) for r, p in zip(recall, precision)]
pr_auc = torch.mean(torch.tensor([v for v in auc_values if not v.isnan()])).to(device)
return pr_auc.detach()
class FilteredAUROC(AUROC):
def compute(self) -> torch.Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
mask = np.ones((self.num_classes), dtype=bool)
for c in range(self.num_classes):
if torch.max(target[:, c]) == 0:
mask[c] = False
filtered_target = target[:, mask]
filtered_preds = preds[:, mask]
num_filtered_cols = np.count_nonzero(mask == False)
logging.info(f"{num_filtered_cols} columns not considered for ROC AUC calculation!")
return _auroc_compute(
filtered_preds,
filtered_target,
self.mode,
self.num_classes - num_filtered_cols,
self.pos_label,
self.average,
self.max_fpr,
)
class FilteredAUROCPerBucket(AUROC):
def __init__(
self,
bucket: List[int],
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
max_fpr: Optional[float] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None
):
super().__init__(num_classes, pos_label, average, max_fpr, compute_on_step, dist_sync_on_step, process_group,
dist_sync_fn)
self.bucket = set(bucket)
def compute(self) -> torch.Tensor:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
mask = np.zeros((self.num_classes), dtype=bool)
for c in range(self.num_classes):
if torch.max(target[:, c]) > 0 and c in self.bucket:
mask[c] = True
filtered_target = target[:, mask]
filtered_preds = preds[:, mask]
num_filtered_cols = np.count_nonzero(mask == False)
logging.info(f"{num_filtered_cols} columns not considered for ROC AUC calculation!")
return _auroc_compute(
filtered_preds,
filtered_target,
self.mode,
self.num_classes - num_filtered_cols,
self.pos_label,
self.average,
self.max_fpr,
)
|