Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/distributed
/pipelining
/microbatch.py
# mypy: allow-untyped-defs | |
# Copyright (c) Meta Platforms, Inc. and affiliates | |
import logging | |
from typing import Any, Dict, List, Optional, Tuple | |
import torch | |
from torch.fx.node import map_aggregate | |
from torch.utils._pytree import tree_flatten, tree_unflatten | |
__all__ = [ | |
"TensorChunkSpec", | |
"split_args_kwargs_into_chunks", | |
"merge_chunks", | |
] | |
logger = logging.getLogger(__name__) | |
""" | |
_debug_mask_minibatches specifies to send masked versions of the mini-batch | |
through instead of micro-batch slices--this can be used for more stable | |
numerical testing (see [A Note About Correctness Testing]) | |
""" | |
_debug_mask_minibatches = False | |
class _CustomReducer: | |
""" | |
Custom reducer class that can be used to specify a custom operation that | |
reduces losses of multiple microbatches into one value. | |
Example: | |
>>> # xdoctest: +SKIP | |
>>> sum_reducer = _CustomReducer( | |
>>> torch.tensor(0.0), | |
>>> lambda a, b: a + b | |
>>> ) | |
""" | |
def __init__(self, init_value, reduce_fn): | |
self.init_value = init_value | |
self.reduce_fn = reduce_fn | |
class _LossReducer(_CustomReducer): | |
pass | |
sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b) | |
# Default chunking dimension is 0. This is used for the case where the user did | |
# not specify a chunking dimension. | |
DEFAULT_CHUNK_DIM = 0 | |
class TensorChunkSpec: | |
""" | |
Class used to specify chunking of inputs | |
""" | |
def __init__(self, split_dim): | |
self.split_dim = split_dim | |
split_dim: int | |
def __repr__(self): | |
return ( | |
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" | |
) | |
def __str__(self): | |
return f"TensorChunkSpec({self.split_dim})" | |
def from_tuple( | |
chunk_dims: Tuple[int, ...], | |
): | |
""" | |
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk | |
dimensions (int's). | |
Example: | |
>>> # xdoctest: +SKIP | |
>>> # There are three positional arguments to the model, and | |
>>> # we are chunking them along dimension 0, 0 and 1, respectively | |
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) | |
""" | |
args_chunk_spec = map_aggregate( | |
chunk_dims, | |
lambda dim: TensorChunkSpec(dim), | |
) | |
return args_chunk_spec | |
def from_dict( | |
chunk_dims: Dict[str, int], | |
): | |
""" | |
A helper for creating a dictionary of `TensorChunkSpec` from a | |
dictionary of chunk dimensions (int's). | |
Example: | |
>>> # xdoctest: +SKIP | |
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument | |
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) | |
""" | |
kwargs_chunk_spec = map_aggregate( | |
chunk_dims, | |
lambda dim: TensorChunkSpec(dim), | |
) | |
return kwargs_chunk_spec | |
# Class used to specify replication of inputs | |
class _Replicate: | |
pass | |
def _shard_dict_of_args( | |
args_dict, | |
args_chunk_spec, | |
num_chunks, | |
): | |
""" | |
Given a dictionary of args, and a dictionary of chunking specs, shard the | |
args according to the chunking specs. | |
Args: | |
args_dict: Dictionary of args | |
args_chunk_spec: Dictionary of chunking specs | |
num_chunks: Number of chunks to shard the args into | |
Returns: | |
args_split: List of sharded args | |
""" | |
# Stage 1+2: flatten and shard/replicate | |
# args_sharded_replicated : [num args, num flat values, num chunks] | |
args_sharded_replicated = {} | |
arg_specs = [] | |
real_num_chunks = num_chunks | |
first_tensor = True | |
assert len(args_dict) == len( | |
args_chunk_spec | |
), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" | |
for arg_key, arg in args_dict.items(): | |
flat, spec = tree_flatten(arg) | |
arg_specs.append(spec) | |
chunk_spec = args_chunk_spec[arg_key] | |
assert chunk_spec is not None # Should have been set by caller | |
chunk_spec_flat, _ = tree_flatten(chunk_spec) | |
if len(flat) != len(chunk_spec_flat): | |
raise ValueError( | |
f"Argument value {arg} did not have the same number of " | |
f"values as as chunk spec {chunk_spec}" | |
) | |
sharded_arg_flat = [] | |
for v, chunk_v in zip(flat, chunk_spec_flat): | |
if chunk_v is _Replicate or not isinstance(v, torch.Tensor): | |
sharded_arg_flat.append([v] * real_num_chunks) | |
elif isinstance(chunk_v, TensorChunkSpec): | |
# TODO: check type of v. If it's a tensor, use chunk (or debug mask). | |
# If it's a collection type, split it as you would expect. Otherwise, | |
# Throw an error | |
assert isinstance(v, torch.Tensor), f"{v} is not a tensor" | |
v_split_dim_size = v.size(chunk_v.split_dim) | |
if v_split_dim_size < real_num_chunks: | |
if first_tensor: | |
# We can only adjust number of chunks when we hit this | |
# issue at the first tensor encountered | |
logger.warning( | |
f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004 | |
f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}." | |
) | |
real_num_chunks = v_split_dim_size | |
else: | |
raise RuntimeError( | |
f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, " | |
f"smaller than the number of chunks {num_chunks}. " | |
"PiPPy cannot reduce the number of chunks because " | |
"other arguments have bigger chunk-dimension sizes. " | |
"Please adjust your num_chunks setting." | |
) | |
chunk_tensors = torch.tensor_split( | |
v, real_num_chunks, chunk_v.split_dim | |
) | |
if _debug_mask_minibatches: | |
expanded_chunks = [] | |
split_dim_idx = 0 | |
for chunk_tensor in chunk_tensors: | |
new_val = torch.zeros_like(v) | |
upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim) | |
slice_indices = [slice(None, None, None)] * new_val.ndim | |
slice_indices[chunk_v.split_dim] = slice( | |
split_dim_idx, upper_idx | |
) | |
new_val[slice_indices] = chunk_tensor | |
expanded_chunks.append(new_val) | |
split_dim_idx += chunk_tensor.size(chunk_v.split_dim) | |
sharded_arg_flat.append(expanded_chunks) | |
else: | |
sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type] | |
first_tensor = False | |
else: | |
raise TypeError(f"Unrecognized chunk spec: {chunk_v}") | |
args_sharded_replicated[arg_key] = sharded_arg_flat | |
# chunks_flat : [num chunks, num args, num flat values] | |
chunks_flat = [] | |
for chunk_idx in range(real_num_chunks): | |
chunk_args = {} | |
for key, arg in args_sharded_replicated.items(): | |
arg_single_chunk = [] | |
for v_flat in arg: | |
arg_single_chunk.append(v_flat[chunk_idx]) | |
chunk_args[key] = arg_single_chunk | |
chunks_flat.append(chunk_args) | |
# args_split : [num chunks, num args] | |
args_split = [] | |
for chunk in chunks_flat: | |
per_chunk_args = {} | |
assert len(arg_specs) == len(chunk) | |
for (key, arg), arg_spec in zip(chunk.items(), arg_specs): | |
per_chunk_args[key] = tree_unflatten(arg, arg_spec) | |
args_split.append(per_chunk_args) | |
return args_split | |
def split_args_kwargs_into_chunks( | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]], | |
chunks: int, | |
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, | |
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, | |
) -> Tuple[List[Tuple], List[Dict]]: | |
""" | |
Given a sequence of args and kwargs, split them into a number of chunks | |
according to their respective chunking specs. | |
Args: | |
args: Tuple of args | |
kwargs: Dict of kwargs | |
chunks: Number of chunks to split the args and kwargs into | |
args_chunk_spec: chunking specs for args, in same shape as args | |
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs | |
Returns: | |
args_split: List of sharded args | |
kwargs_split: List of sharded kwargs | |
""" | |
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that | |
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` | |
# and `kwargs_chunk_spec` specifications. The steps are as follows: | |
# | |
# 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values. | |
# To use a running example: suppose our inputs look like | |
# | |
# args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None) | |
# (kwargs not shown but it's a similar process) | |
# | |
# Then for this step we would end up with | |
# | |
# args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None) | |
# | |
# 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2 | |
# | |
# args = ([[A, A], [B, B], [C_1, C_2]], [D, D]) | |
# | |
# 3. Rotate the nesting order such that chunks are the outer dimension | |
# | |
# args_chunks = [ | |
# ([A, B, C_1], D), | |
# ([A, B, C_2], D), | |
# ] | |
# | |
# 4. Unflatten each chunk according to the spec | |
# | |
# args_chunks = [ | |
# ([A, [B, C_1]], D), | |
# ([A, [B, C_2]], D), | |
# ] | |
# TODO: _debug_mask_minibatches | |
# Handle the case where kwargs is None | |
if kwargs is None: | |
kwargs = {} | |
# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend | |
# their format and use default chunking along dim 0 | |
if args_chunk_spec is None: | |
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args) | |
if kwargs_chunk_spec is None: | |
kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM)) | |
args_split_dict = _shard_dict_of_args( | |
dict(enumerate(args)), | |
dict(enumerate(args_chunk_spec)), | |
chunks, | |
) | |
real_num_chunks = len(args_split_dict) | |
kwargs_split = _shard_dict_of_args( | |
kwargs, | |
kwargs_chunk_spec, | |
real_num_chunks, | |
) | |
if len(kwargs_split) < real_num_chunks: | |
# In case kwargs are sharded into less chunks | |
# e.g. when `args` has no tensor, just values | |
real_num_chunks = len(kwargs_split) | |
# Re-shard args | |
args_split_dict = _shard_dict_of_args( | |
dict(enumerate(args)), | |
dict(enumerate(args_chunk_spec)), | |
real_num_chunks, | |
) | |
if len(args_split_dict) != len(kwargs_split): | |
raise RuntimeError( | |
"args and kwargs are split into different number of chunks: " | |
f"{len(args_split_dict)}, {len(kwargs_split)}" | |
) | |
args_split = [] | |
for chunk_args in args_split_dict: | |
args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args)))) | |
return args_split, kwargs_split | |
def merge_chunks( | |
chunks: List[Any], | |
chunk_spec, | |
): | |
""" | |
Given a list of chunks, merge them into a single value according to | |
the chunk spec. | |
Args: | |
chunks: list of chunks | |
chunk_spec: Chunking spec for the chunks | |
Returns: | |
value: Merged value | |
""" | |
# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the | |
# steps are similar to the steps in that function but in reverse. Given the | |
# input values: | |
# | |
# chunks = [ | |
# ([A, [B, C_1]], D), | |
# ([A, [B, C_2]], D), | |
# ] | |
# args_spec = ([None, [None, TensorChunkSpec]], None) | |
# | |
# 1. Flatten the chunks according to the chunk_spec | |
# | |
# chunks_flat = [ | |
# ([A, B, C_1], D), | |
# ([A, B, C_2], D), | |
# ] | |
# | |
# 2. Rotate the nesting order such that chunks are the inner dimension | |
# | |
# value_inner = ([A, B, [C_1, C_2]], D) | |
# | |
# 3. Concatenate sharded arguments | |
# | |
# value_combined = ([A, B, C], D) | |
# | |
# 4. Unflatten the combined args given the spec | |
# | |
# value = ([A, [B, C]], D) | |
# Preliminary: flatten the chunk spec | |
if chunk_spec is not None: | |
spec_flattened, flatten_spec = tree_flatten(chunk_spec) | |
else: | |
# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields | |
# We obtain the output structure by flattening chunk 0 and generate the chunk_spec | |
chunk0_flat, flatten_spec = tree_flatten(chunks[0]) | |
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) | |
# Stage 1: flatten chunks | |
# chunks_flattened : [num chunks, num args] | |
chunks_flattened = [] | |
for chunk in chunks: | |
chunk_flattened, _ = tree_flatten(chunk) | |
if len(chunk_flattened) != len(spec_flattened): | |
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") | |
chunks_flattened.append(chunk_flattened) | |
# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and | |
# concatenate sharded operands | |
# args_flattened : [num args] | |
args_flattened = [] | |
for arg_idx, arg in enumerate(spec_flattened): | |
if isinstance(arg, TensorChunkSpec): | |
partial_values = [ | |
chunks_flattened[chunk_idx][arg_idx] | |
for chunk_idx in range(len(chunks_flattened)) | |
] | |
if _debug_mask_minibatches: | |
# Infer size of individual chunks by running `tensor_split` again | |
overall_shape = partial_values[0].shape | |
for val in partial_values[1:]: | |
assert val.shape == overall_shape | |
meta_chunks = torch.tensor_split( | |
torch.empty(*overall_shape, device="meta"), | |
sections=len(partial_values), | |
dim=arg.split_dim, | |
) | |
values_to_cat = [] | |
chunk_start_idx = 0 | |
assert len(partial_values) == len(meta_chunks) | |
for partial_value, meta_chunk in zip(partial_values, meta_chunks): | |
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) | |
slice_indices = [slice(None, None, None)] * partial_value.ndim | |
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) | |
sliced = partial_value[slice_indices] | |
values_to_cat.append(sliced) | |
chunk_start_idx = chunk_end_idx | |
else: | |
values_to_cat = partial_values | |
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) | |
elif isinstance(arg, _CustomReducer): | |
reduced_val = arg.init_value | |
for chunk_idx in range(len(chunks_flattened)): | |
reduced_val = arg.reduce_fn( | |
reduced_val, chunks_flattened[chunk_idx][arg_idx] | |
) | |
args_flattened.append(reduced_val) | |
else: | |
value = chunks_flattened[0][arg_idx] | |
for chunk_idx in range(1, len(chunks_flattened)): | |
assert chunks_flattened[chunk_idx][arg_idx] == value | |
args_flattened.append(value) | |
# Stage 4: Unflatten combined args | |
return tree_unflatten(args_flattened, flatten_spec) | |