|
|
|
|
|
|
|
import math |
|
from copy import deepcopy |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
__all__ = ["ModelEMA", "is_parallel"] |
|
|
|
|
|
def is_parallel(model): |
|
"""check if model is in parallel mode.""" |
|
parallel_type = ( |
|
nn.parallel.DataParallel, |
|
nn.parallel.DistributedDataParallel, |
|
) |
|
return isinstance(model, parallel_type) |
|
|
|
|
|
class ModelEMA: |
|
""" |
|
Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models |
|
Keep a moving average of everything in the model state_dict (parameters and buffers). |
|
This is intended to allow functionality like |
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage |
|
A smoothed version of the weights is necessary for some training schemes to perform well. |
|
This class is sensitive where it is initialized in the sequence of model init, |
|
GPU assignment and distributed training wrappers. |
|
""" |
|
|
|
def __init__(self, model, decay=0.9999, updates=0): |
|
""" |
|
Args: |
|
model (nn.Module): model to apply EMA. |
|
decay (float): ema decay reate. |
|
updates (int): counter of EMA updates. |
|
""" |
|
|
|
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() |
|
self.updates = updates |
|
|
|
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) |
|
for p in self.ema.parameters(): |
|
p.requires_grad_(False) |
|
|
|
def update(self, model): |
|
|
|
with torch.no_grad(): |
|
self.updates += 1 |
|
d = self.decay(self.updates) |
|
|
|
msd = ( |
|
model.module.state_dict() if is_parallel(model) else model.state_dict() |
|
) |
|
for k, v in self.ema.state_dict().items(): |
|
if v.dtype.is_floating_point: |
|
v *= d |
|
v += (1.0 - d) * msd[k].detach() |
|
|