Spaces:
Running
Running
File size: 3,995 Bytes
80ebcb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from typing import Optional
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor
from diffusers.utils import is_accelerate_available
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
from torch.distributed._composable.replicate import replicate
from ..utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES
if is_accelerate_available():
from accelerate import Accelerator
from accelerate.utils import (
DataLoaderConfiguration,
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
ProjectConfiguration,
)
def apply_fsdp2_ptd(
model: torch.nn.Module,
dp_mesh: torch.distributed.device_mesh.DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
output_dtype: torch.dtype,
pp_enabled: bool = False,
cpu_offload: bool = False,
) -> None:
r"""Apply FSDP2 on a model."""
mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True)
def apply_fully_shard(blocks):
for layer_index, block in enumerate(blocks):
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_index < len(blocks) - 1
fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward)
for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
blocks = getattr(model, transformer_block_name, None)
if blocks is not None:
apply_fully_shard(blocks)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
def apply_ddp_accelerate(
model: torch.nn.Module,
project_config: Optional[ProjectConfiguration] = None,
ddp_kwargs: Optional[DistributedDataParallelKwargs] = None,
init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None,
dataloader_config: Optional[DataLoaderConfiguration] = None,
gradient_accumulation_steps: Optional[int] = None,
accelerator: Optional[Accelerator] = None,
) -> torch.nn.Module:
if accelerator is None:
accelerator = Accelerator(
project_config=project_config,
dataloader_config=dataloader_config,
gradient_accumulation_steps=gradient_accumulation_steps,
log_with=None,
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
)
if torch.backends.mps.is_available():
accelerator.native_amp = False
accelerator.prepare_model(model)
return accelerator, model
def apply_ddp_ptd(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
if isinstance(x, torch.distributed.tensor.DTensor):
# functional collectives do not support DTensor inputs
x = x.full_tensor()
assert x.numel() == 1 # required by `.item()`
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh)
def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh)
|