Spaces:
Sleeping
Sleeping
# Meant to work with Apex's DistributeFusedAdam | |
from typing import Any, Callable, Dict, List, Optional, Union | |
from pathlib import Path | |
import types | |
import torch | |
from torch.optim.optimizer import Optimizer | |
from torch.optim import LBFGS | |
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam | |
from pytorch_lightning.strategies.ddp import DDPStrategy | |
from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin | |
from pytorch_lightning.core.optimizer import LightningOptimizer | |
from pytorch_lightning.utilities.exceptions import MisconfigurationException | |
try: # pytorch_lightning <= 1.7 | |
from pytorch_lightning.utilities.types import _PATH | |
except ImportError: # pytorch_lightning >= 1.8 | |
try: | |
from lightning_lite.utilities.types import _PATH | |
except ImportError: # pytorch_lightning >= 1.9 | |
from lightning_fabric.utilities.types import _PATH | |
class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): | |
def optimizer_step( # type: ignore[override] | |
self, | |
model: "pl.LightningModule", | |
optimizer, | |
optimizer_idx: int, | |
closure: Callable[[], Any], | |
**kwargs: Any, | |
) -> Any: | |
if self.scaler is None: | |
# skip scaler logic, as bfloat16 does not require scaler | |
return NativeMixedPrecisionPlugin.optimizer_step( | |
self, optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs | |
) | |
if isinstance(optimizer, LBFGS): | |
raise MisconfigurationException( | |
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." | |
) | |
closure_result = closure() | |
# HACK: we don't call self.scaler.unscale_ here. This is because DistributedFusedAdam | |
# optimizer internally takes the scale into account. | |
# If we call unscale_ here, it would be equivalent to unscaling the gradients twice. | |
# Not unscaling has the side-effect that the NormMonitor callback will report the | |
# gradient norm to be much larger than reality. | |
# # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. | |
# self.scaler.unscale_(optimizer) | |
# This will call gradient clipping | |
self._after_closure(model, optimizer, optimizer_idx) | |
skipped_backward = closure_result is None | |
# in manual optimization, the closure does not return a value | |
if not model.automatic_optimization or not skipped_backward: | |
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found | |
step_output = self.scaler.step(optimizer, **kwargs) | |
self.scaler.update() | |
return step_output | |
return closure_result | |
def clip_grad_by_norm(self, optimizer: DistributedFusedAdam, clip_val: Union[int, float]) -> None: | |
"""Clip gradients by norm.""" | |
# DistributedFusedAdam wants list, not generator | |
# Gradients have not be scaled, so we need to scale up the clip_val | |
if self.scaler is not None: | |
clip_val *= self.scaler.get_scale() | |
return optimizer.clip_grad_norm(clip_val) | |
class DDPStrategyZero2(DDPStrategy): | |
"""To use Apex's DistributedFusedAdam, we need to shard the optimizer states when | |
saving/loading checkpoints. | |
""" | |
strategy_name = "ddp_zero2" | |
def __init__( | |
self, | |
*args, | |
precision_plugin: Optional[PrecisionPlugin] = DistAdamNativeMixedPrecisionPlugin, | |
# precision_plugin: Optional[PrecisionPlugin] = None, | |
**kwargs: Union[Any, Dict[str, Any]], | |
) -> None: | |
super().__init__( | |
*args, precision_plugin=precision_plugin, **kwargs | |
) | |
def precision_plugin(self) -> PrecisionPlugin: | |
return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin() | |
def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None: | |
self._precision_plugin = precision_plugin | |
# https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance | |
self._precision_plugin.optimizer_step = types.MethodType( | |
DistAdamNativeMixedPrecisionPlugin.optimizer_step, self._precision_plugin | |
) | |
self._precision_plugin.clip_grad_by_norm = types.MethodType( | |
DistAdamNativeMixedPrecisionPlugin.clip_grad_by_norm, self._precision_plugin | |
) | |
def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: | |
if isinstance(optimizer, LightningOptimizer): | |
optimizer = optimizer._optimizer | |
if isinstance(optimizer, DistributedFusedAdam): | |
return optimizer.state_dict(gather_on_root=False) | |
else: | |
return optimizer.state_dict() | |
def save_checkpoint( | |
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None | |
) -> None: | |
"""Save model/training states as a checkpoint file through state-dump and file-write. | |
Args: | |
checkpoint: dict containing model and trainer state | |
filepath: write-target file's path | |
storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin | |
""" | |
filepath = Path(filepath) | |
filepath.mkdir(parents=True, exist_ok=True) | |
local_optimizer_states = checkpoint.pop('optimizer_states') | |
if self.is_global_zero: | |
self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt', | |
storage_options=storage_options) | |
self.checkpoint_io.save_checkpoint(local_optimizer_states, | |
filepath / f'{self.global_rank:03d}_optim_states.pt', | |
storage_options=storage_options) | |
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: | |
torch.cuda.empty_cache() | |
checkpoint_path = Path(checkpoint_path) | |
if checkpoint_path.is_file(): | |
return super().load_checkpoint(self, str(checkpoint_path)) | |
else: | |
assert checkpoint_path.is_dir() | |
global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt') | |
local_optimizer_states = self.checkpoint_io.load_checkpoint( | |
checkpoint_path / f'{self.global_rank:03d}_optim_states.pt', | |
map_location='cuda' | |
) | |
global_states['optimizer_states'] = local_optimizer_states | |
return global_states | |