Commit
·
24c5a94
1
Parent(s):
2b6209a
--resume EMA fix #292
Browse files- train.py +2 -2
- utils/torch_utils.py +3 -7
train.py
CHANGED
@@ -163,6 +163,7 @@ def train(hyp):
|
|
163 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
164 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
165 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
|
|
166 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
167 |
|
168 |
# Testloader
|
@@ -191,11 +192,10 @@ def train(hyp):
|
|
191 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
192 |
|
193 |
# Exponential moving average
|
194 |
-
ema = torch_utils.ModelEMA(model)
|
195 |
|
196 |
# Start training
|
197 |
t0 = time.time()
|
198 |
-
nb = len(dataloader) # number of batches
|
199 |
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
200 |
maps = np.zeros(nc) # mAP per class
|
201 |
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
|
|
163 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
164 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
|
165 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
166 |
+
nb = len(dataloader) # number of batches
|
167 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
168 |
|
169 |
# Testloader
|
|
|
192 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
193 |
|
194 |
# Exponential moving average
|
195 |
+
ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate)
|
196 |
|
197 |
# Start training
|
198 |
t0 = time.time()
|
|
|
199 |
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
200 |
maps = np.zeros(nc) # mAP per class
|
201 |
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
utils/torch_utils.py
CHANGED
@@ -191,15 +191,11 @@ class ModelEMA:
|
|
191 |
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
192 |
"""
|
193 |
|
194 |
-
def __init__(self, model, decay=0.9999,
|
195 |
# Create EMA
|
196 |
-
self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA
|
197 |
-
self.
|
198 |
-
self.updates = 0 # number of EMA updates
|
199 |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
200 |
-
self.device = device # perform ema on different device from model if set
|
201 |
-
if device:
|
202 |
-
self.ema.to(device)
|
203 |
for p in self.ema.parameters():
|
204 |
p.requires_grad_(False)
|
205 |
|
|
|
191 |
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
192 |
"""
|
193 |
|
194 |
+
def __init__(self, model, decay=0.9999, updates=0):
|
195 |
# Create EMA
|
196 |
+
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
197 |
+
self.updates = updates # number of EMA updates
|
|
|
198 |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
|
|
|
|
|
|
199 |
for p in self.ema.parameters():
|
200 |
p.requires_grad_(False)
|
201 |
|