# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from contextlib import contextmanager from typing import Union import torch import torch.nn as nn from mmengine.device import (is_cuda_available, is_mlu_available, is_npu_available) from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION from .optimizer_wrapper import OptimWrapper if is_npu_available(): from torch.npu.amp import GradScaler elif is_mlu_available(): from torch.mlu.amp import GradScaler else: from torch.cuda.amp import GradScaler @OPTIM_WRAPPERS.register_module() class AmpOptimWrapper(OptimWrapper): """A subclass of :class:`OptimWrapper` that supports automatic mixed precision training based on torch.cuda.amp. ``AmpOptimWrapper`` provides a unified interface with ``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way as ``OptimWrapper``. Warnings: ``AmpOptimWrapper`` requires PyTorch >= 1.6. Args: loss_scale (float or str or dict): The initial configuration of `torch.cuda.amp.GradScaler`. See more specific arguments introduction at `PyTorch AMP `_ # noqa: E501 Defaults to ``dynamic``. - "dynamic": Initialize GradScale without any arguments. - float: Initialize GradScaler with ``init_scale``. - dict: Initialize GradScaler with more detail configuration. dtype (str or torch.dtype, optional): The data type to autocast in amp. If a ``str`` is given, it will be converted to ``torch.dtype``. Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and `'float64'`. If set to ``None``, the default data type will be used. Defaults to None. `New in version 0.6.1.` use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should be enabled when using ``FullyShardedDataParallel``. Defaults to False. `New in version 0.8.0.` **kwargs: Keyword arguments passed to OptimWrapper. Warnings: ``dtype`` argument is only available with PyTorch version >= 1.10.0. If you use PyTorch of an older version, it will be ignored. Note: If you use ``IterBasedRunner`` and enable gradient accumulation, the original `max_iters` should be multiplied by ``accumulative_counts``. """ valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64') def __init__(self, loss_scale: str = 'dynamic', dtype: Union[str, torch.dtype] = None, use_fsdp: bool = False, **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') assert is_cuda_available() or is_npu_available() or is_mlu_available( ), ('``AmpOptimizerWrapper`` is only available training ' 'on gpu, npu or mlu') super().__init__(**kwargs) self._scale_update_param = None if use_fsdp: if digit_version(torch.__version__) >= digit_version('2.0.0'): from torch.distributed.fsdp.sharded_grad_scaler import \ ShardedGradScaler scaler_type = ShardedGradScaler else: raise RuntimeError( 'PyTorch>=2.0.0 is required when sets `use_fsdp=True`') else: scaler_type = GradScaler if loss_scale == 'dynamic': # If loss_scale is a string, it must be 'dynamic', then dynamic # loss scaling will be used. self.loss_scaler = scaler_type() elif isinstance(loss_scale, float): # Static loss scaling self._scale_update_param = loss_scale self.loss_scaler = scaler_type(init_scale=loss_scale) elif isinstance(loss_scale, dict): # More specific configuration. self.loss_scaler = scaler_type(**loss_scale) else: raise TypeError('loss_scale must be of type float, dict, or ' f'"dynamic", but got {loss_scale}') # convert string value to torch.dtype if isinstance(dtype, str): assert dtype in self.valid_dtypes, ( f'dtype should be any of {self.valid_dtypes}, got {dtype}') dtype = getattr(torch, dtype) assert dtype is None or isinstance(dtype, torch.dtype), ( f'dtype should be None or instance of torch.dtype, got {dtype}') self.cast_dtype = dtype def backward(self, loss: torch.Tensor, **kwargs): """Perform gradient back propagation with :attr:`loss_scaler`. Args: loss (torch.Tensor): The loss of current iteration. kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` """ self.loss_scaler.scale(loss).backward(**kwargs) self._inner_count += 1 def step(self, **kwargs): """Update parameters with :attr:`loss_scaler`. Args: kwargs: Keyword arguments passed to :meth:`torch.optim.Optimizer.step`. """ ##-------------zero out nan-------------- params = [p for pg in self.optimizer.param_groups for p in pg["params"]] for p in params: if hasattr(p, "grad") and p.grad is not None: p.grad.data[torch.isnan(p.grad.data)] = 0 p.grad.data[torch.isinf(p.grad.data)] = 0 ##---------------------------------------- if self.clip_grad_kwargs: self.loss_scaler.unscale_(self.optimizer) self._clip_grad() self.loss_scaler.step(self.optimizer, **kwargs) self.loss_scaler.update(self._scale_update_param) def state_dict(self) -> dict: """Get the state dictionary of :attr:`optimizer` and :attr:`loss_scaler`. Based on the state dictionary of the optimizer, the returned state dictionary will add a key named "loss_scaler". Returns: dict: The merged state dict of :attr:`loss_scaler` and :attr:`optimizer`. """ # save state_dict of loss_scaler state_dict = super().state_dict() state_dict['loss_scaler'] = self.loss_scaler.state_dict() return state_dict def load_state_dict(self, state_dict: dict): """Load and parse the state dictionary of :attr:`optimizer` and :attr:`loss_scaler`. If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will load the corresponding keys. Otherwise, only the :attr:`optimizer` will load the state dictionary. Args: state_dict (dict): The state dict of :attr:`optimizer` and :attr:`loss_scaler` """ if 'loss_scaler' in state_dict: self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler')) if 'base_param_settings' in state_dict: self.base_param_settings = state_dict.pop('base_param_settings') # load state_dict of optimizer self.optimizer.load_state_dict(state_dict) @contextmanager def optim_context(self, model: nn.Module): """Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context. Args: model (nn.Module): The training model. """ from mmengine.runner.amp import autocast with super().optim_context(model), autocast(dtype=self.cast_dtype): yield