ConfusionMatrix `normalize=True` fix (#3587)
Browse files- utils/metrics.py +3 -4
utils/metrics.py
CHANGED
@@ -161,9 +161,8 @@ class ConfusionMatrix:
|
|
161 |
def plot(self, normalize=True, save_dir='', names=()):
|
162 |
try:
|
163 |
import seaborn as sn
|
164 |
-
|
165 |
-
if normalize
|
166 |
-
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns
|
167 |
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
168 |
|
169 |
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
@@ -178,7 +177,7 @@ class ConfusionMatrix:
|
|
178 |
fig.axes[0].set_ylabel('Predicted')
|
179 |
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
180 |
except Exception as e:
|
181 |
-
|
182 |
|
183 |
def print(self):
|
184 |
for i in range(self.nc + 1):
|
|
|
161 |
def plot(self, normalize=True, save_dir='', names=()):
|
162 |
try:
|
163 |
import seaborn as sn
|
164 |
+
|
165 |
+
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-6) if normalize else 1) # normalize columns
|
|
|
166 |
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
167 |
|
168 |
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
|
|
177 |
fig.axes[0].set_ylabel('Predicted')
|
178 |
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
179 |
except Exception as e:
|
180 |
+
print(f'WARNING: ConfusionMatrix plot failure: {e}')
|
181 |
|
182 |
def print(self):
|
183 |
for i in range(self.nc + 1):
|