Reduce val device transfers (#7525)
Browse files
val.py
CHANGED
@@ -220,14 +220,14 @@ def run(
|
|
220 |
# Metrics
|
221 |
for si, pred in enumerate(out):
|
222 |
labels = targets[targets[:, 0] == si, 1:]
|
223 |
-
nl =
|
224 |
-
tcls = labels[:, 0].tolist() if nl else [] # target class
|
225 |
path, shape = Path(paths[si]), shapes[si][0]
|
|
|
226 |
seen += 1
|
227 |
|
228 |
-
if
|
229 |
if nl:
|
230 |
-
stats.append((torch.zeros(
|
231 |
continue
|
232 |
|
233 |
# Predictions
|
@@ -244,9 +244,7 @@ def run(
|
|
244 |
correct = process_batch(predn, labelsn, iouv)
|
245 |
if plots:
|
246 |
confusion_matrix.process_batch(predn, labelsn)
|
247 |
-
|
248 |
-
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
|
249 |
-
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
|
250 |
|
251 |
# Save/log
|
252 |
if save_txt:
|
@@ -265,7 +263,7 @@ def run(
|
|
265 |
callbacks.run('on_val_batch_end')
|
266 |
|
267 |
# Compute metrics
|
268 |
-
stats = [
|
269 |
if len(stats) and stats[0].any():
|
270 |
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
|
271 |
ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95
|
|
|
220 |
# Metrics
|
221 |
for si, pred in enumerate(out):
|
222 |
labels = targets[targets[:, 0] == si, 1:]
|
223 |
+
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
|
|
224 |
path, shape = Path(paths[si]), shapes[si][0]
|
225 |
+
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
|
226 |
seen += 1
|
227 |
|
228 |
+
if npr == 0:
|
229 |
if nl:
|
230 |
+
stats.append((correct, *torch.zeros((3, 0))))
|
231 |
continue
|
232 |
|
233 |
# Predictions
|
|
|
244 |
correct = process_batch(predn, labelsn, iouv)
|
245 |
if plots:
|
246 |
confusion_matrix.process_batch(predn, labelsn)
|
247 |
+
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
|
|
|
|
|
248 |
|
249 |
# Save/log
|
250 |
if save_txt:
|
|
|
263 |
callbacks.run('on_val_batch_end')
|
264 |
|
265 |
# Compute metrics
|
266 |
+
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
|
267 |
if len(stats) and stats[0].any():
|
268 |
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
|
269 |
ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95
|