glenn-jocher commited on
Commit
24c5a94
·
1 Parent(s): 2b6209a

--resume EMA fix #292

Browse files
Files changed (2) hide show
  1. train.py +2 -2
  2. utils/torch_utils.py +3 -7
train.py CHANGED
@@ -163,6 +163,7 @@ def train(hyp):
163
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
164
  hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
165
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
 
166
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
167
 
168
  # Testloader
@@ -191,11 +192,10 @@ def train(hyp):
191
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
192
 
193
  # Exponential moving average
194
- ema = torch_utils.ModelEMA(model)
195
 
196
  # Start training
197
  t0 = time.time()
198
- nb = len(dataloader) # number of batches
199
  nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
200
  maps = np.zeros(nc) # mAP per class
201
  results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
 
163
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
164
  hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
165
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
166
+ nb = len(dataloader) # number of batches
167
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
168
 
169
  # Testloader
 
192
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
193
 
194
  # Exponential moving average
195
+ ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate)
196
 
197
  # Start training
198
  t0 = time.time()
 
199
  nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
200
  maps = np.zeros(nc) # mAP per class
201
  results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
utils/torch_utils.py CHANGED
@@ -191,15 +191,11 @@ class ModelEMA:
191
  I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
192
  """
193
 
194
- def __init__(self, model, decay=0.9999, device=''):
195
  # Create EMA
196
- self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA
197
- self.ema.eval()
198
- self.updates = 0 # number of EMA updates
199
  self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
200
- self.device = device # perform ema on different device from model if set
201
- if device:
202
- self.ema.to(device)
203
  for p in self.ema.parameters():
204
  p.requires_grad_(False)
205
 
 
191
  I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
192
  """
193
 
194
+ def __init__(self, model, decay=0.9999, updates=0):
195
  # Create EMA
196
+ self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
197
+ self.updates = updates # number of EMA updates
 
198
  self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
 
 
 
199
  for p in self.ema.parameters():
200
  p.requires_grad_(False)
201