Spaces:
Running
Running
import math | |
import os | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.backends | |
import torch.distributed as dist | |
import torch.distributed.tensor | |
from accelerate import Accelerator | |
from diffusers.utils.torch_utils import is_compiled_module | |
from ..logging import get_logger | |
logger = get_logger() | |
_STRING_TO_DTYPE = { | |
"fp32": torch.float32, | |
"fp16": torch.float16, | |
"bf16": torch.bfloat16, | |
} | |
_DTYPE_TO_STRING = {v: k for k, v in _STRING_TO_DTYPE.items()} | |
_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False | |
def align_device_and_dtype( | |
x: Union[torch.Tensor, Dict[str, torch.Tensor]], | |
device: Optional[torch.device] = None, | |
dtype: Optional[torch.dtype] = None, | |
): | |
if isinstance(x, torch.Tensor): | |
if device is not None: | |
x = x.to(device) | |
if dtype is not None: | |
x = x.to(dtype) | |
elif isinstance(x, dict): | |
if device is not None: | |
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} | |
if dtype is not None: | |
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} | |
return x | |
def _clip_grad_norm_while_handling_failing_dtensor_cases( | |
parameters: Union[torch.Tensor, List[torch.Tensor]], | |
max_norm: float, | |
norm_type: float = 2.0, | |
error_if_nonfinite: bool = False, | |
foreach: Optional[bool] = None, | |
pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | |
) -> Optional[torch.Tensor]: | |
global _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES | |
if not _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES: | |
try: | |
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach, pp_mesh) | |
except NotImplementedError as e: | |
if "DTensor does not support cross-mesh operation" in str(e): | |
# https://github.com/pytorch/pytorch/issues/134212 | |
logger.warning( | |
"DTensor does not support cross-mesh operation. If you haven't fully tensor-parallelized your " | |
"model, while combining other parallelisms such as FSDP, it could be the reason for this error. " | |
"Gradient clipping will be skipped and gradient norm will not be logged." | |
) | |
except Exception as e: | |
logger.warning( | |
f"An error occurred while clipping gradients: {e}. Gradient clipping will be skipped and gradient " | |
f"norm will not be logged." | |
) | |
_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = True | |
return None | |
# Copied from https://github.com/pytorch/torchtitan/blob/4a169701555ab9bd6ca3769f9650ae3386b84c6e/torchtitan/utils.py#L362 | |
def clip_grad_norm_( | |
parameters: Union[torch.Tensor, List[torch.Tensor]], | |
max_norm: float, | |
norm_type: float = 2.0, | |
error_if_nonfinite: bool = False, | |
foreach: Optional[bool] = None, | |
pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | |
) -> torch.Tensor: | |
r""" | |
Clip the gradient norm of parameters. | |
Gradient norm clipping requires computing the gradient norm over the entire model. | |
`torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions. | |
We need to manually reduce the gradient norm across PP stages. | |
See https://github.com/pytorch/torchtitan/issues/596 for details. | |
Args: | |
parameters (`torch.Tensor` or `List[torch.Tensor]`): | |
Tensors that will have gradients normalized. | |
max_norm (`float`): | |
Maximum norm of the gradients after clipping. | |
norm_type (`float`, defaults to `2.0`): | |
Type of p-norm to use. Can be `inf` for infinity norm. | |
error_if_nonfinite (`bool`, defaults to `False`): | |
If `True`, an error is thrown if the total norm of the gradients from `parameters` is `nan`, `inf`, or `-inf`. | |
foreach (`bool`, defaults to `None`): | |
Use the faster foreach-based implementation. If `None`, use the foreach implementation for CUDA and CPU native tensors | |
and silently fall back to the slow implementation for other device types. | |
pp_mesh (`torch.distributed.device_mesh.DeviceMesh`, defaults to `None`): | |
Pipeline parallel device mesh. If not `None`, will reduce gradient norm across PP stages. | |
Returns: | |
`torch.Tensor`: | |
Total norm of the gradients | |
""" | |
grads = [p.grad for p in parameters if p.grad is not None] | |
# TODO(aryan): Wait for next Pytorch release to use `torch.nn.utils.get_total_norm` | |
# total_norm = torch.nn.utils.get_total_norm(grads, norm_type, error_if_nonfinite, foreach) | |
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) | |
# If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. | |
# We can simply reduce the DTensor to get the total norm in this tensor's process group | |
# and then convert it to a local tensor. | |
# It has two purposes: | |
# 1. to make sure the total norm is computed correctly when PP is used (see below) | |
# 2. to return a reduced total_norm tensor whose .item() would return the correct value | |
if isinstance(total_norm, torch.distributed.tensor.DTensor): | |
# Will reach here if any non-PP parallelism is used. | |
# If only using PP, total_norm will be a local tensor. | |
total_norm = total_norm.full_tensor() | |
if pp_mesh is not None: | |
if math.isinf(norm_type): | |
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) | |
else: | |
total_norm **= norm_type | |
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) | |
total_norm **= 1.0 / norm_type | |
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) | |
return total_norm | |
def enable_determinism( | |
seed: int, | |
world_mesh: Optional[torch.distributed.DeviceMesh] = None, | |
deterministic: bool = False, | |
) -> None: | |
r""" | |
For all ranks within the same DTensor SPMD group, the same seed will be set. | |
For PP groups, different seeds will be set. | |
""" | |
if deterministic: | |
logger.info("Deterministic algorithms are enabled (expect performance degradation).") | |
torch.use_deterministic_algorithms(True) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html | |
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
if not world_mesh: | |
if seed is not None: | |
torch.manual_seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed % 2**32) | |
logger.debug(f"Single-process job using seed: {seed}") | |
return | |
# For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, | |
# and choose a unique seed for each rank on the PP mesh. | |
if torch.distributed.distributed_c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names: | |
pp_mesh = world_mesh["pp"] | |
seed += pp_mesh.get_local_rank() | |
seed %= 2**64 | |
info = { | |
"pp_rank": pp_mesh.get_local_rank(), | |
"global_rank": torch.distributed.distributed_c10d.get_rank(), | |
"seed": seed, | |
} | |
logger.debug(f"Enabling determinism: {info}") | |
spmd_mesh_dims = list(filter(lambda name: name != "pp", world_mesh.mesh_dim_names)) | |
spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None | |
else: | |
spmd_mesh = world_mesh | |
info = {"global_rank": torch.distributed.distributed_c10d.get_rank(), "seed": seed} | |
logger.debug(f"Enabling determinism: {info}") | |
# The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency | |
torch.manual_seed(seed) | |
# PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1] | |
os.environ["PYTHONHASHSEED"] = str(seed % 2**32) | |
# As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. | |
# IF PP is also used, this seed is unique per PP rank. | |
if spmd_mesh and spmd_mesh.get_coordinate() is not None: | |
torch.distributed.tensor._random.manual_seed(seed, spmd_mesh) | |
def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor: | |
assert len(tensor.shape) <= ndim | |
return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape))) | |
def get_device_info(): | |
from torch._utils import _get_available_device_type, _get_device_module | |
device_type = _get_available_device_type() | |
if device_type is None: | |
device_type = "cuda" | |
device_module = _get_device_module(device_type) | |
return device_type, device_module | |
def get_dtype_from_string(dtype: str): | |
return _STRING_TO_DTYPE[dtype] | |
def get_string_from_dtype(dtype: torch.dtype): | |
return _DTYPE_TO_STRING[dtype] | |
def set_requires_grad(models: Union[torch.nn.Module, List[torch.nn.Module]], value: bool) -> None: | |
if isinstance(models, torch.nn.Module): | |
models = [models] | |
for model in models: | |
if model is not None: | |
model.requires_grad_(value) | |
def synchronize_device() -> None: | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
elif torch.backends.mps.is_available(): | |
torch.mps.synchronize() | |
def unwrap_model(accelerator: Accelerator, model): | |
model = accelerator.unwrap_model(model) | |
model = model._orig_mod if is_compiled_module(model) else model | |
return model | |
# TODO(aryan): remove everything below this after next torch release | |
def _get_total_norm( | |
tensors: Union[torch.Tensor, List[torch.Tensor]], | |
norm_type: float = 2.0, | |
error_if_nonfinite: bool = False, | |
foreach: Optional[bool] = None, | |
) -> torch.Tensor: | |
if isinstance(tensors, torch.Tensor): | |
tensors = [tensors] | |
else: | |
tensors = list(tensors) | |
norm_type = float(norm_type) | |
if len(tensors) == 0: | |
return torch.tensor(0.0) | |
first_device = tensors[0].device | |
grouped_tensors: dict[ | |
tuple[torch.device, torch.dtype], tuple[list[list[torch.Tensor]], list[int]] | |
] = _group_tensors_by_device_and_dtype( | |
[tensors] # type: ignore[list-item] | |
) # type: ignore[assignment] | |
norms: List[torch.Tensor] = [] | |
for (device, _), ([device_tensors], _) in grouped_tensors.items(): | |
if (foreach is None and _has_foreach_support(device_tensors, device)) or ( | |
foreach and _device_has_foreach_support(device) | |
): | |
norms.extend(torch._foreach_norm(device_tensors, norm_type)) | |
elif foreach: | |
raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") | |
else: | |
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_tensors]) | |
total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) | |
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): | |
raise RuntimeError( | |
f"The total norm of order {norm_type} for gradients from " | |
"`parameters` is non-finite, so it cannot be clipped. To disable " | |
"this error and scale the gradients by the non-finite norm anyway, " | |
"set `error_if_nonfinite=False`" | |
) | |
return total_norm | |
def _clip_grads_with_norm_( | |
parameters: Union[torch.Tensor, List[torch.Tensor]], | |
max_norm: float, | |
total_norm: torch.Tensor, | |
foreach: Optional[bool] = None, | |
) -> None: | |
if isinstance(parameters, torch.Tensor): | |
parameters = [parameters] | |
grads = [p.grad for p in parameters if p.grad is not None] | |
max_norm = float(max_norm) | |
if len(grads) == 0: | |
return | |
grouped_grads: dict[ | |
Tuple[torch.device, torch.dtype], Tuple[List[List[torch.Tensor]], List[int]] | |
] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] | |
clip_coef = max_norm / (total_norm + 1e-6) | |
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so | |
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization | |
# when the gradients do not reside in CPU memory. | |
clip_coef_clamped = torch.clamp(clip_coef, max=1.0) | |
for (device, _), ([device_grads], _) in grouped_grads.items(): | |
if (foreach is None and _has_foreach_support(device_grads, device)) or ( | |
foreach and _device_has_foreach_support(device) | |
): | |
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) | |
elif foreach: | |
raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") | |
else: | |
clip_coef_clamped_device = clip_coef_clamped.to(device) | |
for g in device_grads: | |
g.mul_(clip_coef_clamped_device) | |
def _get_foreach_kernels_supported_devices() -> list[str]: | |
r"""Return the device type list that supports foreach kernels.""" | |
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()] | |
def _group_tensors_by_device_and_dtype( | |
tensorlistlist: List[List[Optional[torch.Tensor]]], | |
with_indices: bool = False, | |
) -> dict[tuple[torch.device, torch.dtype], tuple[List[List[Optional[torch.Tensor]]], List[int]]]: | |
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) | |
def _device_has_foreach_support(device: torch.device) -> bool: | |
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting() | |
def _has_foreach_support(tensors: List[torch.Tensor], device: torch.device) -> bool: | |
return _device_has_foreach_support(device) and all(t is None or type(t) in [torch.Tensor] for t in tensors) | |