Spaces:
Runtime error
Runtime error
from pytorch_lightning.callbacks import Callback | |
from timm.utils.model import get_state_dict, unwrap_model | |
from timm.utils.model_ema import ModelEmaV2 | |
# Cell | |
class EMACallback(Callback): | |
""" | |
Model Exponential Moving Average. Empirically it has been found that using the moving average | |
of the trained parameters of a deep network is better than using its trained parameters directly. | |
If `use_ema_weights`, then the ema parameters of the network is set after training end. | |
""" | |
def __init__(self, decay=0.9999, use_ema_weights: bool = True): | |
self.decay = decay | |
self.ema = None | |
self.use_ema_weights = use_ema_weights | |
def on_fit_start(self, trainer, pl_module, *args): | |
"Initialize `ModelEmaV2` from timm to keep a copy of the moving average of the weights" | |
self.ema = ModelEmaV2(pl_module, decay=self.decay, device=None) | |
def on_train_batch_end( | |
self, trainer, pl_module, *args | |
): | |
"Update the stored parameters using a moving average" | |
# Update currently maintained parameters. | |
self.ema.update(pl_module) | |
def on_validation_epoch_start(self, trainer, pl_module, *args): | |
"do validation using the stored parameters" | |
# save original parameters before replacing with EMA version | |
self.store(pl_module.parameters()) | |
# update the LightningModule with the EMA weights | |
# ~ Copy EMA parameters to LightningModule | |
self.copy_to(self.ema.module.parameters(), pl_module.parameters()) | |
def on_validation_end(self, trainer, pl_module, *args): | |
"Restore original parameters to resume training later" | |
self.restore(pl_module.parameters()) | |
def on_train_end(self, trainer, pl_module, *args): | |
# update the LightningModule with the EMA weights | |
if self.use_ema_weights: | |
self.copy_to(self.ema.module.parameters(), pl_module.parameters()) | |
msg = "Model weights replaced with the EMA version." | |
def on_save_checkpoint(self, trainer, pl_module, checkpoint, *args): | |
if self.ema is not None: | |
return {"state_dict_ema": get_state_dict(self.ema, unwrap_model)} | |
def on_load_checkpoint(self, callback_state, *args): | |
if self.ema is not None: | |
self.ema.module.load_state_dict(callback_state["state_dict_ema"]) | |
def store(self, parameters): | |
"Save the current parameters for restoring later." | |
self.collected_params = [param.clone() for param in parameters] | |
def restore(self, parameters): | |
""" | |
Restore the parameters stored with the `store` method. | |
Useful to validate the model with EMA parameters without affecting the | |
original optimization process. | |
""" | |
for c_param, param in zip(self.collected_params, parameters): | |
param.data.copy_(c_param.data) | |
def copy_to(self, shadow_parameters, parameters): | |
"Copy current parameters into given collection of parameters." | |
for s_param, param in zip(shadow_parameters, parameters): | |
if param.requires_grad: | |
param.data.copy_(s_param.data) | |