Improved model+EMA checkpointing 2 (#2295)
Browse files
test.py
CHANGED
@@ -269,6 +269,7 @@ def test(data,
|
|
269 |
print(f'pycocotools unable to run: {e}')
|
270 |
|
271 |
# Return results
|
|
|
272 |
if not training:
|
273 |
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
274 |
print(f"Results saved to {save_dir}{s}")
|
|
|
269 |
print(f'pycocotools unable to run: {e}')
|
270 |
|
271 |
# Return results
|
272 |
+
model.float() # for training
|
273 |
if not training:
|
274 |
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
275 |
print(f"Results saved to {save_dir}{s}")
|
train.py
CHANGED
@@ -4,6 +4,7 @@ import math
|
|
4 |
import os
|
5 |
import random
|
6 |
import time
|
|
|
7 |
from pathlib import Path
|
8 |
from threading import Thread
|
9 |
|
@@ -381,8 +382,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
381 |
ckpt = {'epoch': epoch,
|
382 |
'best_fitness': best_fitness,
|
383 |
'training_results': results_file.read_text(),
|
384 |
-
'model': (model.module if is_parallel(model) else model).half(),
|
385 |
-
'ema': (ema.ema.half(), ema.updates),
|
386 |
'optimizer': optimizer.state_dict(),
|
387 |
'wandb_id': wandb_run.id if wandb else None}
|
388 |
|
@@ -392,8 +393,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
392 |
torch.save(ckpt, best)
|
393 |
del ckpt
|
394 |
|
395 |
-
model.float(), ema.ema.float()
|
396 |
-
|
397 |
# end epoch ----------------------------------------------------------------------------------------------------
|
398 |
# end training
|
399 |
|
|
|
4 |
import os
|
5 |
import random
|
6 |
import time
|
7 |
+
from copy import deepcopy
|
8 |
from pathlib import Path
|
9 |
from threading import Thread
|
10 |
|
|
|
382 |
ckpt = {'epoch': epoch,
|
383 |
'best_fitness': best_fitness,
|
384 |
'training_results': results_file.read_text(),
|
385 |
+
'model': deepcopy(model.module if is_parallel(model) else model).half(),
|
386 |
+
'ema': (deepcopy(ema.ema).half(), ema.updates),
|
387 |
'optimizer': optimizer.state_dict(),
|
388 |
'wandb_id': wandb_run.id if wandb else None}
|
389 |
|
|
|
393 |
torch.save(ckpt, best)
|
394 |
del ckpt
|
395 |
|
|
|
|
|
396 |
# end epoch ----------------------------------------------------------------------------------------------------
|
397 |
# end training
|
398 |
|