glenn-jocher commited on
Commit
045d5d8
·
unverified ·
1 Parent(s): fa201f9

Update TensorBoard (#3669)

Browse files
Files changed (1) hide show
  1. train.py +17 -16
train.py CHANGED
@@ -42,7 +42,6 @@ logger = logging.getLogger(__name__)
42
  def train(hyp, # path/to/hyp.yaml or hyp dictionary
43
  opt,
44
  device,
45
- tb_writer=None
46
  ):
47
  save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
48
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
@@ -74,9 +73,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
74
  with open(opt.data) as f:
75
  data_dict = yaml.safe_load(f) # data dict
76
 
77
- # Logging- Doing this before checking the dataset. Might update data_dict
78
- loggers = {'wandb': None} # loggers dict
79
  if rank in [-1, 0]:
 
 
 
 
 
 
 
80
  opt.hyp = hyp # add hyperparameters
81
  run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
82
  wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
@@ -219,8 +225,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
219
  # model._initialize_biases(cf.to(device))
220
  if plots:
221
  plot_labels(labels, names, save_dir, loggers)
222
- if tb_writer:
223
- tb_writer.add_histogram('classes', c, 0)
224
 
225
  # Anchors
226
  if not opt.noautoanchor:
@@ -341,10 +347,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
341
  if plots and ni < 3:
342
  f = save_dir / f'train_batch{ni}.jpg' # filename
343
  Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
344
- if tb_writer and ni == 0:
345
  with warnings.catch_warnings():
346
  warnings.simplefilter('ignore') # suppress jit trace warning
347
- tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
348
  elif plots and ni == 10 and wandb_logger.wandb:
349
  wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
350
  save_dir.glob('train*.jpg') if x.exists()]})
@@ -352,7 +358,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
352
  # end batch ------------------------------------------------------------------------------------------------
353
 
354
  # Scheduler
355
- lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
356
  scheduler.step()
357
 
358
  # DDP process 0 or single-GPU
@@ -385,8 +391,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
385
  'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
386
  'x/lr0', 'x/lr1', 'x/lr2'] # params
387
  for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
388
- if tb_writer:
389
- tb_writer.add_scalar(tag, x, epoch) # tensorboard
390
  if wandb_logger.wandb:
391
  wandb_logger.log({tag: x}) # W&B
392
 
@@ -537,12 +543,7 @@ if __name__ == '__main__':
537
  # Train
538
  logger.info(opt)
539
  if not opt.evolve:
540
- tb_writer = None # init loggers
541
- if opt.global_rank in [-1, 0]:
542
- prefix = colorstr('tensorboard: ')
543
- logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
544
- tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
545
- train(opt.hyp, opt, device, tb_writer)
546
 
547
  # Evolve hyperparameters (optional)
548
  else:
 
42
  def train(hyp, # path/to/hyp.yaml or hyp dictionary
43
  opt,
44
  device,
 
45
  ):
46
  save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
47
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
 
73
  with open(opt.data) as f:
74
  data_dict = yaml.safe_load(f) # data dict
75
 
76
+ # Loggers
77
+ loggers = {'wandb': None, 'tb': None} # loggers dict
78
  if rank in [-1, 0]:
79
+ # TensorBoard
80
+ if not opt.evolve:
81
+ prefix = colorstr('tensorboard: ')
82
+ logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
83
+ loggers['tb'] = SummaryWriter(opt.save_dir)
84
+
85
+ # W&B
86
  opt.hyp = hyp # add hyperparameters
87
  run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
88
  wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
 
225
  # model._initialize_biases(cf.to(device))
226
  if plots:
227
  plot_labels(labels, names, save_dir, loggers)
228
+ if loggers['tb']:
229
+ loggers['tb'].add_histogram('classes', c, 0) # TensorBoard
230
 
231
  # Anchors
232
  if not opt.noautoanchor:
 
347
  if plots and ni < 3:
348
  f = save_dir / f'train_batch{ni}.jpg' # filename
349
  Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
350
+ if loggers['tb'] and ni == 0: # TensorBoard
351
  with warnings.catch_warnings():
352
  warnings.simplefilter('ignore') # suppress jit trace warning
353
+ loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
354
  elif plots and ni == 10 and wandb_logger.wandb:
355
  wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
356
  save_dir.glob('train*.jpg') if x.exists()]})
 
358
  # end batch ------------------------------------------------------------------------------------------------
359
 
360
  # Scheduler
361
+ lr = [x['lr'] for x in optimizer.param_groups] # for loggers
362
  scheduler.step()
363
 
364
  # DDP process 0 or single-GPU
 
391
  'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
392
  'x/lr0', 'x/lr1', 'x/lr2'] # params
393
  for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
394
+ if loggers['tb']:
395
+ loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
396
  if wandb_logger.wandb:
397
  wandb_logger.log({tag: x}) # W&B
398
 
 
543
  # Train
544
  logger.info(opt)
545
  if not opt.evolve:
546
+ train(opt.hyp, opt, device)
 
 
 
 
 
547
 
548
  # Evolve hyperparameters (optional)
549
  else: