File size: 4,900 Bytes
42185a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import Optional, Any, Callable, List

import torch
import torchmetrics

from torchmetrics.metric import Metric
from torchmetrics import AUROC, PrecisionRecallCurve
from torchmetrics.functional import auroc
from torchmetrics.utilities.data import dim_zero_cat
import logging
import numpy as np


class PR_AUC(Metric):
    def __init__(self, num_classes, compute_on_step=False, dist_sync_on_step=False):
        super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step)
        self.add_state("prauc", default=[], dist_reduce_fx='cat')
        self.pr_curve = PrecisionRecallCurve(num_classes=num_classes).to(self.device)
        self.auc = torchmetrics.AUC().to(self.device)

    def update(self, prediction: torch.Tensor, target: torch.Tensor):
        precision, recall, thresholds = self.pr_curve(prediction, target)
        auc_values = [self.auc(r, p) for r, p in zip(recall, precision)]

        pr_auc = torch.mean(torch.tensor([v for v in auc_values if not v.isnan()])).to(self.device)
        self.prauc += [pr_auc.detach()]

    def compute(self):
        return torch.mean(self.prauc.detach())


class PR_AUCPerBucket(PR_AUC):
    def __init__(self, num_classes, bucket, compute_on_step=False, dist_sync_on_step=False):
        super().__init__(num_classes=len(bucket), compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step)
        self.bucket = set(bucket)
        self.num_classes = num_classes

    def update(self, prediction: torch.Tensor, target: torch.Tensor):

        mask = np.zeros((self.num_classes), dtype=bool)
        for c in range(self.num_classes):
            if c in self.bucket:
                mask[c] = True
        filtered_target = target[:, mask]
        filtered_preds = prediction[:, mask]

        if len((filtered_target > 0).nonzero()) > 0:
            precision, recall, thresholds = self.pr_curve(filtered_preds, filtered_target)
            auc_values = [self.auc(r, p) for r, p in zip(recall, precision)]

            pr_auc = torch.mean(torch.tensor([v for v in auc_values if not v.isnan()])).to(self.device)
            self.prauc += [pr_auc.detach()]


def calculate_pr_auc(prediction: torch.Tensor, target: torch.Tensor, num_classes, device):
    pr_curve = PrecisionRecallCurve(num_classes=num_classes).to(device)
    auc = torchmetrics.AUC().to(device)

    precision, recall, thresholds = pr_curve(prediction, target)
    auc_values = [auc(r, p) for r, p in zip(recall, precision)]

    pr_auc = torch.mean(torch.tensor([v for v in auc_values if not v.isnan()])).to(device)
    return pr_auc.detach()


class FilteredAUROC(AUROC):
    def compute(self) -> torch.Tensor:

        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)

        mask = np.ones((self.num_classes), dtype=bool)
        for c in range(self.num_classes):
            if torch.max(target[:, c]) == 0:
                mask[c] = False
        filtered_target = target[:, mask]
        filtered_preds = preds[:, mask]

        num_filtered_cols = np.count_nonzero(mask == False)
        logging.info(f"{num_filtered_cols} columns not considered for ROC AUC calculation!")

        return _auroc_compute(
            filtered_preds,
            filtered_target,
            self.mode,
            self.num_classes - num_filtered_cols,
            self.pos_label,
            self.average,
            self.max_fpr,
        )


class FilteredAUROCPerBucket(AUROC):
    def __init__(
            self,
            bucket: List[int],
            num_classes: Optional[int] = None,
            pos_label: Optional[int] = None,
            average: Optional[str] = "macro",
            max_fpr: Optional[float] = None,
            compute_on_step: bool = True,
            dist_sync_on_step: bool = False,
            process_group: Optional[Any] = None,
            dist_sync_fn: Callable = None
    ):
        super().__init__(num_classes, pos_label, average, max_fpr, compute_on_step, dist_sync_on_step, process_group,
                         dist_sync_fn)
        self.bucket = set(bucket)

    def compute(self) -> torch.Tensor:

        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)

        mask = np.zeros((self.num_classes), dtype=bool)
        for c in range(self.num_classes):
            if torch.max(target[:, c]) > 0 and c in self.bucket:
                mask[c] = True
        filtered_target = target[:, mask]
        filtered_preds = preds[:, mask]

        num_filtered_cols = np.count_nonzero(mask == False)
        logging.info(f"{num_filtered_cols} columns not considered for ROC AUC calculation!")

        return _auroc_compute(
            filtered_preds,
            filtered_target,
            self.mode,
            self.num_classes - num_filtered_cols,
            self.pos_label,
            self.average,
            self.max_fpr,
        )