glenn-jocher commited on
Commit
0a3ff71
·
unverified ·
1 Parent(s): 95fa653

Confusion matrix (#1474)

Browse files

* initial commit

* add plotting

* matrix to cpu

* bug fix

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* seaborn pandas to requirements.txt

* seaborn pandas to requirements.txt

* update wandb plotting

* remove pandas

* if plots

* if plots

* if plots

* if plots

* if plots

* initial commit

* add plotting

* matrix to cpu

* bug fix

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* seaborn pandas to requirements.txt

* seaborn pandas to requirements.txt

* update wandb plotting

* remove pandas

* if plots

* if plots

* if plots

* if plots

* if plots

* Cat apriori to autolabels

* cleanup

Files changed (4) hide show
  1. requirements.txt +4 -3
  2. test.py +10 -5
  3. train.py +3 -2
  4. utils/metrics.py +81 -0
requirements.txt CHANGED
@@ -16,8 +16,9 @@ tqdm>=4.41.0
16
  # logging -------------------------------------
17
  # wandb
18
 
19
- # coco ----------------------------------------
20
- # pycocotools>=2.0
 
21
 
22
  # export --------------------------------------
23
  # coremltools==4.0
@@ -26,4 +27,4 @@ tqdm>=4.41.0
26
 
27
  # extras --------------------------------------
28
  # thop # FLOPS computation
29
- # seaborn # plotting
 
16
  # logging -------------------------------------
17
  # wandb
18
 
19
+ # plotting ------------------------------------
20
+ seaborn
21
+ pandas
22
 
23
  # export --------------------------------------
24
  # coremltools==4.0
 
27
 
28
  # extras --------------------------------------
29
  # thop # FLOPS computation
30
+ # pycocotools>=2.0 # COCO mAP
test.py CHANGED
@@ -14,7 +14,7 @@ from utils.datasets import create_dataloader
14
  from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
15
  non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
16
  from utils.loss import compute_loss
17
- from utils.metrics import ap_per_class
18
  from utils.plots import plot_images, output_to_target
19
  from utils.torch_utils import select_device, time_synchronized
20
 
@@ -89,6 +89,7 @@ def test(data,
89
  dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]
90
 
91
  seen = 0
 
92
  names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
93
  coco91class = coco80_to_coco91_class()
94
  s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', '[email protected]:.95')
@@ -176,6 +177,8 @@ def test(data,
176
  # target boxes
177
  tbox = xywh2xyxy(labels[:, 1:5])
178
  scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
 
 
179
 
180
  # Per target class
181
  for cls in torch.unique(tcls_tensor):
@@ -218,10 +221,12 @@ def test(data,
218
  else:
219
  nt = torch.zeros(1)
220
 
221
- # W&B logging
222
- if plots and wandb and wandb.run:
223
- wandb.log({"Images": wandb_images})
224
- wandb.log({"Validation": [wandb.Image(str(x), caption=x.name) for x in sorted(save_dir.glob('test*.jpg'))]})
 
 
225
 
226
  # Print results
227
  pf = '%20s' + '%12.3g' * 6 # print format
 
14
  from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
15
  non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
16
  from utils.loss import compute_loss
17
+ from utils.metrics import ap_per_class, ConfusionMatrix
18
  from utils.plots import plot_images, output_to_target
19
  from utils.torch_utils import select_device, time_synchronized
20
 
 
89
  dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]
90
 
91
  seen = 0
92
+ confusion_matrix = ConfusionMatrix(nc=nc)
93
  names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
94
  coco91class = coco80_to_coco91_class()
95
  s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', '[email protected]:.95')
 
177
  # target boxes
178
  tbox = xywh2xyxy(labels[:, 1:5])
179
  scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
180
+ if plots:
181
+ confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1))
182
 
183
  # Per target class
184
  for cls in torch.unique(tcls_tensor):
 
221
  else:
222
  nt = torch.zeros(1)
223
 
224
+ # Plots
225
+ if plots:
226
+ confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
227
+ if wandb and wandb.run:
228
+ wandb.log({"Images": wandb_images})
229
+ wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})
230
 
231
  # Print results
232
  pf = '%20s' + '%12.3g' * 6 # print format
train.py CHANGED
@@ -396,8 +396,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
396
  if plots:
397
  plot_results(save_dir=save_dir) # save as results.png
