libokj commited on
Commit
ba64d48
·
verified ·
1 Parent(s): 7471e37

Delete deepscreen/models/metrics

Browse files
deepscreen/models/metrics/__init__.py DELETED
File without changes
deepscreen/models/metrics/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (179 Bytes)
 
deepscreen/models/metrics/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (171 Bytes)
 
deepscreen/models/metrics/__pycache__/bedroc.cpython-311.pyc DELETED
Binary file (2.87 kB)
 
deepscreen/models/metrics/__pycache__/bedroc.cpython-39.pyc DELETED
Binary file (1.58 kB)
 
deepscreen/models/metrics/__pycache__/ci.cpython-311.pyc DELETED
Binary file (3.22 kB)
 
deepscreen/models/metrics/__pycache__/ef.cpython-311.pyc DELETED
Binary file (2.38 kB)
 
deepscreen/models/metrics/__pycache__/hit_rate.cpython-311.pyc DELETED
Binary file (2.37 kB)
 
deepscreen/models/metrics/__pycache__/hit_rate.cpython-39.pyc DELETED
Binary file (1.42 kB)
 
deepscreen/models/metrics/__pycache__/rie.cpython-311.pyc DELETED
Binary file (2.82 kB)
 
deepscreen/models/metrics/__pycache__/rie.cpython-39.pyc DELETED
Binary file (1.58 kB)
 
deepscreen/models/metrics/__pycache__/sensitivity.cpython-311.pyc DELETED
Binary file (17.6 kB)
 
