Spaces:
Running
Running
from typing import Optional | |
import torch | |
import torch.distributed._functional_collectives as funcol | |
import torch.distributed.tensor | |
from diffusers.utils import is_accelerate_available | |
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard | |
from torch.distributed._composable.replicate import replicate | |
from ..utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES | |
if is_accelerate_available(): | |
from accelerate import Accelerator | |
from accelerate.utils import ( | |
DataLoaderConfiguration, | |
DistributedDataParallelKwargs, | |
InitProcessGroupKwargs, | |
ProjectConfiguration, | |
) | |
def apply_fsdp2_ptd( | |
model: torch.nn.Module, | |
dp_mesh: torch.distributed.device_mesh.DeviceMesh, | |
param_dtype: torch.dtype, | |
reduce_dtype: torch.dtype, | |
output_dtype: torch.dtype, | |
pp_enabled: bool = False, | |
cpu_offload: bool = False, | |
) -> None: | |
r"""Apply FSDP2 on a model.""" | |
mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True) | |
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} | |
if cpu_offload: | |
fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True) | |
def apply_fully_shard(blocks): | |
for layer_index, block in enumerate(blocks): | |
if pp_enabled: | |
# For PP, do not reshard after forward to avoid per-microbatch | |
# all-gathers, which can be expensive and non-overlapped | |
reshard_after_forward = False | |
else: | |
# As an optimization, do not reshard after forward for the last | |
# transformer block since FSDP would prefetch it immediately | |
reshard_after_forward = layer_index < len(blocks) - 1 | |
fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward) | |
for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: | |
blocks = getattr(model, transformer_block_name, None) | |
if blocks is not None: | |
apply_fully_shard(blocks) | |
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) | |
def apply_ddp_accelerate( | |
model: torch.nn.Module, | |
project_config: Optional[ProjectConfiguration] = None, | |
ddp_kwargs: Optional[DistributedDataParallelKwargs] = None, | |
init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None, | |
dataloader_config: Optional[DataLoaderConfiguration] = None, | |
gradient_accumulation_steps: Optional[int] = None, | |
accelerator: Optional[Accelerator] = None, | |
) -> torch.nn.Module: | |
if accelerator is None: | |
accelerator = Accelerator( | |
project_config=project_config, | |
dataloader_config=dataloader_config, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
log_with=None, | |
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], | |
) | |
if torch.backends.mps.is_available(): | |
accelerator.native_amp = False | |
accelerator.prepare_model(model) | |
return accelerator, model | |
def apply_ddp_ptd(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None: | |
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) | |
def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: | |
if isinstance(x, torch.distributed.tensor.DTensor): | |
# functional collectives do not support DTensor inputs | |
x = x.full_tensor() | |
assert x.numel() == 1 # required by `.item()` | |
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() | |
def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: | |
return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh) | |
def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: | |
return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh) | |