# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import evaluate
import datasets
from torch import from_numpy, amax
from torchmetrics.functional.classification.calibration_error import (
    binary_calibration_error,
    multiclass_calibration_error,
)
from numpy import ndarray


_CITATION = """\
@InProceedings{huggingface:ece,
title = {Expected calibration error (ECE)},
authors={Nathan Fradet},
year={2023}
}
"""

_DESCRIPTION = """\
This metrics computes the expected calibration error (ECE).
It directly calls the torchmetrics package:
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
"""


_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
    predictions: list of predictions to score. They must have a shape (N,C,...) if multiclass, or (N,...) if binary.
    references: list of reference for each prediction, with a shape (N,...).
Returns:
    ece: expected calibration error
Examples:
    >>> ece = evaluate.load("Natooz/ece")
    >>> results = ece.compute(
    ...     references=np.array([[0.25, 0.20, 0.55],
    ...                          [0.55, 0.05, 0.40],
    ...                          [0.10, 0.30, 0.60],
    ...                          [0.90, 0.05, 0.05]]),
    ...     predictions=np.array(),
    ...     num_classes=3,
    ...     n_bins=3,
    ...     norm="l1",
    ... )
    >>> print(results)
    {'ece': 0.2000}
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class ECE(evaluate.Metric):
    """
    Proxy to the BinaryCalibrationError (ECE) metric of the torchmetrics package:
    https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
    """

    def _info(self):
        return evaluate.MetricInfo(
            # This is the description that will appear on the modules page.
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features(
                {
                    "predictions": datasets.Sequence(datasets.Value("float32")),
                    "references": datasets.Value("int64"),
                }
            ),
            # Homepage of the module for documentation
            homepage="https://huggingface.co/spaces/Natooz/ece",
            # Additional links to the codebase or references
            codebase_urls=[
                "https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py"
            ],
            reference_urls=[
                "https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html"
            ],
        )

    def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]:
        """Returns the ece.
        See the torchmetrics documentation for more information on the arguments to pass.
        https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
            predictions: (N,C,...) if multiclass or (N,...) if binary
            references: (N,...)

        If "num_classes" is not provided in a multiclasses setting, the number maximum label index will
        be used as "num_classes".
        """
        # Convert the input
        if isinstance(predictions, ndarray):
            predictions = from_numpy(predictions)
        if isinstance(references, ndarray):
            references = from_numpy(references)

        max_label = amax(references, list(range(references.dim())))
        if max_label > 1 and "num_classes" not in kwargs:
            kwargs["num_classes"] = max_label

        # Compute the calibration
        if max_label > 1:
            ece = multiclass_calibration_error(predictions, references, **kwargs)
        else:
            ece = binary_calibration_error(predictions, references, **kwargs)
        return {
            "ece": float(ece),
        }