libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
1.3 kB
import torch
from torch import Tensor
from torchmetrics.retrieval.base import RetrievalMetric
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
def calc_rie(n_total, active_ranks, r_a, exp_a):
numerator = (exp_a ** (- active_ranks / n_total)).sum()
denominator = (1 - exp_a ** (-1)) / (exp_a ** (1 / n_total) - 1)
return numerator / (r_a * denominator)
class RIE(RetrievalMetric):
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
def __init__(
self,
alpha: float = 80.5,
):
super().__init__()
self.alpha = alpha
def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
preds, target = _check_retrieval_functional_inputs(preds, target)
n_total = target.size(0)
n_actives = target.sum()
if n_actives == 0:
return torch.tensor(0.0, device=preds.device)
r_a = n_actives / n_total
exp_a = torch.exp(torch.tensor(-self.alpha))
idx = torch.argsort(preds, descending=True, stable=True)
active_ranks = torch.take(target, idx).nonzero() + 1
return calc_rie(n_total, active_ranks, r_a, exp_a)
def plot(self, val=None, ax=None):
return self._plot(val, ax)