Nanobit commited on
Commit
886b984
·
unverified ·
1 Parent(s): 7eaf225

Add Multi-Node support for DDP Training (#504)

Browse files

* Add support for multi-node DDP

* Remove local_rank confusion

* Fix spacing

Files changed (1) hide show
  1. train.py +8 -7
train.py CHANGED
@@ -62,9 +62,9 @@ def train(hyp, opt, device, tb_writer=None):
62
  best = wdir + 'best.pt'
63
  results_file = log_dir + os.sep + 'results.txt'
64
  epochs, batch_size, total_batch_size, weights, rank = \
65
- opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.local_rank
 
66
  # TODO: Use DDP logging. Only the first process is allowed to log.
67
-
68
  # Save run settings
69
  with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
70
  yaml.dump(hyp, f, sort_keys=False)
@@ -184,7 +184,7 @@ def train(hyp, opt, device, tb_writer=None):
184
 
185
  # DDP mode
186
  if cuda and rank != -1:
187
- model = DDP(model, device_ids=[rank], output_device=rank)
188
 
189
  # Trainloader
190
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
@@ -441,8 +441,7 @@ if __name__ == '__main__':
441
  if last and not opt.weights:
442
  print(f'Resuming training from {last}')
443
  opt.weights = last if opt.resume and not opt.weights else opt.weights
444
-
445
- if opt.local_rank in [-1, 0]:
446
  check_git_status()
447
  opt.cfg = check_file(opt.cfg) # check file
448
  opt.data = check_file(opt.data) # check file
@@ -454,7 +453,8 @@ if __name__ == '__main__':
454
  device = select_device(opt.device, batch_size=opt.batch_size)
455
  opt.total_batch_size = opt.batch_size
456
  opt.world_size = 1
457
-
 
458
  # DDP mode
459
  if opt.local_rank != -1:
460
  assert torch.cuda.device_count() > opt.local_rank
@@ -462,6 +462,7 @@ if __name__ == '__main__':
462
  device = torch.device('cuda', opt.local_rank)
463
  dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
464
  opt.world_size = dist.get_world_size()
 
465
  assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
466
  opt.batch_size = opt.total_batch_size // opt.world_size
467
 
@@ -470,7 +471,7 @@ if __name__ == '__main__':
470
  # Train
471
  if not opt.evolve:
472
  tb_writer = None
473
- if opt.local_rank in [-1, 0]:
474
  print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
475
  tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
476
 
 
62
  best = wdir + 'best.pt'
63
  results_file = log_dir + os.sep + 'results.txt'
64
  epochs, batch_size, total_batch_size, weights, rank = \
65
+ opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
66
+
67
  # TODO: Use DDP logging. Only the first process is allowed to log.
 
68
  # Save run settings
69
  with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
70
  yaml.dump(hyp, f, sort_keys=False)
 
184
 
185
  # DDP mode
186
  if cuda and rank != -1:
187
+ model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
188
 
189
  # Trainloader
190
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
 
441
  if last and not opt.weights:
442
  print(f'Resuming training from {last}')
443
  opt.weights = last if opt.resume and not opt.weights else opt.weights
444
+ if opt.local_rank == -1 or ("RANK" in os.environ and os.environ["RANK"] == "0"):
 
445
  check_git_status()
446
  opt.cfg = check_file(opt.cfg) # check file
447
  opt.data = check_file(opt.data) # check file
 
453
  device = select_device(opt.device, batch_size=opt.batch_size)
454
  opt.total_batch_size = opt.batch_size
455
  opt.world_size = 1
456
+ opt.global_rank = -1
457
+
458
  # DDP mode
459
  if opt.local_rank != -1:
460
  assert torch.cuda.device_count() > opt.local_rank
 
462
  device = torch.device('cuda', opt.local_rank)
463
  dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
464
  opt.world_size = dist.get_world_size()
465
+ opt.global_rank = dist.get_rank()
466
  assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
467
  opt.batch_size = opt.total_batch_size // opt.world_size
468
 
 
471
  # Train
472
  if not opt.evolve:
473
  tb_writer = None
474
+ if opt.global_rank in [-1, 0]:
475
  print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
476
  tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
477