from pytorch_lightning.callbacks import Callback from timm.utils.model import get_state_dict, unwrap_model from timm.utils.model_ema import ModelEmaV2 # Cell class EMACallback(Callback): """ Model Exponential Moving Average. Empirically it has been found that using the moving average of the trained parameters of a deep network is better than using its trained parameters directly. If `use_ema_weights`, then the ema parameters of the network is set after training end. """ def __init__(self, decay=0.9999, use_ema_weights: bool = True): self.decay = decay self.ema = None self.use_ema_weights = use_ema_weights def on_fit_start(self, trainer, pl_module, *args): "Initialize `ModelEmaV2` from timm to keep a copy of the moving average of the weights" self.ema = ModelEmaV2(pl_module, decay=self.decay, device=None) def on_train_batch_end( self, trainer, pl_module, *args ): "Update the stored parameters using a moving average" # Update currently maintained parameters. self.ema.update(pl_module) def on_validation_epoch_start(self, trainer, pl_module, *args): "do validation using the stored parameters" # save original parameters before replacing with EMA version self.store(pl_module.parameters()) # update the LightningModule with the EMA weights # ~ Copy EMA parameters to LightningModule self.copy_to(self.ema.module.parameters(), pl_module.parameters()) def on_validation_end(self, trainer, pl_module, *args): "Restore original parameters to resume training later" self.restore(pl_module.parameters()) def on_train_end(self, trainer, pl_module, *args): # update the LightningModule with the EMA weights if self.use_ema_weights: self.copy_to(self.ema.module.parameters(), pl_module.parameters()) msg = "Model weights replaced with the EMA version." def on_save_checkpoint(self, trainer, pl_module, checkpoint, *args): if self.ema is not None: return {"state_dict_ema": get_state_dict(self.ema, unwrap_model)} def on_load_checkpoint(self, callback_state, *args): if self.ema is not None: self.ema.module.load_state_dict(callback_state["state_dict_ema"]) def store(self, parameters): "Save the current parameters for restoring later." self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) def copy_to(self, shadow_parameters, parameters): "Copy current parameters into given collection of parameters." for s_param, param in zip(shadow_parameters, parameters): if param.requires_grad: param.data.copy_(s_param.data)