|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict |
|
|
|
import evaluate |
|
import datasets |
|
from torch import Tensor, LongTensor, amax |
|
from torchmetrics.functional.classification.calibration_error import ( |
|
binary_calibration_error, |
|
multiclass_calibration_error, |
|
) |
|
|
|
|
|
_CITATION = """\ |
|
@InProceedings{huggingface:ece, |
|
title = {Expected calibration error (ECE)}, |
|
authors={Nathan Fradet}, |
|
year={2023} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
This metrics computes the expected calibration error (ECE). |
|
It directly calls the torchmetrics package: |
|
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html |
|
""" |
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Calculates how good are predictions given some references, using certain scores |
|
Args: |
|
predictions: list of predictions to score. They must have a shape (N,C,...) if multiclass, or (N,...) if binary. |
|
references: list of reference for each prediction, with a shape (N,...). |
|
Returns: |
|
ece: expected calibration error |
|
Examples: |
|
>>> ece = evaluate.load("Natooz/ece") |
|
>>> results = ece.compute( |
|
... references=np.array([[0.25, 0.20, 0.55], |
|
... [0.55, 0.05, 0.40], |
|
... [0.10, 0.30, 0.60], |
|
... [0.90, 0.05, 0.05]]), |
|
... predictions=np.array(), |
|
... num_classes=3, |
|
... n_bins=3, |
|
... norm="l1", |
|
... ) |
|
>>> print(results) |
|
{'ece': 0.2000} |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class ECE(evaluate.Metric): |
|
""" |
|
Proxy to the BinaryCalibrationError (ECE) metric of the torchmetrics package: |
|
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html |
|
""" |
|
|
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
|
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
features=datasets.Features( |
|
{ |
|
"predictions": datasets.Sequence(datasets.Value("float32")), |
|
"references": datasets.Value("int64"), |
|
} |
|
), |
|
|
|
homepage="https://huggingface.co/spaces/Natooz/ece", |
|
|
|
codebase_urls=[ |
|
"https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py" |
|
], |
|
reference_urls=[ |
|
"https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html" |
|
], |
|
) |
|
|
|
def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]: |
|
"""Returns the ece. |
|
See the torchmetrics documentation for more information on the arguments to pass. |
|
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html |
|
predictions: (N,C,...) if multiclass or (N,...) if binary |
|
references: (N,...) |
|
|
|
If "num_classes" is not provided in a multiclasses setting, the number maximum label index will |
|
be used as "num_classes". |
|
""" |
|
|
|
predictions = Tensor(predictions) |
|
references = LongTensor(references) |
|
|
|
|
|
binary = True |
|
if "num_classes" not in kwargs: |
|
max_label = int(amax(references, list(range(references.dim())))) |
|
if max_label > 1: |
|
kwargs["num_classes"] = max_label |
|
binary = False |
|
elif kwargs["num_classes"] > 1: |
|
binary = False |
|
|
|
|
|
if binary: |
|
ece = binary_calibration_error(predictions, references, **kwargs) |
|
else: |
|
ece = multiclass_calibration_error(predictions, references, **kwargs) |
|
return { |
|
"ece": float(ece), |
|
} |
|
|