glenn-jocher commited on
Commit
5389300
·
unverified ·
1 Parent(s): a45e472

Fix ConfusionMatrix scale `vmin=0.0` (#6638)

Browse files

Fix attempt for https://github.com/ultralytics/yolov5/issues/6626

Files changed (1) hide show
  1. utils/metrics.py +5 -4
utils/metrics.py CHANGED
@@ -175,15 +175,16 @@ class ConfusionMatrix:
175
  try:
176
  import seaborn as sn
177
 
178
- array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-6) if normalize else 1) # normalize columns
179
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
180
 
181
  fig = plt.figure(figsize=(12, 9), tight_layout=True)
182
- sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
183
- labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
 
184
  with warnings.catch_warnings():
185
  warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
186
- sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
187
  xticklabels=names + ['background FP'] if labels else "auto",
188
  yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
189
  fig.axes[0].set_xlabel('True')
 
175
  try:
176
  import seaborn as sn
177
 
178
+ array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
179
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
180
 
181
  fig = plt.figure(figsize=(12, 9), tight_layout=True)
182
+ nc, nn = self.nc, len(names) # number of classes, names
183
+ sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
184
+ labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
185
  with warnings.catch_warnings():
186
  warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
187
+ sn.heatmap(array, annot=nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, vmin=0.0,
188
  xticklabels=names + ['background FP'] if labels else "auto",
189
  yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
190
  fig.axes[0].set_xlabel('True')