Implementation of Early Stopping for DDP training (#8345)
Browse files* Implementation of Early Stopping for DDP training
This edit correctly uses the broadcast_object_list() function to send slave processes a boolean so to end the training phase if the variable is True, thus allowing the master process to destroy the process group and terminate.
* Update train.py
* Update train.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update train.py
* Update train.py
* Update train.py
* Further cleanup
This cleans up the definition of broadcast_list and removes the requirement for clear() afterward.
Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
train.py
CHANGED
@@ -294,7 +294,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
294 |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
|
295 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
296 |
scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
297 |
-
stopper = EarlyStopping(patience=opt.patience)
|
298 |
compute_loss = ComputeLoss(model) # init loss class
|
299 |
callbacks.run('on_train_start')
|
300 |
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
|
@@ -402,6 +402,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
402 |
|
403 |
# Update best mAP
|
404 |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
|
|
|
405 |
if fi > best_fitness:
|
406 |
best_fitness = fi
|
407 |
log_vals = list(mloss) + list(results) + lr
|
@@ -428,19 +429,14 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
428 |
del ckpt
|
429 |
callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
|
430 |
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
# Stop DPP
|
441 |
-
# with torch_distributed_zero_first(RANK):
|
442 |
-
# if stop:
|
443 |
-
# break # must break all DDP ranks
|
444 |
|
445 |
# end epoch ----------------------------------------------------------------------------------------------------
|
446 |
# end training -----------------------------------------------------------------------------------------------------
|
|
|
294 |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
|
295 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
296 |
scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
297 |
+
stopper, stop = EarlyStopping(patience=opt.patience), False
|
298 |
compute_loss = ComputeLoss(model) # init loss class
|
299 |
callbacks.run('on_train_start')
|
300 |
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
|
|
|
402 |
|
403 |
# Update best mAP
|
404 |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
|
405 |
+
stop = stopper(epoch=epoch, fitness=fi) # early stop check
|
406 |
if fi > best_fitness:
|
407 |
best_fitness = fi
|
408 |
log_vals = list(mloss) + list(results) + lr
|
|
|
429 |
del ckpt
|
430 |
callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
|
431 |
|
432 |
+
# EarlyStopping
|
433 |
+
if RANK != -1: # if DDP training
|
434 |
+
broadcast_list = [stop if RANK == 0 else None]
|
435 |
+
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
436 |
+
if RANK != 0:
|
437 |
+
stop = broadcast_list[0]
|
438 |
+
if stop:
|
439 |
+
break # must break all DDP ranks
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
# end epoch ----------------------------------------------------------------------------------------------------
|
442 |
# end training -----------------------------------------------------------------------------------------------------
|