glenn-jocher commited on
Commit
af41083
·
1 Parent(s): f767023

EMA FP16 fix #279

Browse files
Files changed (1) hide show
  1. utils/torch_utils.py +2 -2
utils/torch_utils.py CHANGED
@@ -176,13 +176,13 @@ class ModelEMA:
176
 
177
  def __init__(self, model, decay=0.9999, device=''):
178
  # Create EMA
179
- self.ema = deepcopy(model.module if is_parallel(model) else model).half() # FP16 EMA
180
  self.ema.eval()
181
  self.updates = 0 # number of EMA updates
182
  self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
183
  self.device = device # perform ema on different device from model if set
184
  if device:
185
- self.ema.to(device=device)
186
  for p in self.ema.parameters():
187
  p.requires_grad_(False)
188
 
 
176
 
177
  def __init__(self, model, decay=0.9999, device=''):
178
  # Create EMA
179
+ self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA
180
  self.ema.eval()
181
  self.updates = 0 # number of EMA updates
182
  self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
183
  self.device = device # perform ema on different device from model if set
184
  if device:
185
+ self.ema.to(device)
186
  for p in self.ema.parameters():
187
  p.requires_grad_(False)
188