Suppress jit trace warning + graph once (#3454)
Browse files* Suppress jit trace warning + graph once
Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0.
* Update train.py
train.py
CHANGED
@@ -4,6 +4,7 @@ import math
|
|
4 |
import os
|
5 |
import random
|
6 |
import time
|
|
|
7 |
from copy import deepcopy
|
8 |
from pathlib import Path
|
9 |
from threading import Thread
|
@@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None):
|
|
323 |
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
324 |
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
325 |
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
326 |
-
'
|
327 |
pbar.set_description(s)
|
328 |
|
329 |
# Plot
|
330 |
if plots and ni < 3:
|
331 |
f = save_dir / f'train_batch{ni}.jpg' # filename
|
332 |
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
333 |
-
if tb_writer:
|
334 |
-
|
335 |
-
|
|
|
336 |
elif plots and ni == 10 and wandb_logger.wandb:
|
337 |
-
wandb_logger.log({
|
338 |
save_dir.glob('train*.jpg') if x.exists()]})
|
339 |
|
340 |
# end batch ------------------------------------------------------------------------------------------------
|
|
|
4 |
import os
|
5 |
import random
|
6 |
import time
|
7 |
+
import warnings
|
8 |
from copy import deepcopy
|
9 |
from pathlib import Path
|
10 |
from threading import Thread
|
|
|
324 |
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
325 |
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
326 |
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
327 |
+
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
|
328 |
pbar.set_description(s)
|
329 |
|
330 |
# Plot
|
331 |
if plots and ni < 3:
|
332 |
f = save_dir / f'train_batch{ni}.jpg' # filename
|
333 |
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
334 |
+
if tb_writer and ni == 0:
|
335 |
+
with warnings.catch_warnings():
|
336 |
+
warnings.simplefilter('ignore') # suppress jit trace warning
|
337 |
+
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # graph
|
338 |
elif plots and ni == 10 and wandb_logger.wandb:
|
339 |
+
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
|
340 |
save_dir.glob('train*.jpg') if x.exists()]})
|
341 |
|
342 |
# end batch ------------------------------------------------------------------------------------------------
|