Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from torchmetrics import Metric | |
from torchmetrics.utilities.checks import _check_same_shape | |
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE | |
if not _MATPLOTLIB_AVAILABLE: | |
__doctest_skip__ = ["ConcordanceIndex.plot"] | |
class ConcordanceIndex(Metric): | |
is_differentiable: bool = False | |
higher_is_better: bool = True | |
full_state_update: bool = False | |
plot_lower_bound: float = 0.5 | |
plot_upper_bound: float = 1.0 | |
def __init__(self, dist_sync_on_step=False): | |
super().__init__(dist_sync_on_step=dist_sync_on_step) | |
self.add_state("num_concordant", default=torch.tensor(0), dist_reduce_fx="sum") | |
self.add_state("num_valid", default=torch.tensor(0), dist_reduce_fx="sum") | |
def update(self, preds: torch.Tensor, target: torch.Tensor): | |
_check_same_shape(preds, target) | |
g = preds.unsqueeze(-1) - preds | |
g = (g == 0) * 0.5 + (g > 0) | |
f = (target.unsqueeze(-1) - target) > 0 | |
f = torch.tril(f, diagonal=0) | |
self.num_concordant += torch.sum(torch.mul(g, f)).long() | |
self.num_valid += torch.sum(f).long() | |
def compute(self): | |
return torch.where(self.num_valid == 0, 0.0, self.num_concordant / self.num_valid) | |
def plot(self, val=None, ax=None): | |
return self._plot(val, ax) | |