Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 1,063 Bytes
c0ec7e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import math
from torch import Tensor, topk
from torchmetrics.retrieval.base import RetrievalMetric
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
class HitRate(RetrievalMetric):
"""
Computes hit rate for virtual screening.
"""
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
def __init__(
self,
alpha: float = 0.01,
):
super().__init__()
if alpha <= 0 or alpha > 1:
raise ValueError(f"Argument ``alpha`` has to be in interval (0, 1] but got {alpha}")
self.alpha = alpha
def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
preds, target = _check_retrieval_functional_inputs(preds, target)
n_total = target.size(0)
n_sampled = math.ceil(n_total * self.alpha)
_, idx = topk(preds, n_sampled)
hits_sampled = target[idx].sum()
return hits_sampled / n_sampled
def plot(self, val=None, ax=None):
return self._plot(val, ax) |