|
|
|
|
|
import copy |
|
import logging |
|
import operator |
|
from collections import defaultdict |
|
from enum import Enum |
|
from inspect import Parameter, signature, Signature |
|
from types import MethodType |
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
|
|
|
import torch |
|
import torch.fx as fx |
|
from torch.distributed import ProcessGroup |
|
from torch.export import ExportedProgram |
|
from torch.export.unflatten import ( |
|
_assign_attr, |
|
_AttrKind, |
|
_sink_params, |
|
InterpreterModule, |
|
) |
|
from torch.fx.node import map_aggregate |
|
from torch.fx.passes.split_module import split_module |
|
from ._backward import _null_coalesce_accumulate, stage_backward |
|
from ._unflatten import _outline_submodules |
|
from ._utils import PipeInfo |
|
from .stage import _PipelineStage |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _find_loss_from_output_and_spec(output_val, spec_val): |
|
if spec_val is False: |
|
return None |
|
if spec_val is True: |
|
if not isinstance(output_val, fx.Node): |
|
raise RuntimeError( |
|
f"Loss spec must specify a dynamic value but got {output_val}" |
|
) |
|
return output_val |
|
|
|
if isinstance(spec_val, (tuple, list)): |
|
if not isinstance(output_val, (tuple, list)): |
|
raise RuntimeError( |
|
f"Output value {output_val} must match type of loss specification " |
|
f"{spec_val}" |
|
) |
|
if len(output_val) != len(spec_val): |
|
raise RuntimeError( |
|
f"Output value {output_val} must match length of loss specification " |
|
f"{spec_val}" |
|
) |
|
for out, spec in zip(output_val, spec_val): |
|
loss_val = _find_loss_from_output_and_spec(out, spec) |
|
if loss_val is not None: |
|
return loss_val |
|
raise RuntimeError(f"Did not find loss value in specification {spec_val}") |
|
|
|
if isinstance(spec_val, dict): |
|
if not isinstance(output_val, dict): |
|
raise RuntimeError( |
|
f"Output value {output_val} must match type of loss specification " |
|
f"{spec_val}" |
|
) |
|
if set(output_val.keys()) != set(spec_val.keys()): |
|
raise RuntimeError( |
|
f"Output value {output_val} must match keys of loss specification " |
|
f"{spec_val}" |
|
) |
|
for k in spec_val: |
|
loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k]) |
|
if loss_val is not None: |
|
return loss_val |
|
raise RuntimeError(f"Did not find loss value in specification {spec_val}") |
|
|
|
raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification") |
|
|
|
|
|
def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec): |
|
output_nodes = [n for n in g.nodes if n.op == "output"] |
|
assert len(output_nodes) == 1 |
|
output_node = output_nodes[0] |
|
output_val = output_node.args[0] |
|
generated_spec: Any = None |
|
|
|
if isinstance(mod, TrivialLossWrapper): |
|
|
|
|
|
assert len(output_node.args) == 1 |
|
loss_node = output_val |
|
generated_spec = TrivialLossWrapper.loss_spec |
|
elif output_loss_value_spec is None: |
|
|
|
if isinstance(output_val, dict) and "loss" in output_val.keys(): |
|
loss_node = output_val["loss"] |
|
generated_spec = {k: k == "loss" for k in output_val} |
|
else: |
|
loss_node = None |
|
generated_spec = None |
|
else: |
|
loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec) |
|
generated_spec = output_loss_value_spec |
|
|
|
return loss_node, output_node, generated_spec |
|
|
|
|
|
def _insert_stage_symbolic_backward( |
|
g: fx.Graph, |
|
loss_node: fx.Node, |
|
output_node: fx.Node, |
|
): |
|
|
|
tuples: Dict[fx.Node, Tuple] = {} |
|
for node in reversed(g.nodes): |
|
if node.op == "call_function": |
|
|
|
|
|
|
|
assert node.target == operator.getitem, ( |
|
"Found non-getitem call in forward pass. " |
|
"Please report a bug to PiPPy" |
|
) |
|
assert ( |
|
len(node.args) == 2 |
|
), "Found malformed getitem call. Please report a bug to PiPPy" |
|
indexed_value, node_idx = tuple(node.args) |
|
|
|
|
|
|
|
|
|
existing_list_size = ( |
|
len(tuples[indexed_value]) if indexed_value in tuples else -1 |
|
) |
|
new_list_size = max(node_idx + 1, existing_list_size) |
|
|
|
reconstructed_list = [None for _ in range(new_list_size)] |
|
|
|
|
|
if indexed_value in tuples: |
|
for i, val in enumerate(tuples[indexed_value]): |
|
reconstructed_list[i] = val |
|
|
|
|
|
reconstructed_list[node_idx] = node |
|
|
|
tuples[indexed_value] = tuple(reconstructed_list) |
|
|
|
|
|
|
|
|
|
live_nodes = {loss_node: None} |
|
val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None} |
|
|
|
def assign_or_accumulate_grad(forward_node, grad_value): |
|
if forward_node in val_to_grad and forward_node.op != "placeholder": |
|
grad_value = g.call_function( |
|
_null_coalesce_accumulate, |
|
(val_to_grad[forward_node], grad_value), |
|
) |
|
val_to_grad[forward_node] = grad_value |
|
|
|
with g.inserting_before(output_node): |
|
for node in reversed(g.nodes): |
|
if node not in live_nodes: |
|
continue |
|
|
|
def add_to_live_nodes(n): |
|
live_nodes.setdefault(n, None) |
|
|
|
fx.node.map_arg(node.args, add_to_live_nodes) |
|
fx.node.map_arg(node.kwargs, add_to_live_nodes) |
|
if node.op == "call_module": |
|
output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]] |
|
if node in tuples: |
|
stage_output = tuples[node] |
|
output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node]) |
|
outputs_with_grads_idxs = [ |
|
i for i, n in enumerate(tuples[node]) if n in live_nodes |
|
] |
|
else: |
|
stage_output = (node,) |
|
output_grads = val_to_grad[node] |
|
outputs_with_grads_idxs = [0] |
|
|
|
output_grads = ( |
|
(output_grads,) |
|
if not isinstance(output_grads, tuple) |
|
else output_grads |
|
) |
|
|
|
grad_call = g.call_function( |
|
stage_backward, |
|
kwargs={ |
|
"stage_output": stage_output, |
|
"output_grads": output_grads, |
|
"input_values": list(node.all_input_nodes), |
|
"outputs_with_grads_idxs": outputs_with_grads_idxs, |
|
}, |
|
) |
|
|
|
kwargs_copy = dict(grad_call.kwargs) |
|
grad_call.kwargs = kwargs_copy |
|
|
|
grad_call_proxy = fx.Proxy(grad_call) |
|
grads = grad_call_proxy.node |
|
|
|
input_nodes = list(node.all_input_nodes) |
|
grads_proxy = fx.Proxy(grads) |
|
for i, input_node in enumerate(input_nodes): |
|
assign_or_accumulate_grad(input_node, grads_proxy[i].node) |
|
|
|
return g |
|
|
|
|
|
class PipeSequential(torch.nn.Sequential): |
|
@staticmethod |
|
def from_sequential(sequential_instance: torch.nn.Sequential): |
|
return PipeSequential(*[copy.copy(m) for m in sequential_instance]) |
|
|
|
def forward(self, input): |
|
for i, module in enumerate(self): |
|
input = module(input) |
|
if i != len(self) - 1: |
|
pipe_split() |
|
return input |
|
|
|
|
|
class LossWrapper(torch.nn.Module): |
|
""" |
|
LossWrapper is a convenient abstract class that allows you to wrap up both |
|
your model as well as its loss function and specify the connectivity between |
|
the inputs, model, loss function, and output value. Example:: |
|
|
|
class MyModelWrapper(LossWrapper): |
|
def forward(self, x, targets): |
|
model_out = self.module(x) |
|
loss_value = self.loss_fn(model_out, targets) |
|
return loss_value |
|
|
|
The above example defines a connectivity where we expect the forward/loss/backward |
|
training procedure to take two arguments (x and targets), pass x into the module |
|
to get the output of the feedforward computation, pass the model output and the |
|
targets value into the loss function, and get and return the loss value, which will |
|
be backpropagated by PiPPy. The above class would then be instantiated like:: |
|
|
|
model = ... # instantiate the model |
|
loss_fn = torch.nn.MSELoss() # for the sake of demonstration |
|
|
|
wrapper = MyModelWrapper(model, loss_fn) |
|
pipe = Pipe.from_tracing(wrapper, ...) |
|
|
|
""" |
|
|
|
def __init__(self, module, loss_fn): |
|
super().__init__() |
|
self.module = module |
|
self.loss_fn = loss_fn |
|
|
|
def forward(self, *args, **kwargs): |
|
raise NotImplementedError( |
|
"This instance of LossWrapper does not have an overridden" |
|
"forward(). Please implement forward() to specify the arguments, " |
|
"connection between the module and loss, and loss output " |
|
"value." |
|
) |
|
|
|
|
|
class TrivialLossWrapper(LossWrapper): |
|
def forward(self, x, targets): |
|
model_out = self.module(x) |
|
return self.loss_fn(model_out, targets) |
|
|
|
loss_spec = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.library.define("pippy::_pipe_split", "() -> ()") |
|
|
|
|
|
@torch.library.impl("pippy::_pipe_split", "BackendSelect") |
|
def _pipe_split(): |
|
return None |
|
|
|
|
|
@torch.library.register_fake("pippy::_pipe_split") |
|
def _pipe_split(): |
|
return None |
|
|
|
|
|
|
|
aten_pipe_split_alias = torch.ops.pippy._pipe_split.default |
|
|
|
|
|
|
|
fx.node._side_effectful_functions.add(aten_pipe_split_alias) |
|
|
|
|
|
|
|
def pipe_split(): |
|
""" |
|
pipe_split is a special operator that is used to mark the boundary between |
|
stages in a module. It is used to split the module into stages. It is a |
|
no-op if your annotated module is run eagerly. |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP |
|
>>> def forward(self, x): |
|
>>> x = torch.mm(x, self.mm_param) |
|
>>> x = torch.relu(x) |
|
>>> pipe_split() |
|
>>> x = self.lin(x) |
|
>>> return x |
|
|
|
The above example will be split into two stages. |
|
""" |
|
return torch.ops.pippy._pipe_split() |
|
|
|
|
|
class MultiUseParameterConfig(Enum): |
|
TRANSMIT = 1 |
|
REPLICATE = 2 |
|
|
|
|
|
MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]] |
|
|
|
|
|
class DetachExecutor(fx.Interpreter): |
|
""" |
|
Special interpreter to run the split_gm in testing that detaches all inputs to |
|
a module invocation. This is needed so that the values at the boundary are |
|
leaf modules in autograd execution. |
|
""" |
|
|
|
def __init__(self, module, garbage_collect_values=True): |
|
garbage_collect_values = False |
|
super().__init__(module, garbage_collect_values) |
|
self.value_remap = {} |
|
|
|
def run(self, *args, initial_env=None): |
|
self.value_remap = {} |
|
return super().run(*args, initial_env=initial_env) |
|
|
|
def call_module(self, target, args, kwargs): |
|
def detach_tensors(a): |
|
if isinstance(a, torch.Tensor) and a.requires_grad: |
|
if a not in self.value_remap: |
|
new_val = a.detach().requires_grad_(True) |
|
self.value_remap[a] = new_val |
|
return self.value_remap[a] |
|
else: |
|
return a |
|
|
|
""" |
|
def dont_traverse_size(a): |
|
return type(a) != torch.Size |
|
""" |
|
|
|
args = map_aggregate( |
|
args, |
|
detach_tensors, |
|
) |
|
kwargs = map_aggregate( |
|
kwargs, |
|
detach_tensors, |
|
) |
|
|
|
return super().call_module(target, args, kwargs) |
|
|
|
def call_function(self, target, args, kwargs): |
|
|
|
if target == stage_backward: |
|
kwargs = dict(kwargs) |
|
kwargs["input_values"] = [ |
|
self.value_remap.get(v, v) for v in kwargs["input_values"] |
|
] |
|
return super().call_function(target, args, kwargs) |
|
|
|
|
|
class _NodeReference: |
|
def __init__(self, name): |
|
self.name = name |
|
|
|
name: str |
|
|
|
|
|
class _LinearNodeList: |
|
def __init__(self, node_list): |
|
self.serialize_node_list = [] |
|
for node in node_list: |
|
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) |
|
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) |
|
serialize_node = fx.Node( |
|
graph=None, |
|
name=node.name, |
|
op=node.op, |
|
target=node.target, |
|
args=node_args, |
|
kwargs=node_kwargs, |
|
return_type=node.type, |
|
) |
|
serialize_node.meta = copy.copy(node.meta) |
|
self.serialize_node_list.append(serialize_node) |
|
|
|
def to_graph(self): |
|
graph = fx.Graph() |
|
|
|
ref_str_to_node: Dict[str, fx.Node] = {} |
|
|
|
def ref_to_node(arg): |
|
if isinstance(arg, _NodeReference): |
|
return ref_str_to_node[arg.name] |
|
else: |
|
return arg |
|
|
|
for node in self.serialize_node_list: |
|
node_args = map_aggregate(node.args, ref_to_node) |
|
node_kwargs = map_aggregate(node.kwargs, ref_to_node) |
|
deser_node = graph.create_node( |
|
op=node.op, |
|
target=node.target, |
|
args=node_args, |
|
kwargs=node_kwargs, |
|
name=node.name, |
|
type_expr=node.type, |
|
) |
|
ref_str_to_node[node.name] = deser_node |
|
|
|
return graph |
|
|
|
|
|
def _direct_serialization_deserialize(body, nodes): |
|
""" |
|
Custom `__reduce__` method for serialization. |
|
DO AS I SAY -- NOT AS I DO. This violates the principle that |
|
GraphModules serialize via code export & re-tracing. We allow |
|
for this here because **PIPE STAGES SHOULD NOT BE PERSISTED |
|
TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting |
|
these instances to disk will expose internal implementation |
|
details of `fx.Graph` and related data structures and is |
|
NOT advised. |
|
""" |
|
|
|
class DummyModule(torch.nn.Module): |
|
def __init__(self, body): |
|
super().__init__() |
|
self.__dict__.update(body) |
|
|
|
dummy = DummyModule(body) |
|
|
|
return fx.GraphModule(dummy, nodes.to_graph()) |
|
|
|
|
|
def _direct_serialization_reduce(self): |
|
serialization_dict = dict(self.__dict__) |
|
serialization_dict.pop("_graph") |
|
return ( |
|
_direct_serialization_deserialize, |
|
(serialization_dict, _LinearNodeList(self.graph.nodes)), |
|
) |
|
|
|
|
|
def _modify_graph_op_device( |
|
gm: torch.fx.GraphModule, |
|
new_device: torch.device, |
|
): |
|
""" |
|
Modify the device argument of all "call_function" nodes in the graph. This |
|
is useful for moving the graph to a different device. In particular for |
|
generator ops, like torch.ones. |
|
""" |
|
modified = False |
|
for node in gm.graph.nodes: |
|
if node.op == "call_function": |
|
if "device" in node.kwargs and node.kwargs["device"] != new_device: |
|
logger.debug( |
|
f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" |
|
) |
|
node.update_kwarg("device", new_device) |
|
modified = True |
|
elif node.op == "call_module": |
|
|
|
submod = gm.get_submodule(node.target) |
|
if isinstance(submod, torch.fx.GraphModule): |
|
_modify_graph_op_device(submod, new_device) |
|
elif isinstance(submod, InterpreterModule): |
|
|
|
_modify_graph_op_device(submod.graph_module, new_device) |
|
else: |
|
logger.warning( |
|
f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" |
|
) |
|
|
|
if modified: |
|
gm.recompile() |
|
|
|
|
|
class Pipe(torch.nn.Module): |
|
def __init__( |
|
self, |
|
split_gm: fx.GraphModule, |
|
num_stages: int, |
|
has_loss_and_backward: bool, |
|
loss_spec, |
|
): |
|
|
|
torch.nn.Module.__init__(self) |
|
self.split_gm: fx.GraphModule = split_gm |
|
self.executor: DetachExecutor = DetachExecutor(self.split_gm) |
|
self.num_stages: int = num_stages |
|
self.has_loss_and_backward = has_loss_and_backward |
|
self.loss_spec = loss_spec |
|
|
|
for node in split_gm.graph.nodes: |
|
assert ( |
|
node.op in {"call_module", "placeholder", "output"} |
|
or (node.op, node.target) == ("call_function", operator.getitem) |
|
or (node.op, node.target) == ("call_method", "backward") |
|
or (node.op, node.target) == ("call_function", stage_backward) |
|
or (node.op, node.target) |
|
== ("call_function", _null_coalesce_accumulate) |
|
), node |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {} |
|
|
|
for m_qualname, mod in self.split_gm.named_children(): |
|
for p_qualname, param in mod.named_parameters(): |
|
params_to_users.setdefault(param, {}) |
|
params_to_users[param][m_qualname] = p_qualname |
|
|
|
self.replicated_params: List[Dict[str, str]] = [ |
|
use_mapping |
|
for _, use_mapping in params_to_users.items() |
|
if len(use_mapping) > 1 |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
for param_mapping in self.replicated_params: |
|
for submod_name, param_qualname in param_mapping.items(): |
|
submod = getattr(self.split_gm, submod_name) |
|
atoms = param_qualname.split(".") |
|
for atom in atoms[:-1]: |
|
submod = getattr(submod, atom) |
|
setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1]))) |
|
|
|
def throw(self, *args, **kwargs): |
|
raise RuntimeError( |
|
"To run pipeline locally, invoke the Pipe object directly, not `split_gm`" |
|
) |
|
|
|
self.split_gm.forward = throw |
|
|
|
|
|
i = 0 |
|
while True: |
|
try: |
|
name = f"submod_{i}" |
|
submod = getattr(self.split_gm, name) |
|
submod.__class__.__reduce__ = _direct_serialization_reduce |
|
i += 1 |
|
except AttributeError: |
|
break |
|
|
|
def forward(self, *args, **kwargs): |
|
executor_args = args |
|
if len(kwargs) > 0: |
|
parameters = [] |
|
for node in self.split_gm.graph.nodes: |
|
if node.op == "placeholder": |
|
if node.args and len(node.args) > 0: |
|
parameters.append( |
|
Parameter( |
|
node.target, |
|
Parameter.POSITIONAL_OR_KEYWORD, |
|
default=node.args[0], |
|
) |
|
) |
|
else: |
|
parameter_kind = Parameter.POSITIONAL_OR_KEYWORD |
|
param_name = node.target |
|
if node.target.startswith("**"): |
|
parameter_kind = Parameter.VAR_KEYWORD |
|
param_name = param_name[2:] |
|
elif node.target.startswith("*"): |
|
parameter_kind = Parameter.VAR_POSITIONAL |
|
param_name = param_name[1:] |
|
parameters.append(Parameter(param_name, parameter_kind)) |
|
signature = Signature(parameters) |
|
ba = signature.bind(*args, **kwargs) |
|
ba.apply_defaults() |
|
executor_args = ba.arguments.values() |
|
|
|
res = self.executor.run(*executor_args) |
|
|
|
return res |
|
|
|
def get_stage_module(self, stage_idx: int) -> torch.nn.Module: |
|
""" |
|
Return a stage module corresponding to `stage_idx` of the `pipe`. |
|
""" |
|
if stage_idx < 0 or stage_idx >= self.num_stages: |
|
raise ValueError(f"Invalid stage index {stage_idx}!") |
|
return getattr(self.split_gm, f"submod_{stage_idx}") |
|
|
|
@staticmethod |
|
def _number_and_count_forward_stages(gm: fx.GraphModule): |
|
num_stages = 0 |
|
found_idxs: Dict[int, None] = {} |
|
for node in gm.graph.nodes: |
|
if node.op == "call_module" and node.target.startswith("submod_"): |
|
node.meta["stage_idx"] = int(node.target[len("submod_") :]) |
|
found_idxs.setdefault(node.meta["stage_idx"]) |
|
num_stages += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return num_stages |
|
|
|
@staticmethod |
|
def _from_traced( |
|
mod: torch.nn.Module, |
|
exported_program: ExportedProgram, |
|
multi_use_param_spec: Optional[MultiUseParamSpec] = None, |
|
output_loss_value_spec=None, |
|
split_policy: Optional[ |
|
Callable[[torch.fx.GraphModule], torch.fx.GraphModule] |
|
] = None, |
|
): |
|
""" |
|
Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate |
|
which value in the output of `forward` is the loss value on which PiPPy should apply |
|
backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``, |
|
you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns |
|
a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify |
|
``output_loss_value_spec={'loss': True, 'model_out': False}`` |
|
""" |
|
|
|
traced = exported_program.module() |
|
|
|
if split_policy is not None: |
|
logger.info("Auto-splitting model") |
|
traced = split_policy(traced) |
|
|
|
logger.debug(traced.print_readable(print_output=False)) |
|
|
|
|
|
|
|
|
|
get_attr_nodes: Dict[str, fx.Node] = {} |
|
for node in traced.graph.nodes: |
|
if node.op == "get_attr": |
|
get_attr_nodes.setdefault(node.target, node) |
|
|
|
if get_attr_nodes[node.target] != node: |
|
node.replace_all_uses_with(get_attr_nodes[node.target]) |
|
traced.graph.erase_node(node) |
|
|
|
|
|
prev_pipe_split_idx = -1 |
|
pipe_split_nodes_to_erase = set() |
|
for i, node in enumerate(traced.graph.nodes): |
|
if (node.op, node.target) == ("call_function", pipe_split): |
|
if prev_pipe_split_idx == i - 1: |
|
pipe_split_nodes_to_erase.add(node) |
|
prev_pipe_split_idx = i |
|
|
|
for node in pipe_split_nodes_to_erase: |
|
traced.graph.erase_node(node) |
|
|
|
traced.recompile() |
|
|
|
part_idx = 0 |
|
|
|
def split_callback(n: fx.Node): |
|
nonlocal part_idx |
|
if (n.op, n.target) == ( |
|
"call_function", |
|
aten_pipe_split_alias, |
|
): |
|
logger.debug(f"Found pipe_split {part_idx}") |
|
part_idx += 1 |
|
return part_idx |
|
|
|
|
|
|
|
split = split_module(traced, mod, split_callback) |
|
|
|
split.graph.eliminate_dead_code() |
|
|
|
|
|
for submodule in split.modules(): |
|
if isinstance(submodule, fx.GraphModule): |
|
for node in submodule.graph.nodes: |
|
if (node.op, node.target) == ( |
|
"call_function", |
|
aten_pipe_split_alias, |
|
): |
|
submodule.graph.erase_node(node) |
|
submodule.recompile() |
|
|
|
for name, submodule in split.named_children(): |
|
if isinstance(submodule, fx.GraphModule): |
|
new_submod = _outline_submodules(submodule.graph) |
|
|
|
split.register_module(name, new_submod) |
|
|
|
|
|
def delete_user_reference(node, user): |
|
""" |
|
Delete reference of `node` from `user`'s arg list. |
|
Args: |
|
- node: a `get_attr` node at root. |
|
- user: a submodule node that uses `node`. |
|
""" |
|
assert len(user.kwargs) == 0 |
|
use_idxs = [i for i, arg in enumerate(user.args) if arg == node] |
|
assert len(use_idxs) == 1 |
|
args_copy = list(user.args) |
|
args_copy.pop(use_idxs[0]) |
|
user.args = tuple(args_copy) |
|
logger.debug( |
|
f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" |
|
) |
|
|
|
|
|
|
|
to_delete = list() |
|
|
|
def _recursive_getattr_with_parent(mod, fqn): |
|
|
|
atoms = fqn.split(".") |
|
for atom in atoms[:-1]: |
|
if not hasattr(mod, atom): |
|
return None, None |
|
mod = getattr(mod, atom) |
|
if not hasattr(mod, atoms[-1]): |
|
return mod, None |
|
attr = getattr(mod, atoms[-1]) |
|
return mod, attr |
|
|
|
def move_param_to_callee( |
|
root, |
|
callee_name, |
|
param_fqn, |
|
): |
|
""" |
|
Move a parameter from the root module to a submodule. |
|
Args: |
|
root: The root module. |
|
callee_name: The name of the submodule to move the parameter to. |
|
param_fqn: The fully qualified name of the parameter to move. |
|
""" |
|
|
|
|
|
atoms = param_fqn.split(".") |
|
mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) |
|
|
|
is_buffer = atoms[-1] in mod_itr._buffers |
|
|
|
|
|
assert isinstance(param_val, torch.Tensor), ( |
|
f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}." |
|
+ ( |
|
f" It might happen if module '{param_fqn}' was passed to some 'leaf function'" |
|
f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect " |
|
f"usages of '{param_fqn}' in the traced graph." |
|
if isinstance(param_val, torch.nn.Module) |
|
else "" |
|
) |
|
) |
|
|
|
|
|
callee = root.get_submodule(callee_name) |
|
assert not hasattr( |
|
callee, param_fqn |
|
), f"Module {callee_name} already has a parameter named {param_fqn}" |
|
|
|
|
|
if is_buffer: |
|
_assign_attr( |
|
param_val, |
|
callee, |
|
param_fqn, |
|
attr_kind=_AttrKind.BUFFER, |
|
persistent=True, |
|
) |
|
else: |
|
_assign_attr( |
|
param_val, |
|
callee, |
|
param_fqn, |
|
attr_kind=_AttrKind.PARAMETER, |
|
) |
|
logger.debug(f"Moved parameter {param_fqn} to {callee_name}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
to_delete.append((mod_itr, atoms[-1])) |
|
|
|
|
|
attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes)) |
|
for node in attr_nodes: |
|
|
|
if len(node.users) > 1: |
|
logger.info( |
|
f"Parameter {node.target} used in multiple stages: {node.users}." |
|
) |
|
for user in node.users: |
|
assert user.op == "call_module" |
|
|
|
move_param_to_callee( |
|
split, |
|
user.target, |
|
node.target, |
|
) |
|
|
|
|
|
|
|
id_to_fqns: Dict[int, Set[str]] = defaultdict(set) |
|
for fqn, tensor in mod.state_dict(keep_vars=True).items(): |
|
id_to_fqns[id(tensor)].add(fqn) |
|
for fqn, tensor in mod.named_buffers(): |
|
id_to_fqns[id(tensor)].add(fqn) |
|
|
|
|
|
|
|
|
|
|
|
inputs_to_state: Dict[str, List[str]] = {} |
|
for attr in attr_nodes: |
|
_, tensor = _recursive_getattr_with_parent(mod, attr.target) |
|
fqns = list(id_to_fqns[id(tensor)]) |
|
if fqns: |
|
inputs_to_state[attr.name] = fqns |
|
elif attr.target in exported_program.constants: |
|
inputs_to_state[attr.name] = [attr.target] |
|
|
|
|
|
|
|
|
|
added_attributes: Dict[str, List[str]] = defaultdict(list) |
|
for fqn, tensor in mod.state_dict(keep_vars=True).items(): |
|
for name, submod in split.named_children(): |
|
if isinstance(submod, fx.GraphModule): |
|
parent, child = _recursive_getattr_with_parent(submod, fqn) |
|
if ( |
|
parent and child is None |
|
): |
|
added_attributes[name].append(fqn) |
|
setattr(parent, fqn.split(".")[-1], tensor) |
|
|
|
|
|
|
|
for mod_itr, last_atom in to_delete: |
|
try: |
|
delattr(mod_itr, last_atom) |
|
except AttributeError: |
|
|
|
pass |
|
|
|
|
|
for name, submod in split.named_children(): |
|
if isinstance(submod, fx.GraphModule): |
|
_sink_params(submod, inputs_to_state, []) |
|
submod.graph.lint() |
|
submod.recompile() |
|
|
|
|
|
|
|
|
|
for name, attributes in added_attributes.items(): |
|
submod = getattr(split, name) |
|
unused_attributes = set(attributes) |
|
|
|
stack = [("", submod)] |
|
while stack: |
|
scope, _mod = stack.pop() |
|
if isinstance(_mod, (fx.GraphModule, InterpreterModule)): |
|
for node in _mod.graph.nodes: |
|
if node.op == "get_attr": |
|
|
|
fqn = scope + "." + node.target if scope else node.target |
|
if fqn in unused_attributes: |
|
unused_attributes.remove(fqn) |
|
for _name, _submod in _mod.named_children(): |
|
stack.append((scope + "." + _name if scope else _name, _submod)) |
|
|
|
for attr in unused_attributes: |
|
mod_itr, atoms = submod, attr.split(".") |
|
for atom in atoms[:-1]: |
|
mod_itr = getattr(mod_itr, atom) |
|
delattr(mod_itr, atoms[-1]) |
|
|
|
for node in attr_nodes: |
|
|
|
for user in copy.copy(node.users): |
|
assert user.op == "call_module" |
|
delete_user_reference(node, user) |
|
|
|
split.graph.erase_node(node) |
|
|
|
split.delete_all_unused_submodules() |
|
split.graph.lint() |
|
split.recompile() |
|
|
|
num_stages = Pipe._number_and_count_forward_stages(split) |
|
|
|
has_loss_and_backward = False |
|
generated_loss_spec = output_loss_value_spec |
|
|
|
if output_loss_value_spec is not None: |
|
loss_node, output_node, generated_loss_spec = _find_loss_output( |
|
mod, split.graph, output_loss_value_spec |
|
) |
|
if loss_node is not None: |
|
_insert_stage_symbolic_backward( |
|
split.graph, |
|
loss_node, |
|
output_node, |
|
) |
|
split.recompile() |
|
has_loss_and_backward = True |
|
logger.debug("Pipeline is in training mode, backward pass generated") |
|
else: |
|
raise RuntimeError( |
|
f"Did not find any loss value according to {output_loss_value_spec=}" |
|
) |
|
else: |
|
logger.debug("Pipeline is in inference mode, backward pass not generated") |
|
|
|
logger.debug("Full pipe model:\n" f"{split}") |
|
|
|
return Pipe( |
|
split, |
|
num_stages, |
|
has_loss_and_backward, |
|
generated_loss_spec, |
|
) |
|
|
|
def print_readable(self): |
|
""" |
|
Print the pipe in a human-readable format. |
|
This will print both the root pipe and each stage module. |
|
""" |
|
self.split_gm.print_readable() |
|
|
|
@staticmethod |
|
def _trace_with_export( |
|
mod: torch.nn.Module, |
|
example_args: Tuple[Any, ...], |
|
example_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> ExportedProgram: |
|
logger.info("Tracing model ...") |
|
try: |
|
ep = torch.export.export( |
|
mod, |
|
example_args, |
|
example_kwargs, |
|
) |
|
except Exception as e: |
|
raise RuntimeError( |
|
"It seems that we cannot capture your model as a full graph. " |
|
"Typical reasons include graph breaks, data/shape-dependent " |
|
"control flow, or missing meta kernels for custom operators. " |
|
"You can use our manual pipeline interfaces, or try to fix the " |
|
"graph breaks, see https://pytorch.org/docs/stable/export.html" |
|
) from e |
|
|
|
return ep |
|
|
|
@staticmethod |
|
def from_tracing( |
|
mod: torch.nn.Module, |
|
example_args: Tuple[Any, ...], |
|
example_kwargs: Optional[Dict[str, Any]] = None, |
|
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, |
|
): |
|
|
|
|
|
multi_use_param_spec = MultiUseParameterConfig.REPLICATE |
|
|
|
|
|
output_loss_value_spec: Any = None |
|
|
|
""" |
|
if output_chunk_spec is not None: |
|
output_loss_value_spec = map_aggregate( |
|
output_chunk_spec, lambda v: isinstance(v, _LossReducer) |
|
) |
|
""" |
|
|
|
|
|
exported_program = Pipe._trace_with_export( |
|
mod, |
|
example_args, |
|
example_kwargs, |
|
) |
|
|
|
pipe = Pipe._from_traced( |
|
mod, |
|
exported_program, |
|
multi_use_param_spec, |
|
output_loss_value_spec=output_loss_value_spec, |
|
split_policy=split_policy, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
split = pipe.split_gm |
|
traced = exported_program.module() |
|
submod0 = next(iter(split.children())) |
|
submod0_sign = signature(submod0.forward) |
|
model_sign = signature(traced.forward) |
|
if len(model_sign.parameters) != len(submod0_sign.parameters): |
|
|
|
|
|
logger.info( |
|
f"Original model takes {len(model_sign.parameters)} args but the " |
|
f"first pipeline stage takes {len(submod0_sign.parameters)}. " |
|
"Please provide args to respective pipeline stages." |
|
) |
|
else: |
|
|
|
submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) |
|
|
|
|
|
|
|
submod0.graph._codegen.pytree_info = ( |
|
submod0.graph._codegen.pytree_info._replace(out_spec=None) |
|
) |
|
submod0.recompile() |
|
|
|
return pipe |
|
|
|
def __str__(self): |
|
return self.split_gm.__str__() |
|
|
|
def __repr__(self): |
|
return self.split_gm.__repr__() |
|
|
|
def info(self) -> PipeInfo: |
|
""" |
|
Get information about the pipe. |
|
|
|
Returns |
|
------- |
|
PipeInfo |
|
A dataclass containing information about the pipe. |
|
""" |
|
return PipeInfo( |
|
graph=self.split_gm.graph, |
|
num_stages=self.num_stages, |
|
has_loss_and_backward=self.has_loss_and_backward, |
|
) |
|
|
|
def build_stage( |
|
self, |
|
stage_index: int, |
|
device: torch.device, |
|
group: Optional[ProcessGroup] = None, |
|
) -> _PipelineStage: |
|
""" |
|
Create a `PipelineStage` given a stage index and distributed group. |
|
The `PipelineStage` can run with `PipelineSchedule`s. |
|
""" |
|
|
|
stage_module = self.get_stage_module(stage_index) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(stage_module, torch.fx.GraphModule): |
|
_modify_graph_op_device(stage_module, device) |
|
else: |
|
logger.warning( |
|
f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe_info = self.info() |
|
return _PipelineStage(stage_module, stage_index, pipe_info, device, group) |
|
|
|
|
|
class SplitPoint(Enum): |
|
BEGINNING = 1 |
|
END = 2 |
|
|
|
|
|
|
|
|
|
class PipeSplitWrapper: |
|
|
|
SplitPoint = SplitPoint |
|
|
|
|
|
def _split_before_forward(self, *args, **kwargs): |
|
pipe_split() |
|
return self._orig_forward(*args, **kwargs) |
|
|
|
|
|
def _split_after_forward(self, *args, **kwargs): |
|
try: |
|
return self._orig_forward(*args, **kwargs) |
|
finally: |
|
pipe_split() |
|
|
|
|
|
def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): |
|
|
|
for qualname, split_type in spec.items(): |
|
atoms = qualname.split(".") |
|
predecessor_module = mod |
|
for i, atom in enumerate(atoms[:-1]): |
|
try: |
|
predecessor_module = getattr(predecessor_module, atom) |
|
except AttributeError as e: |
|
raise AttributeError( |
|
f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}' |
|
) from e |
|
|
|
mod_to_wrap = getattr(predecessor_module, atoms[-1]) |
|
mod_to_wrap._orig_forward = mod_to_wrap.forward |
|
if split_type == SplitPoint.BEGINNING: |
|
mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap) |
|
elif split_type == SplitPoint.END: |
|
mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap) |
|
else: |
|
raise ValueError("Unknown split point type.") |
|
|
|
|
|
def pipeline( |
|
module: torch.nn.Module, |
|
mb_args: Tuple[Any, ...], |
|
mb_kwargs: Optional[Dict[str, Any]] = None, |
|
split_spec: Optional[Dict[str, SplitPoint]] = None, |
|
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, |
|
) -> Pipe: |
|
""" |
|
Split a module based on a specification. |
|
|
|
See `Pipe` for more details. |
|
|
|
Arguments |
|
--------- |
|
module: |
|
The module to be splitted. |
|
mb_args: |
|
Example positional inputs, in micro-batch form. |
|
mb_kwargs: |
|
Example keyword inputs, in micro-batch form. (default: `None`) |
|
split_spec: |
|
A dictionary using submodule names as split marker. (default: `None`) |
|
split_policy: |
|
The policy to use for splitting the module. (default: `None`) |
|
|
|
Returns |
|
------- |
|
A pipeline representation of class `Pipe`. |
|
""" |
|
if split_spec is not None and split_policy is not None: |
|
raise ValueError( |
|
"Cannot specify both `split_spec` and `split_policy`. Please use only one of them." |
|
) |
|
|
|
if split_spec is not None: |
|
|
|
annotate_split_points(module, split_spec) |
|
return Pipe.from_tracing( |
|
mod=module, |
|
example_args=mb_args, |
|
example_kwargs=mb_kwargs, |
|
) |
|
else: |
|
|
|
return Pipe.from_tracing( |
|
mod=module, |
|
example_args=mb_args, |
|
example_kwargs=mb_kwargs, |
|
split_policy=split_policy, |
|
) |
|
|