Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/distributed
/checkpoint
/_checkpointer.py
from concurrent.futures import Future | |
from typing import Any, Dict, List, Optional | |
import torch.distributed as dist | |
import torch.distributed.checkpoint.state_dict_loader as loader | |
import torch.distributed.checkpoint.state_dict_saver as saver | |
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE | |
from torch.distributed.checkpoint.storage import ( | |
LoadPlanner, | |
SavePlanner, | |
StorageReader, | |
StorageWriter, | |
) | |
__all__: List[str] = [] | |
class _Checkpointer: | |
"""This base class specefies a high level API for saving and loading | |
distributed `state_dict` 's. It provides an abstraction over the low-level APIs | |
provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling | |
:py:meth: `torch.distributed.state_dict_saver.save` and | |
:py:meth: `torch.distributed.state_dict_loader.load` with the provided storage | |
readers and writers. | |
.. warning:: | |
This feature is experimental and subject to removal/change. | |
""" | |
def __init__( | |
self, | |
storage_writer: StorageWriter, | |
storage_reader: StorageReader, | |
*, | |
process_group: Optional[dist.ProcessGroup] = None, | |
coordinator_rank: int = 0, | |
no_dist: bool = False, | |
load_planner: Optional[LoadPlanner] = None, | |
save_planner: Optional[SavePlanner] = None, | |
): | |
"""Initializes the Checkpointer instance. | |
Args: | |
storage_writer: Instance of StorageWrite use to perform writes. | |
storage_reader: StorageReader used to load data from. | |
process_group: ProcessGroup to be used for cross-rank synchronization. | |
coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. | |
no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) | |
loader_planner: Instance of LoadPlanner to use when loading. | |
save_planner: Instance of SavePlanner to use when saving. | |
""" | |
self.storage_writer = storage_writer | |
self.storage_reader = storage_reader | |
self.process_group = process_group | |
self.coordinator_rank = coordinator_rank | |
self.no_dist = no_dist | |
self.load_planner = load_planner | |
self.save_planner = save_planner | |
def save( | |
self, | |
state_dict: STATE_DICT_TYPE, | |
) -> Metadata: | |
"""Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization.""" | |
return saver.save( | |
state_dict, | |
self.storage_writer, | |
process_group=self.process_group, | |
coordinator_rank=self.coordinator_rank, | |
no_dist=self.no_dist, | |
planner=self.save_planner, | |
) | |
def async_save( | |
self, | |
state_dict: STATE_DICT_TYPE, | |
) -> Future: | |
""" | |
Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization. | |
Returns: | |
Future: A future holding the resultant Metadata object from `save`. | |
""" | |
return saver.async_save( | |
state_dict, | |
storage_writer=self.storage_writer, | |
process_group=self.process_group, | |
planner=self.save_planner, | |
) | |
def load(self, state_dict: Dict[str, Any]) -> None: | |
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" | |
loader.load( | |
state_dict, | |
storage_reader=self.storage_reader, | |
process_group=self.process_group, | |
planner=self.load_planner, | |
) | |