File size: 1,063 Bytes
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import math

from torch import Tensor, topk
from torchmetrics.retrieval.base import RetrievalMetric
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


class HitRate(RetrievalMetric):
    """
    Computes hit rate for virtual screening.
    """
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    def __init__(
            self,
            alpha: float = 0.01,
    ):
        super().__init__()
        if alpha <= 0 or alpha > 1:
            raise ValueError(f"Argument ``alpha`` has to be in interval (0, 1] but got {alpha}")
        self.alpha = alpha

    def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
        preds, target = _check_retrieval_functional_inputs(preds, target)

        n_total = target.size(0)
        n_sampled = math.ceil(n_total * self.alpha)
        _, idx = topk(preds, n_sampled)
        hits_sampled = target[idx].sum()

        return hits_sampled / n_sampled

    def plot(self, val=None, ax=None):
        return self._plot(val, ax)