glenn-jocher commited on
Commit
d2e698c
·
unverified ·
1 Parent(s): 23718df

Reduce val device transfers (#7525)

Browse files
Files changed (1) hide show
  1. val.py +6 -8
val.py CHANGED
@@ -220,14 +220,14 @@ def run(
220
  # Metrics
221
  for si, pred in enumerate(out):
222
  labels = targets[targets[:, 0] == si, 1:]
223
- nl = len(labels)
224
- tcls = labels[:, 0].tolist() if nl else [] # target class
225
  path, shape = Path(paths[si]), shapes[si][0]
 
226
  seen += 1
227
 
228
- if len(pred) == 0:
229
  if nl:
230
- stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
231
  continue
232
 
233
  # Predictions
@@ -244,9 +244,7 @@ def run(
244
  correct = process_batch(predn, labelsn, iouv)
245
  if plots:
246
  confusion_matrix.process_batch(predn, labelsn)
247
- else:
248
- correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
249
- stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
250
 
251
  # Save/log
252
  if save_txt:
@@ -265,7 +263,7 @@ def run(
265
  callbacks.run('on_val_batch_end')
266
 
267
  # Compute metrics
268
- stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
269
  if len(stats) and stats[0].any():
270
  tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
271
  ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95
 
220
  # Metrics
221
  for si, pred in enumerate(out):
222
  labels = targets[targets[:, 0] == si, 1:]
223
+ nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
 
224
  path, shape = Path(paths[si]), shapes[si][0]
225
+ correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
226
  seen += 1
227
 
228
+ if npr == 0:
229
  if nl:
230
+ stats.append((correct, *torch.zeros((3, 0))))
231
  continue
232
 
233
  # Predictions
 
244
  correct = process_batch(predn, labelsn, iouv)
245
  if plots:
246
  confusion_matrix.process_batch(predn, labelsn)
247
+ stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
 
 
248
 
249
  # Save/log
250
  if save_txt:
 
263
  callbacks.run('on_val_batch_end')
264
 
265
  # Compute metrics
266
+ stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
267
  if len(stats) and stats[0].any():
268
  tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
269
  ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95