Improved model+EMA checkpointing (#2292)
Browse files* Enhanced model+EMA checkpointing
* update
* bug fix
* bug fix 2
* always save optimizer
* ema half
* remove model.float()
* model half
* carry ema/model in fp32
* rm model.float()
* both to float always
* cleanup
* cleanup
- test.py +0 -1
- train.py +16 -9
- utils/general.py +2 -2
test.py
CHANGED
@@ -272,7 +272,6 @@ def test(data,
|
|
272 |
if not training:
|
273 |
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
274 |
print(f"Results saved to {save_dir}{s}")
|
275 |
-
model.float() # for training
|
276 |
maps = np.zeros(nc) + map
|
277 |
for i, c in enumerate(ap_class):
|
278 |
maps[c] = ap[i]
|
|
|
272 |
if not training:
|
273 |
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
274 |
print(f"Results saved to {save_dir}{s}")
|
|
|
275 |
maps = np.zeros(nc) + map
|
276 |
for i, c in enumerate(ap_class):
|
277 |
maps[c] = ap[i]
|
train.py
CHANGED
@@ -31,7 +31,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
|
|
31 |
from utils.google_utils import attempt_download
|
32 |
from utils.loss import ComputeLoss
|
33 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
34 |
-
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
|
35 |
|
36 |
logger = logging.getLogger(__name__)
|
37 |
|
@@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
136 |
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
|
137 |
loggers = {'wandb': wandb} # loggers dict
|
138 |
|
|
|
|
|
|
|
139 |
# Resume
|
140 |
start_epoch, best_fitness = 0, 0.0
|
141 |
if pretrained:
|
@@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
144 |
optimizer.load_state_dict(ckpt['optimizer'])
|
145 |
best_fitness = ckpt['best_fitness']
|
146 |
|
|
|
|
|
|
|
|
|
|
|
147 |
# Results
|
148 |
if ckpt.get('training_results') is not None:
|
149 |
results_file.write_text(ckpt['training_results']) # write results.txt
|
@@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
173 |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
174 |
logger.info('Using SyncBatchNorm()')
|
175 |
|
176 |
-
# EMA
|
177 |
-
ema = ModelEMA(model) if rank in [-1, 0] else None
|
178 |
-
|
179 |
# DDP mode
|
180 |
if cuda and rank != -1:
|
181 |
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
|
@@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
191 |
|
192 |
# Process 0
|
193 |
if rank in [-1, 0]:
|
194 |
-
ema.updates = start_epoch * nb // accumulate # set EMA updates
|
195 |
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
|
196 |
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
197 |
world_size=opt.world_size, workers=opt.workers,
|
@@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
335 |
# DDP process 0 or single-GPU
|
336 |
if rank in [-1, 0]:
|
337 |
# mAP
|
338 |
-
|
339 |
-
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
340 |
final_epoch = epoch + 1 == epochs
|
341 |
if not opt.notest or final_epoch: # Calculate mAP
|
342 |
results, maps, times = test.test(opt.data,
|
@@ -378,8 +381,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
378 |
ckpt = {'epoch': epoch,
|
379 |
'best_fitness': best_fitness,
|
380 |
'training_results': results_file.read_text(),
|
381 |
-
'model':
|
382 |
-
'
|
|
|
383 |
'wandb_id': wandb_run.id if wandb else None}
|
384 |
|
385 |
# Save last, best and delete
|
@@ -387,6 +391,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
387 |
if best_fitness == fi:
|
388 |
torch.save(ckpt, best)
|
389 |
del ckpt
|
|
|
|
|
|
|
390 |
# end epoch ----------------------------------------------------------------------------------------------------
|
391 |
# end training
|
392 |
|
|
|
31 |
from utils.google_utils import attempt_download
|
32 |
from utils.loss import ComputeLoss
|
33 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
34 |
+
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
|
35 |
|
36 |
logger = logging.getLogger(__name__)
|
37 |
|
|
|
136 |
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
|
137 |
loggers = {'wandb': wandb} # loggers dict
|
138 |
|
139 |
+
# EMA
|
140 |
+
ema = ModelEMA(model) if rank in [-1, 0] else None
|
141 |
+
|
142 |
# Resume
|
143 |
start_epoch, best_fitness = 0, 0.0
|
144 |
if pretrained:
|
|
|
147 |
optimizer.load_state_dict(ckpt['optimizer'])
|
148 |
best_fitness = ckpt['best_fitness']
|
149 |
|
150 |
+
# EMA
|
151 |
+
if ema and ckpt.get('ema'):
|
152 |
+
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
|
153 |
+
ema.updates = ckpt['ema'][1]
|
154 |
+
|
155 |
# Results
|
156 |
if ckpt.get('training_results') is not None:
|
157 |
results_file.write_text(ckpt['training_results']) # write results.txt
|
|
|
181 |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
182 |
logger.info('Using SyncBatchNorm()')
|
183 |
|
|
|
|
|
|
|
184 |
# DDP mode
|
185 |
if cuda and rank != -1:
|
186 |
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
|
|
|
196 |
|
197 |
# Process 0
|
198 |
if rank in [-1, 0]:
|
|
|
199 |
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
|
200 |
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
201 |
world_size=opt.world_size, workers=opt.workers,
|
|
|
339 |
# DDP process 0 or single-GPU
|
340 |
if rank in [-1, 0]:
|
341 |
# mAP
|
342 |
+
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
|
|
343 |
final_epoch = epoch + 1 == epochs
|
344 |
if not opt.notest or final_epoch: # Calculate mAP
|
345 |
results, maps, times = test.test(opt.data,
|
|
|
381 |
ckpt = {'epoch': epoch,
|
382 |
'best_fitness': best_fitness,
|
383 |
'training_results': results_file.read_text(),
|
384 |
+
'model': (model.module if is_parallel(model) else model).half(),
|
385 |
+
'ema': (ema.ema.half(), ema.updates),
|
386 |
+
'optimizer': optimizer.state_dict(),
|
387 |
'wandb_id': wandb_run.id if wandb else None}
|
388 |
|
389 |
# Save last, best and delete
|
|
|
391 |
if best_fitness == fi:
|
392 |
torch.save(ckpt, best)
|
393 |
del ckpt
|
394 |
+
|
395 |
+
model.float(), ema.ema.float()
|
396 |
+
|
397 |
# end epoch ----------------------------------------------------------------------------------------------------
|
398 |
# end training
|
399 |
|
utils/general.py
CHANGED
@@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
|
484 |
def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
|
485 |
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
486 |
x = torch.load(f, map_location=torch.device('cpu'))
|
487 |
-
for
|
488 |
-
x[
|
489 |
x['epoch'] = -1
|
490 |
x['model'].half() # to FP16
|
491 |
for p in x['model'].parameters():
|
|
|
484 |
def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
|
485 |
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
486 |
x = torch.load(f, map_location=torch.device('cpu'))
|
487 |
+
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
|
488 |
+
x[k] = None
|
489 |
x['epoch'] = -1
|
490 |
x['model'].half() # to FP16
|
491 |
for p in x['model'].parameters():
|