Accelerate documentation
Fully Sharded Data Parallel utilities
Fully Sharded Data Parallel utilities
enable_fsdp_ram_efficient_loading
Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
disable_fsdp_ram_efficient_loading
Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
merge_fsdp_weights
accelerate.utils.merge_fsdp_weights
< source >( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )
Parameters
- checkpoint_dir (
str
) — The directory containing the FSDP checkpoints (can be either the model or optimizer). - output_path (
str
) — The path to save the merged checkpoint. - safe_serialization (
bool
, optional, defaults toTrue
) — Whether to save the merged weights with safetensors (recommended). - remove_checkpoint_dir (
bool
, optional, defaults toFalse
) — Whether to remove the checkpoint directory after merging.
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
SHARDED_STATE_DICT
was used for the model. Weights will be saved to {output_path}/model.safetensors
if
safe_serialization
else pytorch_model.bin
.
Note: this is a CPU-bound process.
FullyShardedDataParallelPlugin
class accelerate.FullyShardedDataParallelPlugin
< source >( fsdp_version: int = None sharding_strategy: typing.Union[str, ForwardRef('torch.distributed.fsdp.ShardingStrategy')] = None reshard_after_forward: typing.Union[str, ForwardRef('torch.distributed.fsdp.ShardingStrategy'), bool] = None backward_prefetch: typing.Union[str, ForwardRef('torch.distributed.fsdp.BackwardPrefetch'), NoneType] = None mixed_precision_policy: typing.Union[dict, ForwardRef('torch.distributed.fsdp.MixedPrecision'), ForwardRef('torch.distributed.fsdp.MixedPrecisionPolicy'), NoneType] = None auto_wrap_policy: typing.Union[typing.Callable, typing.Literal['transformer_based_wrap', 'size_based_wrap', 'no_wrap'], NoneType] = None cpu_offload: typing.Union[bool, ForwardRef('torch.distributed.fsdp.CPUOffload'), ForwardRef('torch.distributed.fsdp.CPUOffloadPolicy')] = None ignored_modules: typing.Optional[collections.abc.Iterable[torch.nn.modules.module.Module]] = None state_dict_type: typing.Union[str, ForwardRef('torch.distributed.fsdp.StateDictType')] = None state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedStateDictConfig'), NoneType] = None optim_state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullOptimStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedOptimStateDictConfig'), NoneType] = None limit_all_gathers: bool = True use_orig_params: typing.Optional[bool] = None param_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module], NoneType]] = None sync_module_states: typing.Optional[bool] = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: typing.Optional[list[str]] = None min_num_params: typing.Optional[int] = None )
Parameters
- fsdp_version (
int
, defaults to1
) — The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format. - sharding_strategy (
Union[str, torch.distributed.fsdp.ShardingStrategy]
, defaults to'FULL_SHARD'
) — Sharding strategy to use. Should be either astr
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy
. Is deprecated in favor ofreshard_after_forward
. - reshard_after_forward (
Union[str, torch.distributed.fsdp.ShardingStrategy, bool]
, defaults to'FULL_SHARD'
forfsdp_version=1
andTrue
forfsdp_version=2
) — Sharding strategy to use. Should be a bool iffsdp_version
is set to 2 else astr
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy
. - backward_prefetch (
Union[str, torch.distributed.fsdp.BackwardPrefetch]
, defaults to'NO_PREFETCH'
) — Backward prefetch strategy to use. Should be either astr
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch
. - mixed_precision_policy (
Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]
, defaults toNone
) — A config to enable mixed precision training with FullyShardedDataParallel. If passing in adict
, it should have the following keys:param_dtype
,reduce_dtype
, andbuffer_dtype
, can be an instance oftorch.distributed.fsdp.MixedPrecisionPolicy
iffsdp_version
is set to 2. - auto_wrap_policy (
Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to
NO_WRAP) -- A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of
transformer_based_wrap,
size_based_wrap, or
no_wrap. See
torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. - cpu_offload (
Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy]
, defaults toFalse
) — Whether to offload parameters to CPU. Should be either abool
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload
ortorch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy
iffsdp_version
is set to 2. - ignored_modules (
Optional[Iterable[torch.nn.Module]]
, defaults toNone
) — A list of modules to ignore when wrapping with FSDP. - state_dict_type (
Union[str, torch.distributed.fsdp.StateDictType]
, defaults to'FULL_STATE_DICT'
) — State dict type to use. If a string, it must be one offull_state_dict
,local_state_dict
, orsharded_state_dict
. - state_dict_config (
Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]
, defaults toNone
) — State dict config to use. Is determined based on thestate_dict_type
if not passed in. - optim_state_dict_config (
Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]
, defaults toNone
) — Optim state dict config to use. Is determined based on thestate_dict_type
if not passed in. - limit_all_gathers (
bool
, defaults toTrue
) — Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries. - use_orig_params (
bool
, defaults toFalse
) — Whether to use the original parameters for the optimizer. - param_init_fn (
Optional[Callable[[torch.nn.Module], None]
, defaults toNone
) — ACallable[torch.nn.Module] -> None
that specifies how modules that are currently on the meta device should be initialized onto an actual device. Only applicable whensync_module_states
isTrue
. By default is alambda
which callsto_empty
on the module. - sync_module_states (
bool
, defaults toFalse
) — Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. Defaults toFalse
unlesscpu_ram_efficient_loading
isTrue
, then will be forcibly enabled. - forward_prefetch (
bool
, defaults toFalse
) — Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. only use with Static graphs. - activation_checkpointing (
bool
, defaults toFalse
) — A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage. - cpu_ram_efficient_loading (
bool
, defaults toNone
) — If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for Transformers. When using this,sync_module_states
needs to beTrue
. - transformer_cls_names_to_wrap (
Optional[List[str]]
, defaults toNone
) — A list of transformer layer class names to wrap. Only applicable whenauto_wrap_policy
istransformer_based_wrap
. - min_num_params (
Optional[int]
, defaults toNone
) — The minimum number of parameters a module must have to be wrapped. Only applicable whenauto_wrap_policy
issize_based_wrap
.
This plugin is used to enable fully sharded data parallelism.
Given model
, creates an auto_wrap_policy
baesd on the passed in policy and if we can use the
transformer_cls_to_wrap
Sets the mixed precision policy for FSDP
Set the state dict config based on the StateDictType
.
Validates the mixed precision policy, abstracted away to not bring in the imports if not needed.
fsdp2_load_full_state_dict
accelerate.utils.fsdp2_load_full_state_dict
< source >( accelerator model: Module full_sd: dict )
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place.
fsdp2_switch_optimizer_parameters
accelerate.utils.fsdp2_switch_optimizer_parameters
< source >( optimizer: Optimizer mapping: dict )
Parameters
- optimizer (
torch.optim.Optimizer
) — Optimizer instance which contains the original model parameters - mapping (
dict
) — Mapping from the original parameter (specified bydata_ptr
) to the sharded parameter
Raises
KeyError
KeyError
— If a parameter in the optimizer couldn’t be switched to its sharded version. This should never happen and indicates a bug. If we kept the original params instead of raising, the training wouldn’t be numerically correct and weights wouldn’t get updated.
Switches the parameters of the optimizer to new ones (sharded parameters in usual case). This function modifies the optimizer in-place.
fsdp2_prepare_model
accelerate.utils.fsdp2_prepare_model
< source >( accelerator model: Module ) → torch.nn.Module
Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.