# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict import evaluate import datasets from torch import Tensor, LongTensor from torchmetrics.functional.classification.calibration_error import ( binary_calibration_error, multiclass_calibration_error, ) _CITATION = """\ @InProceedings{huggingface:ece, title = {Expected calibration error (ECE)}, authors={Nathan Fradet}, year={2023} } """ _DESCRIPTION = """\ This metrics computes the expected calibration error (ECE). It directly calls the torchmetrics package: https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html """ _KWARGS_DESCRIPTION = """ Calculates how good are predictions given some references, using certain scores Args: predictions: list of predictions to score. They must have a shape (N,C,...) if multiclass, or (N,...) if binary. references: list of reference for each prediction, with a shape (N,...). Returns: ece: expected calibration error Examples: >>> ece = evaluate.load("Natooz/ece") >>> results = ece.compute( ... references=np.array([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], ... [0.90, 0.05, 0.05]]), ... predictions=np.array(), ... num_classes=3, ... n_bins=3, ... norm="l1", ... ) >>> print(results) {'ece': 0.2000} """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class ECE(evaluate.Metric): """ Proxy to the BinaryCalibrationError (ECE) metric of the torchmetrics package: https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html """ def _info(self): return evaluate.MetricInfo( # This is the description that will appear on the modules page. module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, # This defines the format of each prediction and reference features=datasets.Features( { "predictions": datasets.Sequence(datasets.Value("float32")), "references": datasets.Value("int64"), } ), # Homepage of the module for documentation homepage="https://huggingface.co/spaces/Natooz/ece", # Additional links to the codebase or references codebase_urls=[ "https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py" ], reference_urls=[ "https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html" ], ) def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]: """Returns the ece. See the torchmetrics documentation for more information on the arguments to pass. https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html predictions: (N,C,...) if multiclass or (N,...) if binary references: (N,...) If "num_classes" is not provided in a multiclasses setting, the number maximum label index will be used as "num_classes". """ # Convert the input predictions = Tensor(predictions) references = LongTensor(references) # Determine number of classes / binary or multiclass error_msg = "Expected to have predictions with shape (N,C,...) for multiclass or (N,...) for binary, " \ f"and references with shape (N,...), but got {predictions.shape} and {references.shape}" binary = True if predictions.dim() == references.dim() + 1: # multiclass binary = False if "num_classes" not in kwargs: kwargs["num_classes"] = int(predictions.shape[1]) elif predictions.dim() == references.dim() and "num_classes" in kwargs: raise ValueError("You gave the num_classes argument, with predictions and references having the" "same number of dimensions. " + error_msg) elif predictions.dim() != references.dim(): raise ValueError("Bad input shape. " + error_msg) # Compute the calibration if binary: ece = binary_calibration_error(predictions, references, **kwargs) else: ece = multiclass_calibration_error(predictions, references, **kwargs) return { "ece": float(ece), }