fix input format conversion
Browse files
ece.py
CHANGED
@@ -21,7 +21,6 @@ from torchmetrics.functional.classification.calibration_error import (
|
|
21 |
binary_calibration_error,
|
22 |
multiclass_calibration_error,
|
23 |
)
|
24 |
-
from numpy import ndarray
|
25 |
|
26 |
|
27 |
_CITATION = """\
|
@@ -109,15 +108,21 @@ class ECE(evaluate.Metric):
|
|
109 |
predictions = Tensor(predictions)
|
110 |
references = LongTensor(references)
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
# Compute the calibration
|
117 |
-
if
|
118 |
-
ece = multiclass_calibration_error(predictions, references, **kwargs)
|
119 |
-
else:
|
120 |
ece = binary_calibration_error(predictions, references, **kwargs)
|
|
|
|
|
121 |
return {
|
122 |
"ece": float(ece),
|
123 |
}
|
|
|
21 |
binary_calibration_error,
|
22 |
multiclass_calibration_error,
|
23 |
)
|
|
|
24 |
|
25 |
|
26 |
_CITATION = """\
|
|
|
108 |
predictions = Tensor(predictions)
|
109 |
references = LongTensor(references)
|
110 |
|
111 |
+
# Determine number of classes / binary or multiclass
|
112 |
+
binary = True
|
113 |
+
if "num_classes" not in kwargs:
|
114 |
+
max_label = int(amax(references, list(range(references.dim()))))
|
115 |
+
if max_label > 1:
|
116 |
+
kwargs["num_classes"] = max_label
|
117 |
+
binary = False
|
118 |
+
elif kwargs["num_classes"] > 1:
|
119 |
+
binary = False
|
120 |
|
121 |
# Compute the calibration
|
122 |
+
if binary:
|
|
|
|
|
123 |
ece = binary_calibration_error(predictions, references, **kwargs)
|
124 |
+
else:
|
125 |
+
ece = multiclass_calibration_error(predictions, references, **kwargs)
|
126 |
return {
|
127 |
"ece": float(ece),
|
128 |
}
|