glenn-jocher commited on
Commit
4aa2959
·
unverified ·
1 Parent(s): af2bc3a

Suppress jit trace warning + graph once (#3454)

Browse files

* Suppress jit trace warning + graph once

Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0.

* Update train.py

Files changed (1) hide show
  1. train.py +7 -5
train.py CHANGED
@@ -4,6 +4,7 @@ import math
4
  import os
5
  import random
6
  import time
 
7
  from copy import deepcopy
8
  from pathlib import Path
9
  from threading import Thread
@@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None):
323
  mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
324
  mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
325
  s = ('%10s' * 2 + '%10.4g' * 6) % (
326
- '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
327
  pbar.set_description(s)
328
 
329
  # Plot
330
  if plots and ni < 3:
331
  f = save_dir / f'train_batch{ni}.jpg' # filename
332
  Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
333
- if tb_writer:
334
- tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph
335
- # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
 
336
  elif plots and ni == 10 and wandb_logger.wandb:
337
- wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
338
  save_dir.glob('train*.jpg') if x.exists()]})
339
 
340
  # end batch ------------------------------------------------------------------------------------------------
 
4
  import os
5
  import random
6
  import time
7
+ import warnings
8
  from copy import deepcopy
9
  from pathlib import Path
10
  from threading import Thread
 
324
  mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
325
  mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
326
  s = ('%10s' * 2 + '%10.4g' * 6) % (
327
+ f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
328
  pbar.set_description(s)
329
 
330
  # Plot
331
  if plots and ni < 3:
332
  f = save_dir / f'train_batch{ni}.jpg' # filename
333
  Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
334
+ if tb_writer and ni == 0:
335
+ with warnings.catch_warnings():
336
+ warnings.simplefilter('ignore') # suppress jit trace warning
337
+ tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # graph
338
  elif plots and ni == 10 and wandb_logger.wandb:
339
+ wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
340
  save_dir.glob('train*.jpg') if x.exists()]})
341
 
342
  # end batch ------------------------------------------------------------------------------------------------