Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from torch import Tensor | |
from torchmetrics.retrieval.base import RetrievalMetric | |
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs | |
def calc_rie(n_total, active_ranks, r_a, exp_a): | |
numerator = (exp_a ** (- active_ranks / n_total)).sum() | |
denominator = (1 - exp_a ** (-1)) / (exp_a ** (1 / n_total) - 1) | |
return numerator / (r_a * denominator) | |
class RIE(RetrievalMetric): | |
is_differentiable: bool = False | |
higher_is_better: bool = True | |
full_state_update: bool = False | |
def __init__( | |
self, | |
alpha: float = 80.5, | |
): | |
super().__init__() | |
self.alpha = alpha | |
def _metric(self, preds: Tensor, target: Tensor) -> Tensor: | |
preds, target = _check_retrieval_functional_inputs(preds, target) | |
n_total = target.size(0) | |
n_actives = target.sum() | |
if n_actives == 0: | |
return torch.tensor(0.0, device=preds.device) | |
r_a = n_actives / n_total | |
exp_a = torch.exp(torch.tensor(-self.alpha)) | |
idx = torch.argsort(preds, descending=True, stable=True) | |
active_ranks = torch.take(target, idx).nonzero() + 1 | |
return calc_rie(n_total, active_ranks, r_a, exp_a) | |
def plot(self, val=None, ax=None): | |
return self._plot(val, ax) | |