EarlyStopper updates (#4679)
Browse files- train.py +3 -3
- utils/torch_utils.py +5 -2
train.py
CHANGED
@@ -344,7 +344,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
344 |
# mAP
|
345 |
callbacks.on_train_epoch_end(epoch=epoch)
|
346 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
347 |
-
final_epoch = epoch + 1 == epochs
|
348 |
if not noval or final_epoch: # Calculate mAP
|
349 |
results, maps, _ = val.run(data_dict,
|
350 |
batch_size=batch_size // WORLD_SIZE * 2,
|
@@ -384,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
384 |
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
|
385 |
|
386 |
# Stop Single-GPU
|
387 |
-
if stopper(epoch=epoch, fitness=fi):
|
388 |
break
|
389 |
|
390 |
# Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
|
@@ -462,7 +462,7 @@ def parse_opt(known=False):
|
|
462 |
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
463 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
464 |
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
|
465 |
-
parser.add_argument('--patience', type=int, default=
|
466 |
opt = parser.parse_known_args()[0] if known else parser.parse_args()
|
467 |
return opt
|
468 |
|
|
|
344 |
# mAP
|
345 |
callbacks.on_train_epoch_end(epoch=epoch)
|
346 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
347 |
+
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
|
348 |
if not noval or final_epoch: # Calculate mAP
|
349 |
results, maps, _ = val.run(data_dict,
|
350 |
batch_size=batch_size // WORLD_SIZE * 2,
|
|
|
384 |
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
|
385 |
|
386 |
# Stop Single-GPU
|
387 |
+
if RANK == -1 and stopper(epoch=epoch, fitness=fi):
|
388 |
break
|
389 |
|
390 |
# Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
|
|
|
462 |
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
463 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
464 |
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
|
465 |
+
parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
|
466 |
opt = parser.parse_known_args()[0] if known else parser.parse_args()
|
467 |
return opt
|
468 |
|
utils/torch_utils.py
CHANGED
@@ -298,13 +298,16 @@ class EarlyStopping:
|
|
298 |
def __init__(self, patience=30):
|
299 |
self.best_fitness = 0.0 # i.e. mAP
|
300 |
self.best_epoch = 0
|
301 |
-
self.patience = patience # epochs to wait after fitness stops improving to stop
|
|
|
302 |
|
303 |
def __call__(self, epoch, fitness):
|
304 |
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
305 |
self.best_epoch = epoch
|
306 |
self.best_fitness = fitness
|
307 |
-
|
|
|
|
|
308 |
if stop:
|
309 |
LOGGER.info(f'EarlyStopping patience {self.patience} exceeded, stopping training.')
|
310 |
return stop
|
|
|
298 |
def __init__(self, patience=30):
|
299 |
self.best_fitness = 0.0 # i.e. mAP
|
300 |
self.best_epoch = 0
|
301 |
+
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
302 |
+
self.possible_stop = False # possible stop may occur next epoch
|
303 |
|
304 |
def __call__(self, epoch, fitness):
|
305 |
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
306 |
self.best_epoch = epoch
|
307 |
self.best_fitness = fitness
|
308 |
+
delta = epoch - self.best_epoch # epochs without improvement
|
309 |
+
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
310 |
+
stop = delta >= self.patience # stop training if patience exceeded
|
311 |
if stop:
|
312 |
LOGGER.info(f'EarlyStopping patience {self.patience} exceeded, stopping training.')
|
313 |
return stop
|