dinhdat1110's picture
Upload folder using huggingface_hub
9457143 verified
from pytorch_lightning.callbacks import Callback
from timm.utils.model import get_state_dict, unwrap_model
from timm.utils.model_ema import ModelEmaV2
# Cell
class EMACallback(Callback):
"""
Model Exponential Moving Average. Empirically it has been found that using the moving average
of the trained parameters of a deep network is better than using its trained parameters directly.
If `use_ema_weights`, then the ema parameters of the network is set after training end.
"""
def __init__(self, decay=0.9999, use_ema_weights: bool = True):
self.decay = decay
self.ema = None
self.use_ema_weights = use_ema_weights
def on_fit_start(self, trainer, pl_module, *args):
"Initialize `ModelEmaV2` from timm to keep a copy of the moving average of the weights"
self.ema = ModelEmaV2(pl_module, decay=self.decay, device=None)
def on_train_batch_end(
self, trainer, pl_module, *args
):
"Update the stored parameters using a moving average"
# Update currently maintained parameters.
self.ema.update(pl_module)
def on_validation_epoch_start(self, trainer, pl_module, *args):
"do validation using the stored parameters"
# save original parameters before replacing with EMA version
self.store(pl_module.parameters())
# update the LightningModule with the EMA weights
# ~ Copy EMA parameters to LightningModule
self.copy_to(self.ema.module.parameters(), pl_module.parameters())
def on_validation_end(self, trainer, pl_module, *args):
"Restore original parameters to resume training later"
self.restore(pl_module.parameters())
def on_train_end(self, trainer, pl_module, *args):
# update the LightningModule with the EMA weights
if self.use_ema_weights:
self.copy_to(self.ema.module.parameters(), pl_module.parameters())
msg = "Model weights replaced with the EMA version."
def on_save_checkpoint(self, trainer, pl_module, checkpoint, *args):
if self.ema is not None:
return {"state_dict_ema": get_state_dict(self.ema, unwrap_model)}
def on_load_checkpoint(self, callback_state, *args):
if self.ema is not None:
self.ema.module.load_state_dict(callback_state["state_dict_ema"])
def store(self, parameters):
"Save the current parameters for restoring later."
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
def copy_to(self, shadow_parameters, parameters):
"Copy current parameters into given collection of parameters."
for s_param, param in zip(shadow_parameters, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)