Commit
·
df224a0
1
Parent(s):
a9d20eb
EMA bug fix #279
Browse files- train.py +2 -2
- utils/torch_utils.py +9 -12
train.py
CHANGED
@@ -294,7 +294,7 @@ def train(hyp):
|
|
294 |
batch_size=batch_size,
|
295 |
imgsz=imgsz_test,
|
296 |
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
297 |
-
model=ema.ema
|
298 |
single_cls=opt.single_cls,
|
299 |
dataloader=testloader)
|
300 |
|
@@ -324,7 +324,7 @@ def train(hyp):
|
|
324 |
ckpt = {'epoch': epoch,
|
325 |
'best_fitness': best_fitness,
|
326 |
'training_results': f.read(),
|
327 |
-
'model': ema.ema
|
328 |
'optimizer': None if final_epoch else optimizer.state_dict()}
|
329 |
|
330 |
# Save last, best and delete
|
|
|
294 |
batch_size=batch_size,
|
295 |
imgsz=imgsz_test,
|
296 |
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
297 |
+
model=ema.ema,
|
298 |
single_cls=opt.single_cls,
|
299 |
dataloader=testloader)
|
300 |
|
|
|
324 |
ckpt = {'epoch': epoch,
|
325 |
'best_fitness': best_fitness,
|
326 |
'training_results': f.read(),
|
327 |
+
'model': ema.ema,
|
328 |
'optimizer': None if final_epoch else optimizer.state_dict()}
|
329 |
|
330 |
# Save last, best and delete
|
utils/torch_utils.py
CHANGED
@@ -175,8 +175,8 @@ class ModelEMA:
|
|
175 |
"""
|
176 |
|
177 |
def __init__(self, model, decay=0.9999, device=''):
|
178 |
-
#
|
179 |
-
self.ema = deepcopy(model)
|
180 |
self.ema.eval()
|
181 |
self.updates = 0 # number of EMA updates
|
182 |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
@@ -187,22 +187,19 @@ class ModelEMA:
|
|
187 |
p.requires_grad_(False)
|
188 |
|
189 |
def update(self, model):
|
190 |
-
|
191 |
-
d = self.decay(self.updates)
|
192 |
with torch.no_grad():
|
193 |
-
|
194 |
-
|
195 |
-
else:
|
196 |
-
msd, esd = model.state_dict(), self.ema.state_dict()
|
197 |
|
198 |
-
|
|
|
199 |
if v.dtype.is_floating_point:
|
200 |
v *= d
|
201 |
v += (1. - d) * msd[k].detach()
|
202 |
|
203 |
def update_attr(self, model):
|
204 |
-
# Update
|
205 |
-
ema = self.ema.module if is_parallel(model) else self.ema
|
206 |
for k, v in model.__dict__.items():
|
207 |
if not k.startswith('_') and k != 'module':
|
208 |
-
setattr(ema, k, v)
|
|
|
175 |
"""
|
176 |
|
177 |
def __init__(self, model, decay=0.9999, device=''):
|
178 |
+
# Create EMA
|
179 |
+
self.ema = deepcopy(model.module if is_parallel(model) else model).half() # FP16 EMA
|
180 |
self.ema.eval()
|
181 |
self.updates = 0 # number of EMA updates
|
182 |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
|
|
187 |
p.requires_grad_(False)
|
188 |
|
189 |
def update(self, model):
|
190 |
+
# Update EMA parameters
|
|
|
191 |
with torch.no_grad():
|
192 |
+
self.updates += 1
|
193 |
+
d = self.decay(self.updates)
|
|
|
|
|
194 |
|
195 |
+
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
|
196 |
+
for k, v in self.ema.state_dict().items():
|
197 |
if v.dtype.is_floating_point:
|
198 |
v *= d
|
199 |
v += (1. - d) * msd[k].detach()
|
200 |
|
201 |
def update_attr(self, model):
|
202 |
+
# Update EMA attributes
|
|
|
203 |
for k, v in model.__dict__.items():
|
204 |
if not k.startswith('_') and k != 'module':
|
205 |
+
setattr(self.ema, k, v)
|