code style update
Browse files
ece.py
CHANGED
@@ -17,7 +17,10 @@ from typing import Dict
|
|
17 |
import evaluate
|
18 |
import datasets
|
19 |
from torch import from_numpy, amax
|
20 |
-
from torchmetrics.functional.classification.calibration_error import
|
|
|
|
|
|
|
21 |
from numpy import ndarray
|
22 |
|
23 |
|
@@ -75,15 +78,21 @@ class ECE(evaluate.Metric):
|
|
75 |
citation=_CITATION,
|
76 |
inputs_description=_KWARGS_DESCRIPTION,
|
77 |
# This defines the format of each prediction and reference
|
78 |
-
features=datasets.Features(
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
82 |
# Homepage of the module for documentation
|
83 |
homepage="https://huggingface.co/spaces/Natooz/ece",
|
84 |
# Additional links to the codebase or references
|
85 |
-
codebase_urls=[
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
)
|
88 |
|
89 |
def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]:
|
|
|
17 |
import evaluate
|
18 |
import datasets
|
19 |
from torch import from_numpy, amax
|
20 |
+
from torchmetrics.functional.classification.calibration_error import (
|
21 |
+
binary_calibration_error,
|
22 |
+
multiclass_calibration_error,
|
23 |
+
)
|
24 |
from numpy import ndarray
|
25 |
|
26 |
|
|
|
78 |
citation=_CITATION,
|
79 |
inputs_description=_KWARGS_DESCRIPTION,
|
80 |
# This defines the format of each prediction and reference
|
81 |
+
features=datasets.Features(
|
82 |
+
{
|
83 |
+
"predictions": datasets.Sequence(datasets.Value("float32")),
|
84 |
+
"references": datasets.Value("int64"),
|
85 |
+
}
|
86 |
+
),
|
87 |
# Homepage of the module for documentation
|
88 |
homepage="https://huggingface.co/spaces/Natooz/ece",
|
89 |
# Additional links to the codebase or references
|
90 |
+
codebase_urls=[
|
91 |
+
"https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py"
|
92 |
+
],
|
93 |
+
reference_urls=[
|
94 |
+
"https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html"
|
95 |
+
],
|
96 |
)
|
97 |
|
98 |
def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]:
|