Giacomo Guiduzzi glenn-jocher pre-commit-ci[bot] commited on
Commit
6935a54
·
unverified ·
1 Parent(s): f76a78e

Implementation of Early Stopping for DDP training (#8345)

Browse files

* Implementation of Early Stopping for DDP training

This edit correctly uses the broadcast_object_list() function to send slave processes a boolean so to end the training phase if the variable is True, thus allowing the master process to destroy the process group and terminate.

* Update train.py

* Update train.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Update train.py

* Update train.py

* Further cleanup

This cleans up the definition of broadcast_list and removes the requirement for clear() afterward.

Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (1) hide show
  1. train.py +10 -14
train.py CHANGED
@@ -294,7 +294,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
294
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
295
  scheduler.last_epoch = start_epoch - 1 # do not move
296
  scaler = torch.cuda.amp.GradScaler(enabled=amp)
297
- stopper = EarlyStopping(patience=opt.patience)
298
  compute_loss = ComputeLoss(model) # init loss class
299
  callbacks.run('on_train_start')
300
  LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
@@ -402,6 +402,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
402
 
403
  # Update best mAP
404
  fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
 
405
  if fi > best_fitness:
406
  best_fitness = fi
407
  log_vals = list(mloss) + list(results) + lr
@@ -428,19 +429,14 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
428
  del ckpt
429
  callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
430
 
431
- # Stop Single-GPU
432
- if RANK == -1 and stopper(epoch=epoch, fitness=fi):
433
- break
434
-
435
- # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
436
- # stop = stopper(epoch=epoch, fitness=fi)
437
- # if RANK == 0:
438
- # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
439
-
440
- # Stop DPP
441
- # with torch_distributed_zero_first(RANK):
442
- # if stop:
443
- # break # must break all DDP ranks
444
 
445
  # end epoch ----------------------------------------------------------------------------------------------------
446
  # end training -----------------------------------------------------------------------------------------------------
 
294
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
295
  scheduler.last_epoch = start_epoch - 1 # do not move
296
  scaler = torch.cuda.amp.GradScaler(enabled=amp)
297
+ stopper, stop = EarlyStopping(patience=opt.patience), False
298
  compute_loss = ComputeLoss(model) # init loss class
299
  callbacks.run('on_train_start')
300
  LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
 
402
 
403
  # Update best mAP
404
  fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
405
+ stop = stopper(epoch=epoch, fitness=fi) # early stop check
406
  if fi > best_fitness:
407
  best_fitness = fi
408
  log_vals = list(mloss) + list(results) + lr
 
429
  del ckpt
430
  callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
431
 
432
+ # EarlyStopping
433
+ if RANK != -1: # if DDP training
434
+ broadcast_list = [stop if RANK == 0 else None]
435
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
436
+ if RANK != 0:
437
+ stop = broadcast_list[0]
438
+ if stop:
439
+ break # must break all DDP ranks
 
 
 
 
 
440
 
441
  # end epoch ----------------------------------------------------------------------------------------------------
442
  # end training -----------------------------------------------------------------------------------------------------