deepscreen/models/metrics/bedroc.py DELETED
@@ -1,45 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
- from torchmetrics.retrieval.base import RetrievalMetric
4
- from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
5
-
6
- from deepscreen.models.metrics.rie import calc_rie
7
-
8
-
9
- class BEDROC(RetrievalMetric):
10
- is_differentiable: bool = False
11
- higher_is_better: bool = True
12
- full_state_update: bool = False
13
-
14
- def __init__(
15
- self,
16
- alpha: float = 80.5,
17
- ):
18
- super().__init__()
19
- self.alpha = alpha
20
-
21
- def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
22
- preds, target = _check_retrieval_functional_inputs(preds, target)
23
-
24
- n_total = target.size(0)
25
- n_actives = target.sum()
26
-
27
- if n_actives == 0:
28
- return torch.tensor(0.0, device=preds.device)
29
- elif n_actives == n_total:
30
- return torch.tensor(1.0, device=preds.device)
31
-
32
- r_a = n_actives / n_total
33
- exp_a = torch.exp(torch.tensor(self.alpha))
34
-
35
- idx = torch.argsort(preds, descending=True, stable=True)
36
- active_ranks = torch.take(target, idx).nonzero() + 1
37
-
38
- rie = calc_rie(n_total, active_ranks, r_a, exp_a)
39
- rie_min = (1 - exp_a ** r_a) / (r_a * (1 - exp_a))
40
- rie_max = (1 - exp_a ** (-r_a)) / (r_a * (1 - exp_a ** (-1)))
41
-
42
- return (rie - rie_min) / (rie_max - rie_min)
43
-
44
- def plot(self, val=None, ax=None):
45
- return self._plot(val, ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deepscreen/models/metrics/ci.py DELETED
@@ -1,39 +0,0 @@
1
- import torch
2
- from torchmetrics import Metric
3
- from torchmetrics.utilities.checks import _check_same_shape
4
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
5
-
6
- if not _MATPLOTLIB_AVAILABLE:
7
- __doctest_skip__ = ["ConcordanceIndex.plot"]
8
-
9
-
10
- class ConcordanceIndex(Metric):
11
- is_differentiable: bool = False
12
- higher_is_better: bool = True
13
- full_state_update: bool = False
14
- plot_lower_bound: float = 0.5
15
- plot_upper_bound: float = 1.0
16
-
17
- def __init__(self, dist_sync_on_step=False):
18
- super().__init__(dist_sync_on_step=dist_sync_on_step)
19
-
20
- self.add_state("num_concordant", default=torch.tensor(0), dist_reduce_fx="sum")
21
- self.add_state("num_valid", default=torch.tensor(0), dist_reduce_fx="sum")
22
-
23
- def update(self, preds: torch.Tensor, target: torch.Tensor):
24
- _check_same_shape(preds, target)
25
-
26
- g = preds.unsqueeze(-1) - preds
27
- g = (g == 0) * 0.5 + (g > 0)
28
-
29
- f = (target.unsqueeze(-1) - target) > 0
30
- f = torch.tril(f, diagonal=0)
31
-
32
- self.num_concordant += torch.sum(torch.mul(g, f)).long()
33
- self.num_valid += torch.sum(f).long()
34
-
35
- def compute(self):
36
- return torch.where(self.num_valid == 0, 0.0, self.num_concordant / self.num_valid)
37
-
38
- def plot(self, val=None, ax=None):
39
- return self._plot(val, ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deepscreen/models/metrics/ef.py DELETED
@@ -1,34 +0,0 @@
1
- import math
2
-
3
- from torch import Tensor, topk
4
- from torchmetrics.retrieval.base import RetrievalMetric
5
- from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
6
-
7
-
8
- class EnrichmentFactor(RetrievalMetric):
9
- is_differentiable: bool = False
10
- higher_is_better: bool = True
11
- full_state_update: bool = False
12
-
13
- def __init__(
14
- self,
15
- alpha: float,
16
- ):
17
- super().__init__()
18
- if alpha <= 0 or alpha > 1:
19
- raise ValueError(f"Argument ``alpha`` has to be in interval (0, 1] but got {alpha}")
20
- self.alpha = alpha
21
-
22
- def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
23
- preds, target = _check_retrieval_functional_inputs(preds, target)
24
-
25
- n_total = target.size(0)
26
- n_sampled = math.ceil(n_total * self.alpha)
27
- _, idx = topk(preds, n_sampled)
28
- hits_sampled = target[idx].sum()
29
- hits_total = target.sum()
30
-
31
- return hits_sampled / (hits_total * self.alpha)
32
-
33
- def plot(self, val=None, ax=None):
34
- return self._plot(val, ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deepscreen/models/metrics/hit_rate.py DELETED
@@ -1,36 +0,0 @@
1
- import math
2
-
3
- from torch import Tensor, topk
4
- from torchmetrics.retrieval.base import RetrievalMetric
5
- from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
6
-
7
-
8
- class HitRate(RetrievalMetric):
9
- """
10
- Computes hit rate for virtual screening.
11
- """
12
- is_differentiable: bool = False
13
- higher_is_better: bool = True
14
- full_state_update: bool = False
15
-
16
- def __init__(
17
- self,
18
- alpha: float = 0.01,
19
- ):
20
- super().__init__()
21
- if alpha <= 0 or alpha > 1:
22
- raise ValueError(f"Argument ``alpha`` has to be in interval (0, 1] but got {alpha}")
23
- self.alpha = alpha
24
-
25
- def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
26
- preds, target = _check_retrieval_functional_inputs(preds, target)
27
-
28
- n_total = target.size(0)
29
- n_sampled = math.ceil(n_total * self.alpha)
30
- _, idx = topk(preds, n_sampled)
31
- hits_sampled = target[idx].sum()
32
-
33
- return hits_sampled / n_sampled
34
-
35
- def plot(self, val=None, ax=None):
36
- return self._plot(val, ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deepscreen/models/metrics/rie.py DELETED
@@ -1,44 +0,0 @@
1
- import torch
2
- from torch import Tensor
3
- from torchmetrics.retrieval.base import RetrievalMetric
4
- from torchmetrics.utilities.checks import _check_retrieval_functional_inputs
5
-
6
-
7
- def calc_rie(n_total, active_ranks, r_a, exp_a):
8
- numerator = (exp_a ** (- active_ranks / n_total)).sum()
9
- denominator = (1 - exp_a ** (-1)) / (exp_a ** (1 / n_total) - 1)
10
-
11
- return numerator / (r_a * denominator)
12
-
13
-
14
- class RIE(RetrievalMetric):
15
- is_differentiable: bool = False
16
- higher_is_better: bool = True
17
- full_state_update: bool = False
18
-
19
- def __init__(
20
- self,
21
- alpha: float = 80.5,
22
- ):
23
- super().__init__()
24
- self.alpha = alpha
25
-
26
- def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
27
- preds, target = _check_retrieval_functional_inputs(preds, target)
28
-
29
- n_total = target.size(0)
30
- n_actives = target.sum()
31
-
32
- if n_actives == 0:
33
- return torch.tensor(0.0, device=preds.device)
34
-
35
- r_a = n_actives / n_total
36
- exp_a = torch.exp(torch.tensor(-self.alpha))
37
-
38
- idx = torch.argsort(preds, descending=True, stable=True)
39
- active_ranks = torch.take(target, idx).nonzero() + 1
40
-
41
- return calc_rie(n_total, active_ranks, r_a, exp_a)
42
-
43
- def plot(self, val=None, ax=None):
44
- return self._plot(val, ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deepscreen/models/metrics/sensitivity.py DELETED
@@ -1,337 +0,0 @@
1
- # Copyright The Lightning team.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Optional, Sequence, Union
15
-
16
- from torch import Tensor
17
- from torchmetrics.utilities.compute import _safe_divide, _adjust_weights_safe_divide
18
- from typing_extensions import Literal
19
-
20
- from torchmetrics.classification.base import _ClassificationTaskWrapper
21
- from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
22
- from torchmetrics.metric import Metric
23
- from torchmetrics.utilities.enums import ClassificationTask
24
- from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
25
- from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
26
-
27
- if not _MATPLOTLIB_AVAILABLE:
28
- __doctest_skip__ = ["BinarySensitivity.plot", "MulticlassSensitivity.plot", "MultilabelSensitivity.plot"]
29
-
30
-
31
- class BinarySensitivity(BinaryStatScores):
32
- r"""Compute `Sensitivity`_ for binary tasks.
33
-
34
- .. math:: \text{Sensitivity} = \frac{\text{TN}}{\text{TN} + \text{FP}}
35
-
36
- Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives
37
- respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is
38
- encountered a score of 0 is returned.
39
-
40
- As input to ``forward`` and ``update`` the metric accepts the following input:
41
-
42
- - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
43
- tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per
44
- element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``.
45
- - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``
46
-
47
- As output to ``forward`` and ``compute`` the metric returns the following output:
48
-
49
- - ``bs`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, the metric returns a scalar value.
50
- If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value
51
- per sample.
52
-
53
- Args:
54
- threshold: Threshold for transforming probability to binary {0,1} predictions
55
- multidim_average:
56
- Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
57
-
58
- - ``global``: Additional dimensions are flatted along the batch dimension
59
- - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
60
- The statistics in this case are calculated over the additional dimensions.
61
-
62
- ignore_index:
63
- Specifies a target value that is ignored and does not contribute to the metric calculation
64
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
65
- Set to ``False`` for faster computations.
66
- """
67
- plot_lower_bound: float = 0.0
68
- plot_upper_bound: float = 1.0
69
-
70
- def compute(self) -> Tensor:
71
- """Compute metric."""
72
- tp, fp, tn, fn = self._final_state()
73
- return _sensitivity_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average)
74
-
75
- def plot(
76
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
77
- ) -> _PLOT_OUT_TYPE:
78
- """Plot a single or multiple values from the metric.
79
-
80
- Args:
81
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
82
- If no value is provided, will automatically call `metric.compute` and plot that result.
83
- ax: An matplotlib axis object. If provided will add plot to that axis
84
-
85
- Returns:
86
- Figure object and Axes object
87
-
88
- Raises:
89
- ModuleNotFoundError:
90
- If `matplotlib` is not installed
91
- """
92
- return self._plot(val, ax)
93
-
94
-
95
- class MulticlassSensitivity(MulticlassStatScores):
96
- r"""Compute `Sensitivity`_ for multiclass tasks.
97
-
98
- .. math:: \text{Sensitivity} = \frac{\text{TN}}{\text{TN} + \text{FP}}
99
-
100
- Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives
101
- respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is
102
- encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
103
- affected in turn.
104
-
105
- As input to ``forward`` and ``update`` the metric accepts the following input:
106
-
107
- - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``.
108
- If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert
109
- probabilities/logits into an int tensor.
110
- - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``
111
-
112
- As output to ``forward`` and ``compute`` the metric returns the following output:
113
-
114
- - ``mcs`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average``
115
- arguments:
116
-
117
- - If ``multidim_average`` is set to ``global``:
118
-
119
- - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
120
- - If ``average=None/'none'``, the shape will be ``(C,)``
121
-
122
- - If ``multidim_average`` is set to ``samplewise``:
123
-
124
- - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
125
- - If ``average=None/'none'``, the shape will be ``(N, C)``
126
-
127
- Args:
128
- num_classes: Integer specifing the number of classes
129
- average:
130
- Defines the reduction that is applied over labels. Should be one of the following:
131
-
132
- - ``micro``: Sum statistics over all labels
133
- - ``macro``: Calculate statistics for each label and average them
134
- - ``weighted``: calculates statistics for each label and computes weighted average using their support
135
- - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
136
-
137
- top_k:
138
- Number of highest probability or logit score predictions considered to find the correct label.
139
- Only works when ``preds`` contain probabilities/logits.
140
- multidim_average:
141
- Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
142
-
143
- - ``global``: Additional dimensions are flatted along the batch dimension
144
- - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
145
- The statistics in this case are calculated over the additional dimensions.
146
-
147
- ignore_index:
148
- Specifies a target value that is ignored and does not contribute to the metric calculation
149
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
150
- Set to ``False`` for faster computations.
151
- """
152
- plot_lower_bound: float = 0.0
153
- plot_upper_bound: float = 1.0
154
- plot_legend_name: str = "Class"
155
-
156
- def compute(self) -> Tensor:
157
- """Compute metric."""
158
- tp, fp, tn, fn = self._final_state()
159
- return _sensitivity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)
160
-
161
- def plot(
162
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
163
- ) -> _PLOT_OUT_TYPE:
164
- """Plot a single or multiple values from the metric.
165
-
166
- Args:
167
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
168
- If no value is provided, will automatically call `metric.compute` and plot that result.
169
- ax: An matplotlib axis object. If provided will add plot to that axis
170
-
171
- Returns:
172
- Figure object and Axes object
173
-
174
- Raises:
175
- ModuleNotFoundError:
176
- If `matplotlib` is not installed
177
- """
178
- return self._plot(val, ax)
179
-
180
-
181
- class MultilabelSensitivity(MultilabelStatScores):
182
- r"""Compute `Sensitivity`_ for multilabel tasks.
183
-
184
- .. math:: \text{Sensitivity} = \frac{\text{TN}}{\text{TN} + \text{FP}}
185
-
186
- Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives
187
- respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is
188
- encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
189
- affected in turn.
190
-
191
- As input to ``forward`` and ``update`` the metric accepts the following input:
192
-
193
- - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, C, ...)``. If preds is a floating
194
- point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid
195
- per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``.
196
- - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``
197
-
198
-
199
- As output to ``forward`` and ``compute`` the metric returns the following output:
200
-
201
- - ``mls`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average``
202
- arguments:
203
-
204
- - If ``multidim_average`` is set to ``global``
205
-
206
- - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
207
- - If ``average=None/'none'``, the shape will be ``(C,)``
208
-
209
- - If ``multidim_average`` is set to ``samplewise``
210
-
211
- - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
212
- - If ``average=None/'none'``, the shape will be ``(N, C)``
213
-
214
- Args:
215
- num_labels: Integer specifing the number of labels
216
- threshold: Threshold for transforming probability to binary (0,1) predictions
217
- average:
218
- Defines the reduction that is applied over labels. Should be one of the following:
219
-
220
- - ``micro``: Sum statistics over all labels
221
- - ``macro``: Calculate statistics for each label and average them
222
- - ``weighted``: calculates statistics for each label and computes weighted average using their support
223
- - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction
224
-
225
- multidim_average: Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
226
-
227
- - ``global``: Additional dimensions are flatted along the batch dimension
228
- - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
229
- The statistics in this case are calculated over the additional dimensions.
230
-
231
- ignore_index:
232
- Specifies a target value that is ignored and does not contribute to the metric calculation
233
- validate_args: bool indicating if input arguments and tensors should be validated for correctness.
234
- Set to ``False`` for faster computations.
235
- """
236
- plot_lower_bound: float = 0.0
237
- plot_upper_bound: float = 1.0
238
- plot_legend_name: str = "Label"
239
-
240
- def compute(self) -> Tensor:
241
- """Compute metric."""
242
- tp, fp, tn, fn = self._final_state()
243
- return _sensitivity_reduce(
244
- tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
245
- )
246
-
247
- def plot(
248
- self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
249
- ) -> _PLOT_OUT_TYPE:
250
- """Plot a single or multiple values from the metric.
251
-
252
- Args:
253
- val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
254
- If no value is provided, will automatically call `metric.compute` and plot that result.
255
- ax: An matplotlib axis object. If provided will add plot to that axis
256
-
257
- Returns:
258
- Figure object and Axes object
259
-
260
- Raises:
261
- ModuleNotFoundError:
262
- If `matplotlib` is not installed
263
- """
264
- return self._plot(val, ax)
265
-
266
-
267
- class Sensitivity(_ClassificationTaskWrapper):
268
- r"""Compute `Sensitivity`_.
269
-
270
- .. math:: \text{Sensitivity} = \frac{\text{TN}}{\text{TN} + \text{FP}}
271
-
272
- Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives
273
- respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
274
- encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may
275
- therefore be affected in turn.
276
-
277
- This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
278
- ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of
279
- :class:`~torchmetrics.classification.BinarySensitivity`, :class:`~torchmetrics.classification.MulticlassSensitivity`
280
- and :class:`~torchmetrics.classification.MultilabelSensitivity` for the specific details of each argument influence
281
- and examples.
282
-
283
- Legacy Example:
284
- """
285
-
286
- def __new__( # type: ignore[misc]
287
- cls,
288
- task: Literal["binary", "multiclass", "multilabel"],
289
- threshold: float = 0.5,
290
- num_classes: Optional[int] = None,
291
- num_labels: Optional[int] = None,
292
- average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
293
- multidim_average: Optional[Literal["global", "samplewise"]] = "global",
294
- top_k: Optional[int] = 1,
295
- ignore_index: Optional[int] = None,
296
- validate_args: bool = True,
297
- **kwargs: Any,
298
- ) -> Metric:
299
- """Initialize task metric."""
300
- task = ClassificationTask.from_str(task)
301
- assert multidim_average is not None # noqa: S101 # needed for mypy
302
- kwargs.update(
303
- {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
304
- )
305
- if task == ClassificationTask.BINARY:
306
- return BinarySensitivity(threshold, **kwargs)
307
- if task == ClassificationTask.MULTICLASS:
308
- if not isinstance(num_classes, int):
309
- raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
310
- if not isinstance(top_k, int):
311
- raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`")
312
- return MulticlassSensitivity(num_classes, top_k, average, **kwargs)
313
- if task == ClassificationTask.MULTILABEL:
314
- if not isinstance(num_labels, int):
315
- raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
316
- return MultilabelSensitivity(num_labels, threshold, average, **kwargs)
317
- raise ValueError(f"Task {task} not supported!")
318
-
319
-
320
- def _sensitivity_reduce(
321
- tp: Tensor,
322
- fp: Tensor,
323
- tn: Tensor,
324
- fn: Tensor,
325
- average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
326
- multidim_average: Literal["global", "samplewise"] = "global",
327
- multilabel: bool = False,
328
- ) -> Tensor:
329
- if average == "binary":
330
- return _safe_divide(tp, tp + fn)
331
- if average == "micro":
332
- tp = tp.sum(dim=0 if multidim_average == "global" else 1)
333
- fn = fn.sum(dim=0 if multidim_average == "global" else 1)
334
- return _safe_divide(tp, tp + fn)
335
-
336
- sensitivity_score = _safe_divide(tp, tp + fn)
337
- return _adjust_weights_safe_divide(sensitivity_score, average, multilabel, tp, fp, fn)