KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
10.1 kB
import copy
import itertools
import logging
from typing import Dict, Optional, Any
from lightning import Callback
from lightning.pytorch.utilities.types import STEP_OUTPUT
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmpl.registry import HOOKS, MODELS
@HOOKS.register_module()
class EMAHook(Callback):
"""A Hook to apply Exponential Moving Average (EMA) on the model during
training.
Note:
- EMAHook takes priority over CheckpointHook.
- The original model parameters are actually saved in ema field after
train.
- ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.
Args:
ema_type (str): The type of EMA strategy to use. You can find the
supported strategies in :mod:`mmengine.model.averaged_model`.
Defaults to 'ExponentialMovingAverage'.
strict_load (bool): Whether to strictly enforce that the keys of
``state_dict`` in checkpoint match the keys returned by
``self.module.state_dict``. Defaults to False.
Changed in v0.3.0.
begin_iter (int): The number of iteration to enable ``EMAHook``.
Defaults to 0.
begin_epoch (int): The number of epoch to enable ``EMAHook``.
Defaults to 0.
**kwargs: Keyword arguments passed to subclasses of
:obj:`BaseAveragedModel`
"""
priority = 'NORMAL'
def __init__(self,
ema_type: str = 'ExponentialMovingAverage',
strict_load: bool = False,
begin_iter: int = 0,
begin_epoch: int = 0,
**kwargs):
self.strict_load = strict_load
self.ema_cfg = dict(type=ema_type, **kwargs)
assert not (begin_iter != 0 and begin_epoch != 0), (
'`begin_iter` and `begin_epoch` should not be both set.')
assert begin_iter >= 0, (
'`begin_iter` must larger than or equal to 0, '
f'but got begin_iter: {begin_iter}')
assert begin_epoch >= 0, (
'`begin_epoch` must larger than or equal to 0, '
f'but got begin_epoch: {begin_epoch}')
self.begin_iter = begin_iter
self.begin_epoch = begin_epoch
# If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be
# enabled at 0 iteration.
self.enabled_by_epoch = self.begin_epoch > 0
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Create an ema copy of the model.
Args:
runner (Runner): The runner of the training process.
"""
model = pl_module
if is_model_wrapper(model):
model = model.module
self.src_model = model
self.ema_model = MODELS.build(
self.ema_cfg, default_args=dict(model=self.src_model))
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Check the begin_epoch/iter is smaller than max_epochs/iters.
Args:
runner (Runner): The runner of the training process.
"""
if self.enabled_by_epoch:
assert self.begin_epoch <= trainer.max_epochs, (
'self.begin_epoch should be smaller than or equal to '
f'runner.max_epochs: {trainer.max_epochs}, but got '
f'begin_epoch: {self.begin_epoch}')
else:
assert self.begin_iter <= trainer.max_steps or self.begin_iter <= trainer.max_epochs * len(trainer.train_dataloader), (
'self.begin_iter should be smaller than or equal to '
f'runner.max_iters: {trainer.max_steps}, but got '
f'begin_iter: {self.begin_iter}')
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Update ema parameter.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model. Defaults to None.
"""
if self._ema_started(trainer):
self.ema_model.update_parameters(self.src_model)
else:
ema_params = self.ema_model.module.state_dict()
src_params = self.src_model.state_dict()
for k, p in ema_params.items():
p.data.copy_(src_params[k].data)
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""We load parameter values from ema model to source model before
validation.
Args:
runner (Runner): The runner of the training process.
"""
self._swap_ema_parameters()
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""We recover source model's parameter from ema model after validation.
Args:
runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
self._swap_ema_parameters()
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""We load parameter values from ema model to source model before test.
Args:
runner (Runner): The runner of the training process.
"""
self._swap_ema_parameters()
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""We recover source model's parameter from ema model after test.
Args:
runner (Runner): The runner of the testing process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
self._swap_ema_parameters()
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
"""Save ema parameters to checkpoint.
Args:
runner (Runner): The runner of the testing process.
"""
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
# Save ema parameters to the source model's state dict so that we
# can directly load the averaged model weights for deployment.
# Swapping the state_dict key-values instead of swapping model
# parameters because the state_dict is a shallow copy of model
# parameters.
self._swap_ema_state_dict(checkpoint)
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
"""Resume ema parameters from checkpoint.
Args:
runner (Runner): The runner of the testing process.
"""
from mmengine.runner.checkpoint import load_state_dict
if 'ema_state_dict' in checkpoint and not trainer._checkpoint_connector._loaded_checkpoint:
# The original model parameters are actually saved in ema
# field swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(
checkpoint['ema_state_dict'], strict=self.strict_load)
# Support load checkpoint without ema state dict.
else:
if not trainer._checkpoint_connector._loaded_checkpoint:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
load_state_dict(
self.ema_model.module,
copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load)
def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model."""
avg_param = (
itertools.chain(self.ema_model.module.parameters(),
self.ema_model.module.buffers())
if self.ema_model.update_buffers else
self.ema_model.module.parameters())
src_param = (
itertools.chain(self.src_model.parameters(),
self.src_model.buffers())
if self.ema_model.update_buffers else self.src_model.parameters())
for p_avg, p_src in zip(avg_param, src_param):
tmp = p_avg.data.clone()
p_avg.data.copy_(p_src.data)
p_src.data.copy_(tmp)
def _swap_ema_state_dict(self, checkpoint):
"""Swap the state dict values of model with ema_model."""
model_state = checkpoint['state_dict']
ema_state = checkpoint['ema_state_dict']
for k in ema_state:
if k[:7] == 'module.':
tmp = ema_state[k]
ema_state[k] = model_state[k[7:]]
model_state[k[7:]] = tmp
def _ema_started(self, trainer) -> bool:
"""Whether ``EMAHook`` has been initialized at current iteration or
epoch.
:attr:`ema_model` will be initialized when ``runner.iter`` or
``runner.epoch`` is greater than ``self.begin`` for the first time.
Args:
runner (Runner): Runner of the training, validation process.
Returns:
bool: Whether ``EMAHook`` has been initialized.
"""
if self.enabled_by_epoch:
return trainer.current_epoch + 1 >= self.begin_epoch
else:
return trainer.global_step + 1 >= self.begin_iter