libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
1.32 kB
import torch
from torchmetrics import Metric
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ConcordanceIndex.plot"]
class ConcordanceIndex(Metric):
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.5
plot_upper_bound: float = 1.0
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("num_concordant", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("num_valid", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
_check_same_shape(preds, target)
g = preds.unsqueeze(-1) - preds
g = (g == 0) * 0.5 + (g > 0)
f = (target.unsqueeze(-1) - target) > 0
f = torch.tril(f, diagonal=0)
self.num_concordant += torch.sum(torch.mul(g, f)).long()
self.num_valid += torch.sum(f).long()
def compute(self):
return torch.where(self.num_valid == 0, 0.0, self.num_concordant / self.num_valid)
def plot(self, val=None, ax=None):
return self._plot(val, ax)