398
  if wandb:
399
- wandb.log({"Results": [wandb.Image(str(save_dir / x), caption=x) for x in
400
- ['results.png', 'precision_recall_curve.png']]})
 
401
  logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
402
  else:
403
  dist.destroy_process_group()
 
396
  if plots:
397
  plot_results(save_dir=save_dir) # save as results.png
398
  if wandb:
399
+ files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
400
+ wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
401
+ if (save_dir / f).exists()]})
402
  logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
403
  else:
404
  dist.destroy_process_group()
utils/metrics.py CHANGED
@@ -4,6 +4,9 @@ from pathlib import Path
4
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
 
 
7
 
8
 
9
  def fitness(x):
@@ -102,6 +105,84 @@ def compute_ap(recall, precision):
102
  return ap, mpre, mrec
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def plot_pr_curve(px, py, ap, save_dir='.', names=()):
106
  fig, ax = plt.subplots(1, 1, figsize=(9, 6))
107
  py = np.stack(py, axis=1)
 
4
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
+ import torch
8
+
9
+ from . import general
10
 
11
 
12
  def fitness(x):
 
105
  return ap, mpre, mrec
106
 
107
 
108
+ class ConfusionMatrix:
109
+ # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
110
+ def __init__(self, nc, conf=0.25, iou_thres=0.45):
111
+ self.matrix = np.zeros((nc + 1, nc + 1))
112
+ self.nc = nc # number of classes
113
+ self.conf = conf
114
+ self.iou_thres = iou_thres
115
+
116
+ def process_batch(self, detections, labels):
117
+ """
118
+ Return intersection-over-union (Jaccard index) of boxes.
119
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
120
+ Arguments:
121
+ detections (Array[N, 6]), x1, y1, x2, y2, conf, class
122
+ labels (Array[M, 5]), class, x1, y1, x2, y2
123
+ Returns:
124
+ None, updates confusion matrix accordingly
125
+ """
126
+ detections = detections[detections[:, 4] > self.conf]
127
+ gt_classes = labels[:, 0].int()
128
+ detection_classes = detections[:, 5].int()
129
+ iou = general.box_iou(labels[:, 1:], detections[:, :4])
130
+
131
+ x = torch.where(iou > self.iou_thres)
132
+ if x[0].shape[0]:
133
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
134
+ if x[0].shape[0] > 1:
135
+ matches = matches[matches[:, 2].argsort()[::-1]]
136
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
137
+ matches = matches[matches[:, 2].argsort()[::-1]]
138
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
139
+ else:
140
+ matches = np.zeros((0, 3))
141
+
142
+ n = matches.shape[0] > 0
143
+ m0, m1, _ = matches.transpose().astype(np.int16)
144
+ for i, gc in enumerate(gt_classes):
145
+ j = m0 == i
146
+ if n and sum(j) == 1:
147
+ self.matrix[gc, detection_classes[m1[j]]] += 1 # correct
148
+ else:
149
+ self.matrix[gc, self.nc] += 1 # background FP
150
+
151
+ if n:
152
+ for i, dc in enumerate(detection_classes):
153
+ if not any(m1 == i):
154
+ self.matrix[self.nc, dc] += 1 # background FN
155
+
156
+ def matrix(self):
157
+ return self.matrix
158
+
159
+ def plot(self, save_dir='', names=()):
160
+ try:
161
+ import seaborn as sn
162
+
163
+ array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
164
+ array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
165
+
166
+ fig = plt.figure(figsize=(12, 9))
167
+ sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
168
+ labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
169
+ sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
170
+ xticklabels=names + ['background FN'] if labels else "auto",
171
+ yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1))
172
+ fig.axes[0].set_xlabel('True')
173
+ fig.axes[0].set_ylabel('Predicted')
174
+ fig.tight_layout()
175
+ fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
176
+ except Exception as e:
177
+ pass
178
+
179
+ def print(self):
180
+ for i in range(self.nc + 1):
181
+ print(' '.join(map(str, self.matrix[i])))
182
+
183
+
184
+ # Plots ----------------------------------------------------------------------------------------------------------------
185
+
186
  def plot_pr_curve(px, py, ap, save_dir='.', names=()):
187
  fig, ax = plt.subplots(1, 1, figsize=(9, 6))
188
  py = np.stack(py, axis=1)