|
|
|
from abc import ABC, abstractmethod |
|
from contextlib import contextmanager, nullcontext |
|
from copy import copy |
|
from dataclasses import dataclass |
|
from functools import partial, wraps |
|
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
import torch.distributed._functional_collectives |
|
import torch.nn as nn |
|
import torch.utils._pytree as pytree |
|
|
|
from functorch import make_fx |
|
|
|
from torch import fx |
|
from torch._decomp.decompositions import native_layer_norm_backward |
|
|
|
from torch._subclasses.fake_tensor import FakeTensorMode |
|
from torch.distributed._spmd.data_parallel import gradients_tagging |
|
from torch.distributed._spmd.parallel_mode import ( |
|
DataParallel, |
|
DTensorExpandMode, |
|
ParallelMode, |
|
) |
|
from torch.distributed._tensor import Placement |
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo, CodeGen |
|
from torch.nn.utils import stateless |
|
from torch.nn.utils._named_member_accessor import NamedMemberAccessor |
|
|
|
|
|
class Override(ABC): |
|
r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`. |
|
|
|
This is useful when any part of the model is not traceable or if you prefer |
|
to not trace it due to any reason. More specifically, users can implement |
|
:meth:`torch.distributed._spmd.Override.replacement` to replace an original |
|
submodule with the return new submodule. The new submodule contains |
|
operations that users preferred to be traced, which simply be a dummy |
|
placeholder operator. After tracing, users can implement |
|
:meth:`torch.distributed._spmd.Override.transform` to transform the traced |
|
graph, where the dummy placeholder operator serves as an anchor to insert |
|
new sub-graphs. |
|
""" |
|
|
|
@abstractmethod |
|
def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module: |
|
r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule`` |
|
argument in the model. |
|
|
|
This helps if ``orig_submodule`` is not traceable or should not be traced. |
|
|
|
Args: |
|
fqn (str): fully quantified name of the submodule. |
|
orig_submodule (class:`nn.Module`): original submodule instance to replace. |
|
|
|
Returns: |
|
A new :class:`nn.Module` instance to replace the original one. |
|
|
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def transform( |
|
self, |
|
gm: fx.GraphModule, |
|
flat_state: List[torch.Tensor], |
|
) -> fx.GraphModule: |
|
r""" |
|
Given a DTensor-expanded graph and sharding schema for every node, |
|
conduct additional transformation for the sub-graph from the :class:`nn.Module` |
|
returned by :meth:`torch.distributed._spmd.Override.replacement` if |
|
necessary. |
|
|
|
Args: |
|
gm (:class:`fx.Graph`): a DTensor-expanded graph. |
|
flat_state (List[str, :class:`Tensor`]): a reference to the list of |
|
flattened state. The elements in ``flat_state`` map to the first |
|
``len(flat_state)`` placeholders in the graph. The transformation |
|
can add state to or remove state from ``flat_state`` as long as |
|
it keeps ``flat_state`` and the placeholders consistent. |
|
|
|
Returns: |
|
The :class:`fx.Graph` after transformation. |
|
|
|
""" |
|
pass |
|
|
|
|
|
class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen): |
|
|
|
def process_inputs(self, *args: Any) -> Any: |
|
return args |
|
|
|
|
|
def gen_fn_def(self, free_vars, maybe_return_annotation): |
|
return CodeGen.gen_fn_def(self, free_vars, maybe_return_annotation) |
|
|
|
|
|
def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
|
"""Move the responsibility of flattening the input arguments from the graph module to the caller. |
|
|
|
Example: |
|
|
|
output = gm(my_struct) |
|
|
|
gm = gm(to_caller_flattened_graph_module) |
|
|
|
output = gm(*pytree.flatten(my_struct)[0]) |
|
|
|
""" |
|
|
|
gm._graph._codegen = _PyTreeCodeGenOutputsOnly( |
|
pytree_info=_PyTreeInfo( |
|
|
|
orig_args=None, |
|
|
|
in_spec=None, |
|
|
|
out_spec=gm._graph._codegen.pytree_info.out_spec, |
|
) |
|
) |
|
gm.recompile() |
|
return gm |
|
|
|
|
|
|
|
|
|
dtensor_expand_mode = DTensorExpandMode() |
|
|
|
|
|
def _override_placements(t: torch.Tensor, placements: List[Placement]): |
|
global dtensor_expand_mode |
|
dtensor_expand_mode._placements_override[id(t)] = placements |
|
|
|
|
|
@contextmanager |
|
def _rematerialize_optimizer( |
|
opt: torch.optim.Optimizer, |
|
named_states: Dict[str, Any], |
|
params: Dict[str, nn.Parameter], |
|
): |
|
assert opt is not None |
|
|
|
|
|
orig_states = copy(opt.state) |
|
for n in named_states: |
|
|
|
opt.state[params[n]] = named_states[n] |
|
|
|
|
|
param_group = opt.param_groups[0] |
|
orig_params = param_group["params"] |
|
param_group["params"] = params.values() |
|
|
|
try: |
|
yield |
|
finally: |
|
param_group["params"] = orig_params |
|
opt.state = orig_states |
|
|
|
|
|
aten = torch.ops.aten |
|
|
|
|
|
@contextmanager |
|
def _enable_compile(): |
|
|
|
|
|
|
|
def f_true(): |
|
return True |
|
|
|
orig_is_compiling_code = torch._utils.is_compiling.__code__ |
|
torch._utils.is_compiling.__code__ = f_true.__code__ |
|
try: |
|
yield |
|
finally: |
|
torch._utils.is_compiling.__code__ = orig_is_compiling_code |
|
|
|
|
|
def _foreach_add_decomp(self, other, alpha=1): |
|
self_updated = aten._foreach_add.List(self, other, alpha=alpha) |
|
for s, s_u in zip(self, self_updated): |
|
s.copy_(s_u) |
|
|
|
|
|
def _foreach_unaop_decomp(op, self): |
|
self_updated = op(self) |
|
for s, s_u in zip(self, self_updated): |
|
s.copy_(s_u) |
|
|
|
|
|
def _foreach_binop_list_decomp(op, self, other): |
|
self_updated = op(self, other) |
|
for s, s_u in zip(self, self_updated): |
|
s.copy_(s_u) |
|
|
|
|
|
def _foreach_binop_scalar_decomp(op, self, scalar=1): |
|
self_updated = op(self, scalar) |
|
for s, s_u in zip(self, self_updated): |
|
s.copy_(s_u) |
|
|
|
|
|
def _foreach_addcop_scalar_decomp(op, self, tensor1, tensor2, scalar=1): |
|
self_updated = op(self, tensor1, tensor2, scalar) |
|
for s, s_u in zip(self, self_updated): |
|
s.copy_(s_u) |
|
|
|
|
|
def _fused_adam_decomp( |
|
self, |
|
grads, |
|
exp_avgs, |
|
exp_avg_sqs, |
|
max_exp_avg_sqs, |
|
state_steps, |
|
*, |
|
lr=1, |
|
beta1=1, |
|
beta2=1, |
|
weight_decay=1, |
|
eps=1, |
|
amsgrad=True, |
|
maximize=True, |
|
grad_scale=None, |
|
found_inf=None, |
|
): |
|
orig_tuple = (self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs) |
|
updated_tuple = aten._fused_adam.default( |
|
self, |
|
grads, |
|
exp_avgs, |
|
exp_avg_sqs, |
|
max_exp_avg_sqs, |
|
state_steps, |
|
lr=lr, |
|
beta1=beta1, |
|
beta2=beta2, |
|
weight_decay=weight_decay, |
|
eps=eps, |
|
amsgrad=amsgrad, |
|
maximize=maximize, |
|
grad_scale=grad_scale, |
|
found_inf=found_inf, |
|
) |
|
|
|
for idx, (orig, updated) in enumerate(zip(orig_tuple, updated_tuple)): |
|
if idx == 1: |
|
|
|
continue |
|
for o, u in zip(orig, updated): |
|
o.copy_(u) |
|
|
|
|
|
SPMD_DECOMP_TABLE = { |
|
aten._foreach_add_.List: _foreach_add_decomp, |
|
aten._foreach_add_.Scalar: partial( |
|
_foreach_binop_scalar_decomp, aten._foreach_add.Scalar |
|
), |
|
aten._foreach_addcdiv_.Scalar: partial( |
|
_foreach_addcop_scalar_decomp, aten._foreach_addcdiv.Scalar |
|
), |
|
aten._foreach_addcmul_.Scalar: partial( |
|
_foreach_addcop_scalar_decomp, aten._foreach_addcmul.Scalar |
|
), |
|
aten._foreach_div_.List: partial( |
|
_foreach_binop_list_decomp, aten._foreach_div.List |
|
), |
|
aten._foreach_mul_.Scalar: partial( |
|
_foreach_binop_scalar_decomp, aten._foreach_mul.Scalar |
|
), |
|
aten._foreach_div_.Scalar: partial( |
|
_foreach_binop_scalar_decomp, aten._foreach_div.Scalar |
|
), |
|
aten._foreach_neg_.default: partial( |
|
_foreach_unaop_decomp, aten._foreach_neg.default |
|
), |
|
aten._foreach_reciprocal_.default: partial( |
|
_foreach_unaop_decomp, aten._foreach_reciprocal.default |
|
), |
|
aten._foreach_sqrt_.default: partial( |
|
_foreach_unaop_decomp, aten._foreach_sqrt.default |
|
), |
|
aten._foreach_sub_.Scalar: partial( |
|
_foreach_binop_scalar_decomp, aten._foreach_sub.Scalar |
|
), |
|
aten._fused_adam_.default: _fused_adam_decomp, |
|
aten.native_layer_norm_backward.default: native_layer_norm_backward, |
|
} |
|
|
|
|
|
DEDUP_TARGETS: Set[torch._ops.OpOverload] = { |
|
torch.ops._c10d_functional.all_reduce.default, |
|
torch.ops._c10d_functional.wait_tensor.default, |
|
} |
|
|
|
|
|
def _dedup_collectives(gm: fx.GraphModule) -> fx.GraphModule: |
|
args_to_node: Dict[Tuple[Any, ...], fx.Node] = {} |
|
|
|
for node in gm.graph.nodes: |
|
|
|
args = pytree.arg_tree_leaves(*node.args) |
|
|
|
if node.target in DEDUP_TARGETS: |
|
args_key = (node.target, *args) |
|
unique_node = args_to_node.get(args_key, None) |
|
if unique_node is None: |
|
|
|
args_to_node[args_key] = node |
|
else: |
|
|
|
node.replace_all_uses_with(unique_node) |
|
gm.graph.erase_node(node) |
|
|
|
gm.recompile() |
|
|
|
return gm |
|
|
|
|
|
@dataclass |
|
class _CompiledResult: |
|
gm: fx.GraphModule |
|
mod: nn.Module |
|
opt: Optional[torch.optim.Optimizer] |
|
flat_state: List[torch.Tensor] |
|
|
|
|
|
def _compile( |
|
func: Callable, |
|
module_override: Optional[List[Override]], |
|
parallel_mode: ParallelMode, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> _CompiledResult: |
|
|
|
|
|
|
|
|
|
mod, opt = None, None |
|
for arg in pytree.arg_tree_leaves(*args, **kwargs): |
|
if isinstance(arg, nn.Module): |
|
assert mod is None, "Only support single nn.Module for now" |
|
mod = arg |
|
if isinstance(arg, torch.optim.Optimizer): |
|
assert opt is None, "Only support single Optimizer for now" |
|
opt = arg |
|
|
|
assert mod is not None, "Couldn't find nn.Module instances from the arguments." |
|
|
|
|
|
if module_override: |
|
accessor = NamedMemberAccessor(mod) |
|
|
|
def swap(fqn_prefix: str, module: torch.nn.Module) -> None: |
|
for override in module_override: |
|
for name, child in module.named_children(): |
|
if len(name) == 0: |
|
continue |
|
fqn = fqn_prefix + "." + name if fqn_prefix != "" else name |
|
new_child = override.replacement(fqn, child) |
|
if id(new_child) == id(child): |
|
swap(fqn, new_child) |
|
else: |
|
accessor.swap_submodule(fqn, new_child) |
|
|
|
swap("", mod) |
|
|
|
|
|
params = dict(mod.named_parameters(remove_duplicate=False)) |
|
buffers = dict(mod.named_buffers(remove_duplicate=False)) |
|
|
|
named_states = {} |
|
if opt is not None: |
|
|
|
|
|
|
|
for n, p in params.items(): |
|
if p in opt.state: |
|
|
|
|
|
named_states[n] = opt.state[p] |
|
|
|
is_data_parallel_mode = isinstance(parallel_mode, DataParallel) |
|
|
|
|
|
|
|
def stateless_func(func, params, buffers, named_states, args, kwargs): |
|
with stateless._reparametrize_module( |
|
mod, {**params, **buffers} |
|
), _rematerialize_optimizer( |
|
opt, named_states, params |
|
) if opt else nullcontext(): |
|
|
|
with gradients_tagging(params) if is_data_parallel_mode else nullcontext(): |
|
ret = func(*args, **kwargs) |
|
|
|
|
|
return ret, list(mod.parameters()), list(named_states.values()) |
|
|
|
|
|
|
|
|
|
|
|
tracing_mode = "fake" if is_data_parallel_mode else "symbolic" |
|
|
|
if is_data_parallel_mode: |
|
fake_mode = FakeTensorMode() |
|
data_parallel_mode = cast(DataParallel, parallel_mode) |
|
|
|
def _get_full_batch_arg(arg: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
fake_arg = fake_mode.from_tensor(arg) |
|
arg_dims = [1] * arg.ndim |
|
|
|
arg_dims[data_parallel_mode.input_batch_dim] *= dist.get_world_size() |
|
return fake_arg.repeat(arg_dims) |
|
|
|
args = pytree.tree_map_only( |
|
torch.Tensor, |
|
_get_full_batch_arg, |
|
args, |
|
) |
|
kwargs = pytree.tree_map_only( |
|
torch.Tensor, |
|
_get_full_batch_arg, |
|
kwargs, |
|
) |
|
|
|
with _enable_compile(), torch.autograd.detect_anomaly(check_nan=False): |
|
|
|
|
|
|
|
|
|
gm = make_fx( |
|
partial(stateless_func, func), |
|
tracing_mode=tracing_mode, |
|
decomposition_table=SPMD_DECOMP_TABLE, |
|
_allow_non_fake_inputs=False, |
|
)(params, buffers, named_states, args, kwargs) |
|
|
|
params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = { |
|
**params, |
|
**buffers, |
|
} |
|
|
|
|
|
gm = parallel_mode.partition( |
|
gm, |
|
mod, |
|
opt, |
|
params_and_buffers, |
|
named_states, |
|
args, |
|
kwargs, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flat_state = pytree.tree_leaves([params_and_buffers, named_states]) |
|
gm = _to_caller_flattened_graph_module(gm) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gm = _dedup_collectives(gm) |
|
|
|
|
|
if module_override: |
|
for override in module_override: |
|
gm = override.transform(gm, flat_state) |
|
|
|
return _CompiledResult(gm, mod, opt, flat_state) |
|
|
|
|
|
|
|
|
|
COMPILED_OBJECT_KEY = "_compiled_obj" |
|
|
|
|
|
def compile( |
|
module_override: Optional[List[Override]] = None, |
|
gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, |
|
parallel_mode: Optional[ParallelMode] = None, |
|
): |
|
r"""Compile and optimize a callable, which can be a train step within a training loop. |
|
|
|
This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer` |
|
instances from the input arguments and trace operations applied to their |
|
parameters and states. |
|
|
|
Args: |
|
module_override (Optional[List[Override]]): a list of Override instances |
|
that will be applied to the module in order. The :class:`Override` |
|
objects provide :class:`nn.Module` replacements during tracing and a |
|
graph transformation function after tracing. (Default: ``None``) |
|
gm_transformation (Optional[Callable[fx.GraphModule, fx.GraphModule]]): |
|
a callback that will be called after the original callable is |
|
compiled and distributed (usually after the first iteration) to |
|
transform the compiled GraphModule into a new optimized one. |
|
parallel_mode (Optional[ParallelMode]): a :class:`ParallelMode` object |
|
that specifies how to parallelize the callable. Each ParallelMode |
|
would have its own strategy to partition the model and the captured |
|
graph (Default: ``None``) |
|
|
|
""" |
|
|
|
def inner(func: Callable): |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
last_train_step = kwargs.pop("last_train_step", False) if kwargs else False |
|
first_iter = False |
|
|
|
|
|
compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None) |
|
if compiled_obj is None: |
|
first_iter = True |
|
global dtensor_expand_mode |
|
mode: ParallelMode = ( |
|
dtensor_expand_mode if parallel_mode is None else parallel_mode |
|
) |
|
|
|
compiled_obj = _compile(func, module_override, mode, *args, **kwargs) |
|
wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj |
|
|
|
flat_inps = compiled_obj.flat_state + pytree.arg_tree_leaves( |
|
*args, **kwargs |
|
) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if first_iter and gm_transformation: |
|
|
|
|
|
compiled_obj.gm = gm_transformation(compiled_obj.gm) |
|
if not last_train_step: |
|
output = compiled_obj.gm(*flat_inps)[0] |
|
else: |
|
|
|
|
|
|
|
try: |
|
output = compiled_obj.gm(*flat_inps, last_iter=last_train_step)[ |
|
0 |
|
] |
|
except TypeError as e: |
|
if "last_iter" not in str(e): |
|
raise e |
|
output = compiled_obj.gm(*flat_inps)[0] |
|
|
|
return output |
|
|
|
return wrapper |
|
|
|
return inner |
|
|