Natooz commited on
Commit
65b297d
·
unverified ·
1 Parent(s): 9db8383

fix input format conversion

Browse files
Files changed (1) hide show
  1. ece.py +12 -7
ece.py CHANGED
@@ -21,7 +21,6 @@ from torchmetrics.functional.classification.calibration_error import (
21
  binary_calibration_error,
22
  multiclass_calibration_error,
23
  )
24
- from numpy import ndarray
25
 
26
 
27
  _CITATION = """\
@@ -109,15 +108,21 @@ class ECE(evaluate.Metric):
109
  predictions = Tensor(predictions)
110
  references = LongTensor(references)
111
 
112
- max_label = amax(references, list(range(references.dim())))
113
- if max_label > 1 and "num_classes" not in kwargs:
114
- kwargs["num_classes"] = max_label
 
 
 
 
 
 
115
 
116
  # Compute the calibration
117
- if max_label > 1:
118
- ece = multiclass_calibration_error(predictions, references, **kwargs)
119
- else:
120
  ece = binary_calibration_error(predictions, references, **kwargs)
 
 
121
  return {
122
  "ece": float(ece),
123
  }
 
21
  binary_calibration_error,
22
  multiclass_calibration_error,
23
  )
 
24
 
25
 
26
  _CITATION = """\
 
108
  predictions = Tensor(predictions)
109
  references = LongTensor(references)
110
 
111
+ # Determine number of classes / binary or multiclass
112
+ binary = True
113
+ if "num_classes" not in kwargs:
114
+ max_label = int(amax(references, list(range(references.dim()))))
115
+ if max_label > 1:
116
+ kwargs["num_classes"] = max_label
117
+ binary = False
118
+ elif kwargs["num_classes"] > 1:
119
+ binary = False
120
 
121
  # Compute the calibration
122
+ if binary:
 
 
123
  ece = binary_calibration_error(predictions, references, **kwargs)
124
+ else:
125
+ ece = multiclass_calibration_error(predictions, references, **kwargs)
126
  return {
127
  "ece": float(ece),
128
  }