specificity / specificity.py
nevikw39's picture
implement specificity
dc76a90
raw
history blame
7.82 kB
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specificity metric."""
import evaluate
import datasets
from sklearn.metrics import recall_score
_CITATION = """\
@article{scikit-learn, title={Scikit-learn: Machine Learning in {P}ython}, author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, journal={Journal of Machine Learning Research}, volume={12}, pages={2825--2830}, year={2011}}
"""
_DESCRIPTION = """\
Specificity is the fraction of the negatives examples that were correctly labeled by the model as negatives. It can be computed with the equation:
Specificity = TN / (TN + FP)
Where TN is the true negatives and FP is the false positives.
"""
_KWARGS_DESCRIPTION = """
Args:
- **predictions** (`list` of `int`): The predicted labels.
- **references** (`list` of `int`): The ground truth labels.
- **labels** (`list` of `int`): The set of labels to include when `average` is not set to `binary`, and their order when average is `None`. Labels present in the data can be excluded in this input, for example to calculate a multiclass average ignoring a majority negative class, while labels not present in the data will result in 0 components in a macro average. For multilabel targets, labels are column indices. By default, all labels in y_true and y_pred are used in sorted order. Defaults to None.
- **neg_label** (`int`): The class label to use as the 'negative class' when calculating the specificity. Defaults to `0`.
- **average** (`string`): This parameter is required for multiclass/multilabel targets. If None, the scores for each class are returned. Otherwise, this determines the type of averaging performed on the data. Defaults to `'binary'`.
- `'binary'`: Only report results for the class specified by `neg_label`. This is applicable only if the target labels and predictions are binary.
- `'micro'`: Calculate metrics globally by counting the total true negatives, false positives, and false negatives.
- `'macro'`: Calculate metrics for each label, and find their unweighted mean. This does not take label imbalance into account.
- `'weighted'`: Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. Note that it can result in an F-score that is not between sensitivity and specificity.
- `'samples'`: Calculate metrics for each instance, and find their average (only meaningful for multilabel classification).
- **sample_weight** (`list` of `float`): Sample weights Defaults to `None`.
- **zero_division** (): Sets the value to return when there is a zero division. Defaults to .
- `'warn'`: If there is a zero division, the return value is `0`, but warnings are also raised.
- `0`: If there is a zero division, the return value is `0`.
- `1`: If there is a zero division, the return value is `1`.
Returns:
- **specificity** (`float`, or `array` of `float`): Either the general specificity score, or the specificity scores for individual classes, depending on the values input to `labels` and `average`. Minimum possible value is 0. Maximum possible value is 1. A higher specificity means that more of the positive examples have been labeled correctly. Therefore, a higher specificity is generally considered better.
Examples:
Example 1-A simple example with some errors
>>> specificity_metric = evaluate.load('nevikw39/specificity')
>>> results = specificity_metric.compute(references=[0, 0, 1, 1, 1], predictions=[0, 1, 0, 1, 1])
>>> print(results)
{'specificity': 0.5}
Example 2-The same example as Example 1, but with `neg_label=1` instead of the default `neg_label=0`.
>>> specificity_metric = evaluate.load('nevikw39/specificity')
>>> results = specificity_metric.compute(references=[0, 0, 1, 1, 1], predictions=[0, 1, 0, 1, 1], neg_label=1)
>>> print(results)
{'specificity': 0.6666666666666666}
Example 3-The same example as Example 1, but with `sample_weight` included.
>>> specificity_metric = evaluate.load('nevikw39/specificity')
>>> sample_weight = [0.9, 0.2, 0.9, 0.3, 0.8]
>>> results = specificity_metric.compute(references=[0, 0, 1, 1, 1], predictions=[0, 1, 0, 1, 1], sample_weight=sample_weight)
>>> print(results)
{'specificity': 0.8181818181818181}
Example 4-A multiclass example, using different averages.
>>> specificity_metric = evaluate.load('nevikw39/specificity')
>>> predictions = [0, 2, 1, 0, 0, 1]
>>> references = [0, 1, 2, 0, 1, 2]
>>> results = specificity_metric.compute(predictions=predictions, references=references, average='macro')
>>> print(results)
{'specificity': 0.3333333333333333}
>>> results = specificity_metric.compute(predictions=predictions, references=references, average='micro')
>>> print(results)
{'specificity': 0.3333333333333333}
>>> results = specificity_metric.compute(predictions=predictions, references=references, average='weighted')
>>> print(results)
{'specificity': 0.3333333333333333}
>>> results = specificity_metric.compute(predictions=predictions, references=references, average=None)
>>> print(results)
{'specificity': array([1., 0., 0.])}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Specificity(evaluate.Metric):
"""Specificity metric."""
def _info(self):
return evaluate.MetricInfo(
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("int32")),
"references": datasets.Sequence(datasets.Value("int32")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
),
reference_urls=[
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html"
],
)
def _compute(
self,
predictions: datasets.Sequence(datasets.Value("int32")) | datasets.Value("int32"),
references: datasets.Sequence(datasets.Value("int32")) | datasets.Value("int32"),
labels=None,
neg_label=0,
average="binary",
sample_weight=None,
zero_division="warn",
):
score = recall_score(
references,
predictions,
labels=labels,
pos_label=neg_label,
average=average,
sample_weight=sample_weight,
zero_division=zero_division,
)
return {"specificity": float(score) if score.size == 1 else score}