import copy import itertools import logging from typing import Dict, Optional, Any from lightning import Callback from lightning.pytorch.utilities.types import STEP_OUTPUT from mmengine.logging import print_log from mmengine.model import is_model_wrapper from mmpl.registry import HOOKS, MODELS @HOOKS.register_module() class EMAHook(Callback): """A Hook to apply Exponential Moving Average (EMA) on the model during training. Note: - EMAHook takes priority over CheckpointHook. - The original model parameters are actually saved in ema field after train. - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. Args: ema_type (str): The type of EMA strategy to use. You can find the supported strategies in :mod:`mmengine.model.averaged_model`. Defaults to 'ExponentialMovingAverage'. strict_load (bool): Whether to strictly enforce that the keys of ``state_dict`` in checkpoint match the keys returned by ``self.module.state_dict``. Defaults to False. Changed in v0.3.0. begin_iter (int): The number of iteration to enable ``EMAHook``. Defaults to 0. begin_epoch (int): The number of epoch to enable ``EMAHook``. Defaults to 0. **kwargs: Keyword arguments passed to subclasses of :obj:`BaseAveragedModel` """ priority = 'NORMAL' def __init__(self, ema_type: str = 'ExponentialMovingAverage', strict_load: bool = False, begin_iter: int = 0, begin_epoch: int = 0, **kwargs): self.strict_load = strict_load self.ema_cfg = dict(type=ema_type, **kwargs) assert not (begin_iter != 0 and begin_epoch != 0), ( '`begin_iter` and `begin_epoch` should not be both set.') assert begin_iter >= 0, ( '`begin_iter` must larger than or equal to 0, ' f'but got begin_iter: {begin_iter}') assert begin_epoch >= 0, ( '`begin_epoch` must larger than or equal to 0, ' f'but got begin_epoch: {begin_epoch}') self.begin_iter = begin_iter self.begin_epoch = begin_epoch # If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be # enabled at 0 iteration. self.enabled_by_epoch = self.begin_epoch > 0 def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Create an ema copy of the model. Args: runner (Runner): The runner of the training process. """ model = pl_module if is_model_wrapper(model): model = model.module self.src_model = model self.ema_model = MODELS.build( self.ema_cfg, default_args=dict(model=self.src_model)) def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Check the begin_epoch/iter is smaller than max_epochs/iters. Args: runner (Runner): The runner of the training process. """ if self.enabled_by_epoch: assert self.begin_epoch <= trainer.max_epochs, ( 'self.begin_epoch should be smaller than or equal to ' f'runner.max_epochs: {trainer.max_epochs}, but got ' f'begin_epoch: {self.begin_epoch}') else: assert self.begin_iter <= trainer.max_steps or self.begin_iter <= trainer.max_epochs * len(trainer.train_dataloader), ( 'self.begin_iter should be smaller than or equal to ' f'runner.max_iters: {trainer.max_steps}, but got ' f'begin_iter: {self.begin_iter}') def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: """Update ema parameter. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ if self._ema_started(trainer): self.ema_model.update_parameters(self.src_model) else: ema_params = self.ema_model.module.state_dict() src_params = self.src_model.state_dict() for k, p in ema_params.items(): p.data.copy_(src_params[k].data) def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """We load parameter values from ema model to source model before validation. Args: runner (Runner): The runner of the training process. """ self._swap_ema_parameters() def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """We recover source model's parameter from ema model after validation. Args: runner (Runner): The runner of the validation process. metrics (Dict[str, float], optional): Evaluation results of all metrics on validation dataset. The keys are the names of the metrics, and the values are corresponding results. """ self._swap_ema_parameters() def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """We load parameter values from ema model to source model before test. Args: runner (Runner): The runner of the training process. """ self._swap_ema_parameters() def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """We recover source model's parameter from ema model after test. Args: runner (Runner): The runner of the testing process. metrics (Dict[str, float], optional): Evaluation results of all metrics on test dataset. The keys are the names of the metrics, and the values are corresponding results. """ self._swap_ema_parameters() def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: """Save ema parameters to checkpoint. Args: runner (Runner): The runner of the testing process. """ checkpoint['ema_state_dict'] = self.ema_model.state_dict() # Save ema parameters to the source model's state dict so that we # can directly load the averaged model weights for deployment. # Swapping the state_dict key-values instead of swapping model # parameters because the state_dict is a shallow copy of model # parameters. self._swap_ema_state_dict(checkpoint) def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: """Resume ema parameters from checkpoint. Args: runner (Runner): The runner of the testing process. """ from mmengine.runner.checkpoint import load_state_dict if 'ema_state_dict' in checkpoint and not trainer._checkpoint_connector._loaded_checkpoint: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) self.ema_model.load_state_dict( checkpoint['ema_state_dict'], strict=self.strict_load) # Support load checkpoint without ema state dict. else: if not trainer._checkpoint_connector._loaded_checkpoint: print_log( 'There is no `ema_state_dict` in checkpoint. ' '`EMAHook` will make a copy of `state_dict` as the ' 'initial `ema_state_dict`', 'current', logging.WARNING) load_state_dict( self.ema_model.module, copy.deepcopy(checkpoint['state_dict']), strict=self.strict_load) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" avg_param = ( itertools.chain(self.ema_model.module.parameters(), self.ema_model.module.buffers()) if self.ema_model.update_buffers else self.ema_model.module.parameters()) src_param = ( itertools.chain(self.src_model.parameters(), self.src_model.buffers()) if self.ema_model.update_buffers else self.src_model.parameters()) for p_avg, p_src in zip(avg_param, src_param): tmp = p_avg.data.clone() p_avg.data.copy_(p_src.data) p_src.data.copy_(tmp) def _swap_ema_state_dict(self, checkpoint): """Swap the state dict values of model with ema_model.""" model_state = checkpoint['state_dict'] ema_state = checkpoint['ema_state_dict'] for k in ema_state: if k[:7] == 'module.': tmp = ema_state[k] ema_state[k] = model_state[k[7:]] model_state[k[7:]] = tmp def _ema_started(self, trainer) -> bool: """Whether ``EMAHook`` has been initialized at current iteration or epoch. :attr:`ema_model` will be initialized when ``runner.iter`` or ``runner.epoch`` is greater than ``self.begin`` for the first time. Args: runner (Runner): Runner of the training, validation process. Returns: bool: Whether ``EMAHook`` has been initialized. """ if self.enabled_by_epoch: return trainer.current_epoch + 1 >= self.begin_epoch else: return trainer.global_step + 1 >= self.begin_iter