Confusion matrix (#1474)
Browse files* initial commit
* add plotting
* matrix to cpu
* bug fix
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* seaborn pandas to requirements.txt
* seaborn pandas to requirements.txt
* update wandb plotting
* remove pandas
* if plots
* if plots
* if plots
* if plots
* if plots
* initial commit
* add plotting
* matrix to cpu
* bug fix
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* update plot
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* cleanup
* seaborn pandas to requirements.txt
* seaborn pandas to requirements.txt
* update wandb plotting
* remove pandas
* if plots
* if plots
* if plots
* if plots
* if plots
* Cat apriori to autolabels
* cleanup
- requirements.txt +4 -3
- test.py +10 -5
- train.py +3 -2
- utils/metrics.py +81 -0
@@ -16,8 +16,9 @@ tqdm>=4.41.0
|
|
16 |
# logging -------------------------------------
|
17 |
# wandb
|
18 |
|
19 |
-
#
|
20 |
-
|
|
|
21 |
|
22 |
# export --------------------------------------
|
23 |
# coremltools==4.0
|
@@ -26,4 +27,4 @@ tqdm>=4.41.0
|
|
26 |
|
27 |
# extras --------------------------------------
|
28 |
# thop # FLOPS computation
|
29 |
-
#
|
|
|
16 |
# logging -------------------------------------
|
17 |
# wandb
|
18 |
|
19 |
+
# plotting ------------------------------------
|
20 |
+
seaborn
|
21 |
+
pandas
|
22 |
|
23 |
# export --------------------------------------
|
24 |
# coremltools==4.0
|
|
|
27 |
|
28 |
# extras --------------------------------------
|
29 |
# thop # FLOPS computation
|
30 |
+
# pycocotools>=2.0 # COCO mAP
|
@@ -14,7 +14,7 @@ from utils.datasets import create_dataloader
|
|
14 |
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
|
15 |
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
|
16 |
from utils.loss import compute_loss
|
17 |
-
from utils.metrics import ap_per_class
|
18 |
from utils.plots import plot_images, output_to_target
|
19 |
from utils.torch_utils import select_device, time_synchronized
|
20 |
|
@@ -89,6 +89,7 @@ def test(data,
|
|
89 |
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]
|
90 |
|
91 |
seen = 0
|
|
|
92 |
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
|
93 |
coco91class = coco80_to_coco91_class()
|
94 |
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', '[email protected]:.95')
|
@@ -176,6 +177,8 @@ def test(data,
|
|
176 |
# target boxes
|
177 |
tbox = xywh2xyxy(labels[:, 1:5])
|
178 |
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
|
|
|
|
|
179 |
|
180 |
# Per target class
|
181 |
for cls in torch.unique(tcls_tensor):
|
@@ -218,10 +221,12 @@ def test(data,
|
|
218 |
else:
|
219 |
nt = torch.zeros(1)
|
220 |
|
221 |
-
#
|
222 |
-
if plots
|
223 |
-
|
224 |
-
|
|
|
|
|
225 |
|
226 |
# Print results
|
227 |
pf = '%20s' + '%12.3g' * 6 # print format
|
|
|
14 |
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
|
15 |
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
|
16 |
from utils.loss import compute_loss
|
17 |
+
from utils.metrics import ap_per_class, ConfusionMatrix
|
18 |
from utils.plots import plot_images, output_to_target
|
19 |
from utils.torch_utils import select_device, time_synchronized
|
20 |
|
|
|
89 |
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]
|
90 |
|
91 |
seen = 0
|
92 |
+
confusion_matrix = ConfusionMatrix(nc=nc)
|
93 |
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
|
94 |
coco91class = coco80_to_coco91_class()
|
95 |
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', '[email protected]:.95')
|
|
|
177 |
# target boxes
|
178 |
tbox = xywh2xyxy(labels[:, 1:5])
|
179 |
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
|
180 |
+
if plots:
|
181 |
+
confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1))
|
182 |
|
183 |
# Per target class
|
184 |
for cls in torch.unique(tcls_tensor):
|
|
|
221 |
else:
|
222 |
nt = torch.zeros(1)
|
223 |
|
224 |
+
# Plots
|
225 |
+
if plots:
|
226 |
+
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
227 |
+
if wandb and wandb.run:
|
228 |
+
wandb.log({"Images": wandb_images})
|
229 |
+
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})
|
230 |
|
231 |
# Print results
|
232 |
pf = '%20s' + '%12.3g' * 6 # print format
|
@@ -396,8 +396,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
396 |
if plots:
|
397 |
plot_results(save_dir=save_dir) # save as results.png
|
398 |
if wandb:
|
399 |
-
|
400 |
-
|
|
|
401 |
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
402 |
else:
|
403 |
dist.destroy_process_group()
|
|
|
396 |
if plots:
|
397 |
plot_results(save_dir=save_dir) # save as results.png
|
398 |
if wandb:
|
399 |
+
files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
|
400 |
+
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
|
401 |
+
if (save_dir / f).exists()]})
|
402 |
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
403 |
else:
|
404 |
dist.destroy_process_group()
|
@@ -4,6 +4,9 @@ from pathlib import Path
|
|
4 |
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
def fitness(x):
|
@@ -102,6 +105,84 @@ def compute_ap(recall, precision):
|
|
102 |
return ap, mpre, mrec
|
103 |
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def plot_pr_curve(px, py, ap, save_dir='.', names=()):
|
106 |
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
|
107 |
py = np.stack(py, axis=1)
|
|
|
4 |
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from . import general
|
10 |
|
11 |
|
12 |
def fitness(x):
|
|
|
105 |
return ap, mpre, mrec
|
106 |
|
107 |
|
108 |
+
class ConfusionMatrix:
|
109 |
+
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
|
110 |
+
def __init__(self, nc, conf=0.25, iou_thres=0.45):
|
111 |
+
self.matrix = np.zeros((nc + 1, nc + 1))
|
112 |
+
self.nc = nc # number of classes
|
113 |
+
self.conf = conf
|
114 |
+
self.iou_thres = iou_thres
|
115 |
+
|
116 |
+
def process_batch(self, detections, labels):
|
117 |
+
"""
|
118 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
119 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
120 |
+
Arguments:
|
121 |
+
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
|
122 |
+
labels (Array[M, 5]), class, x1, y1, x2, y2
|
123 |
+
Returns:
|
124 |
+
None, updates confusion matrix accordingly
|
125 |
+
"""
|
126 |
+
detections = detections[detections[:, 4] > self.conf]
|
127 |
+
gt_classes = labels[:, 0].int()
|
128 |
+
detection_classes = detections[:, 5].int()
|
129 |
+
iou = general.box_iou(labels[:, 1:], detections[:, :4])
|
130 |
+
|
131 |
+
x = torch.where(iou > self.iou_thres)
|
132 |
+
if x[0].shape[0]:
|
133 |
+
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
134 |
+
if x[0].shape[0] > 1:
|
135 |
+
matches = matches[matches[:, 2].argsort()[::-1]]
|
136 |
+
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
137 |
+
matches = matches[matches[:, 2].argsort()[::-1]]
|
138 |
+
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
139 |
+
else:
|
140 |
+
matches = np.zeros((0, 3))
|
141 |
+
|
142 |
+
n = matches.shape[0] > 0
|
143 |
+
m0, m1, _ = matches.transpose().astype(np.int16)
|
144 |
+
for i, gc in enumerate(gt_classes):
|
145 |
+
j = m0 == i
|
146 |
+
if n and sum(j) == 1:
|
147 |
+
self.matrix[gc, detection_classes[m1[j]]] += 1 # correct
|
148 |
+
else:
|
149 |
+
self.matrix[gc, self.nc] += 1 # background FP
|
150 |
+
|
151 |
+
if n:
|
152 |
+
for i, dc in enumerate(detection_classes):
|
153 |
+
if not any(m1 == i):
|
154 |
+
self.matrix[self.nc, dc] += 1 # background FN
|
155 |
+
|
156 |
+
def matrix(self):
|
157 |
+
return self.matrix
|
158 |
+
|
159 |
+
def plot(self, save_dir='', names=()):
|
160 |
+
try:
|
161 |
+
import seaborn as sn
|
162 |
+
|
163 |
+
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
|
164 |
+
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
165 |
+
|
166 |
+
fig = plt.figure(figsize=(12, 9))
|
167 |
+
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
|
168 |
+
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
|
169 |
+
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
|
170 |
+
xticklabels=names + ['background FN'] if labels else "auto",
|
171 |
+
yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1))
|
172 |
+
fig.axes[0].set_xlabel('True')
|
173 |
+
fig.axes[0].set_ylabel('Predicted')
|
174 |
+
fig.tight_layout()
|
175 |
+
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
176 |
+
except Exception as e:
|
177 |
+
pass
|
178 |
+
|
179 |
+
def print(self):
|
180 |
+
for i in range(self.nc + 1):
|
181 |
+
print(' '.join(map(str, self.matrix[i])))
|
182 |
+
|
183 |
+
|
184 |
+
# Plots ----------------------------------------------------------------------------------------------------------------
|
185 |
+
|
186 |
def plot_pr_curve(px, py, ap, save_dir='.', names=()):
|
187 |
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
|
188 |
py = np.stack(py, axis=1)
|