Spaces:
Running
Running
import functools | |
import pathlib | |
import shutil | |
import time | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union | |
import torch | |
import torch.distributed.checkpoint | |
from torch.distributed.checkpoint.state_dict import ( | |
StateDictOptions, | |
get_model_state_dict, | |
set_model_state_dict, | |
) | |
from torch.distributed.checkpoint.stateful import Stateful | |
from ..logging import get_logger | |
if TYPE_CHECKING: | |
from .. import optimizer | |
logger = get_logger() | |
class ModelWrapper(Stateful): | |
def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: | |
self.model = [model] if isinstance(model, torch.nn.Module) else model | |
def state_dict(self) -> Dict[str, Any]: | |
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
func = functools.partial( | |
set_model_state_dict, | |
model_state_dict=state_dict, | |
options=StateDictOptions(strict=False), | |
) | |
list(map(func, self.model)) | |
class PTDCheckpointManager: | |
def __init__( | |
self, | |
dataloader: torch.utils.data.DataLoader, | |
model_parts: List[torch.nn.Module], | |
optimizers: "optimizer.OptimizerWrapper", | |
schedulers: "optimizer.SchedulerWrapper", | |
states: Dict[str, Any], | |
checkpointing_steps: int, | |
checkpointing_limit: int, | |
output_dir: str, | |
enable: bool = True, | |
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, | |
_prefix: str = "finetrainers_step", | |
) -> None: | |
self.states = states | |
self.states.update( | |
{ | |
"model": ModelWrapper(model_parts), | |
"optimizer": optimizers, | |
"dataloader": dataloader, | |
} | |
) | |
self.states.update(schedulers.get_lr_scheduler_state()) | |
self.checkpointing_steps = checkpointing_steps | |
self.checkpointing_limit = checkpointing_limit | |
self.output_dir = pathlib.Path(output_dir) | |
self.enable = enable | |
self._callback_fn = _callback_fn | |
self._prefix = _prefix | |
logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") | |
def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: | |
if not self._should_checkpoint(step, force): | |
return None | |
checkpoint_dir = self._get_checkpoint_dir(step) | |
begin_time = time.monotonic() | |
torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) | |
end_time = time.monotonic() | |
logger.info( | |
f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" | |
) | |
self._purge_stale_checkpoints() | |
state_dicts = [ | |
gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process) | |
for model in self.states["model"].model | |
] | |
if self._callback_fn is not None: | |
list(map(self._callback_fn, state_dicts)) | |
return checkpoint_dir.as_posix() | |
def load(self, step: int = -1) -> bool: | |
if not self.enable: | |
return False | |
if not self.output_dir.exists(): | |
return False | |
if step != -1 and not self._get_checkpoint_dir(step).exists(): | |
return False | |
if step == -1: | |
latest_checkpoint_dir = self._find_latest_checkpoint_dir() | |
if latest_checkpoint_dir is None: | |
return False | |
step = int(latest_checkpoint_dir.name.split("_")[-1]) | |
checkpoint_dir = self._get_checkpoint_dir(step) | |
logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") | |
# For step 0, optimizers/schedulers are not available as they are created during training after first step | |
states = {"model": self.states["model"]} if step == 0 else self.states | |
# See bug: https://github.com/pytorch/pytorch/pull/138575 | |
original_stateful_states = {k: v for k, v in states.items() if isinstance(v, Stateful)} | |
begin_time = time.monotonic() | |
torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix()) | |
end_time = time.monotonic() | |
logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") | |
# bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load() | |
states.update(original_stateful_states) | |
return True | |
def _should_checkpoint(self, step: int, force: bool) -> bool: | |
if not self.enable: | |
return False | |
if not force: | |
if step % self.checkpointing_steps != 0: | |
return False | |
return True | |
def _get_checkpoint_dir(self, step: int) -> pathlib.Path: | |
return self.output_dir / f"{self._prefix}_{step}" | |
def _find_latest_checkpoint_dir(self) -> Union[pathlib.Path, None]: | |
checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) | |
return checkpoints[-1] if len(checkpoints) > 0 else None | |
def _purge_stale_checkpoints(self) -> None: | |
if self.checkpointing_limit is None or self.checkpointing_limit <= 0: | |
return | |
checkpoints = sorted( | |
self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True | |
) | |
for checkpoint in checkpoints[self.checkpointing_limit :]: | |
logger.info(f"Deleting stale checkpoint: {checkpoint}") | |
shutil.rmtree(checkpoint, ignore_errors=True) | |
def gather_state_dict_on_cpu_rank0( | |
model, device: Optional[torch.device] = None, *, is_main_process: bool | |
) -> Dict[str, Any]: | |
cpu_state_dict = {} | |
sharded_sd = model.state_dict() | |
for param_name, param in sharded_sd.items(): | |
if param.is_cpu: | |
# Move back to device if offloaded to CPU | |
param = param.to(device) | |
if hasattr(param, "_local_tensor"): | |
# Gather DTensor | |
param = param.full_tensor() | |
if is_main_process: | |
cpu_state_dict[param_name] = param.cpu() | |
torch.distributed.barrier() | |
return cpu_state_dict | |
# # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict | |
# def dcp_to_torch_save( | |
# dcp_checkpoint_dir: Union[str, os.PathLike], | |
# torch_save_path: Union[str, os.PathLike], | |
# callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, | |
# ): | |
# """ | |
# Given a directory containing a DCP checkpoint, this function will convert it into a | |
# Torch save file. | |
# Args: | |
# dcp_checkpoint_dir: Directory containing the DCP checkpoint. | |
# torch_save_path: Filename to store the converted Torch save file. | |
# callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict. | |
# .. warning:: | |
# To avoid OOM, it's recommended to only run this function on a single rank. | |
# """ | |
# state_dict = {} | |
# _load_state_dict( | |
# state_dict, | |
# storage_reader=FileSystemReader(dcp_checkpoint_dir), | |
# planner=_EmptyStateDictLoadPlanner(), | |
# no_dist=True, | |
# ) | |
# if callback_fn is not None: | |
# state_dict = callback_fn(state_dict) | |
# torch.save(state_dict, torch_save_path) | |