|
|
|
import copy |
|
import functools |
|
import heapq |
|
import itertools |
|
import logging |
|
import math |
|
import operator |
|
import os |
|
from collections import defaultdict |
|
from dataclasses import dataclass, replace |
|
from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union |
|
|
|
import torch |
|
import torch._inductor.inductor_prims |
|
import torch.fx as fx |
|
import torch.utils._pytree as pytree |
|
from torch.fx.experimental._backward_state import BackwardState |
|
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types |
|
from torch.fx.experimental.sym_node import magic_methods, method_to_operator |
|
from torch.fx.experimental.symbolic_shapes import ( |
|
find_symbol_binding_fx_nodes, |
|
free_symbols, |
|
hint_int, |
|
is_symbol_binding_fx_node, |
|
) |
|
from torch.fx.passes import graph_drawer |
|
from . import config |
|
from ._aot_autograd.logging_utils import get_aot_graph_name |
|
from .compile_utils import fx_graph_cse, get_aten_target |
|
|
|
if TYPE_CHECKING: |
|
import sympy |
|
|
|
|
|
AOT_PARTITIONER_DEBUG = config.debug_partitioner |
|
log = logging.getLogger(__name__) |
|
|
|
aten = torch.ops.aten |
|
prims = torch.ops.prims |
|
|
|
|
|
@dataclass |
|
class OpTypes: |
|
"""Class for keeping track of different operator categories""" |
|
|
|
fusible_ops: Set[Callable] |
|
compute_intensive_ops: Set[Callable] |
|
random_ops: Set[Callable] |
|
view_ops: Set[Callable] |
|
recomputable_ops: Set[Callable] |
|
|
|
def is_fusible(self, node: fx.Node): |
|
return get_aten_target(node) in self.fusible_ops |
|
|
|
def is_compute_intensive(self, node: fx.Node): |
|
return get_aten_target(node) in self.compute_intensive_ops |
|
|
|
def is_random(self, node: fx.Node): |
|
return get_aten_target(node) in self.random_ops |
|
|
|
def is_view(self, node: fx.Node): |
|
return get_aten_target(node) in self.view_ops |
|
|
|
def is_recomputable(self, node: fx.Node): |
|
return get_aten_target(node) in self.recomputable_ops |
|
|
|
|
|
@dataclass |
|
class NodeInfo: |
|
|
|
|
|
inputs: List[fx.Node] |
|
_required_fw_nodes: Set[fx.Node] |
|
required_bw_nodes: Set[fx.Node] |
|
unclaimed_nodes: Set[fx.Node] |
|
fw_order: Dict[fx.Node, int] |
|
|
|
@functools.cached_property |
|
def required_fw_nodes(self) -> List[fx.Node]: |
|
return sorted( |
|
(n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] |
|
) |
|
|
|
def is_required_fw(self, n: fx.Node) -> bool: |
|
return n in self._required_fw_nodes |
|
|
|
def is_required_bw(self, n: fx.Node) -> bool: |
|
return n in self.required_bw_nodes |
|
|
|
def is_unclaimed(self, n: fx.Node) -> bool: |
|
return n in self.unclaimed_nodes |
|
|
|
def get_fw_order(self, n: fx.Node) -> int: |
|
assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!" |
|
return self.fw_order[n] |
|
|
|
|
|
@dataclass |
|
class MinCutOptions: |
|
ban_if_used_far_apart: bool |
|
ban_if_long_fusible_chains: bool |
|
ban_if_materialized_backward: bool |
|
ban_if_not_in_allowlist: bool |
|
ban_if_reduction: bool |
|
|
|
|
|
def must_recompute(node: fx.Node) -> bool: |
|
return node.meta.get("recompute", False) |
|
|
|
|
|
def has_recomputable_ops(fx_g: fx.GraphModule) -> bool: |
|
found = False |
|
for node in fx_g.graph.nodes: |
|
if must_recompute(node): |
|
return True |
|
return False |
|
|
|
|
|
def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool: |
|
for node in fx_g.graph.nodes: |
|
if ( |
|
must_recompute(node) |
|
and hasattr(node.target, "tags") |
|
and torch.Tag.nondeterministic_seeded in node.target.tags |
|
): |
|
return True |
|
return False |
|
|
|
|
|
def sym_node_size(node: fx.Node) -> int: |
|
if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): |
|
return 1 |
|
assert isinstance(node.meta["val"], torch.SymFloat) |
|
return 4 |
|
|
|
|
|
class InvalidNodeBase: |
|
def __repr__(self): |
|
return "Invalid Node" |
|
|
|
|
|
InvalidNode = InvalidNodeBase() |
|
|
|
|
|
def _extract_graph_with_inputs_outputs( |
|
joint_graph: fx.Graph, inputs: List[fx.Node], outputs: List[fx.Node] |
|
) -> fx.Graph: |
|
""" |
|
Given a graph, extracts out a subgraph that takes the specified nodes as |
|
inputs and returns the specified outputs. |
|
|
|
This includes specifying non-placeholder nodes as inputs. |
|
|
|
The general strategy is to initialize all inputs with proxies as we |
|
encounter them, and trace through the graph, only keeping values which take |
|
in valid proxies. Then, all dead code is eliminated. |
|
""" |
|
new_graph = fx.Graph() |
|
env = {} |
|
|
|
|
|
for node in inputs: |
|
new_node = new_graph.placeholder(node.name) |
|
|
|
new_node.meta = node.meta |
|
env[node] = new_node |
|
|
|
for node in joint_graph.nodes: |
|
if node in env: |
|
|
|
|
|
|
|
continue |
|
elif node.op == "placeholder": |
|
env[node] = InvalidNode |
|
elif node.op == "call_function": |
|
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) |
|
all_args = [ |
|
isinstance(env[x], InvalidNodeBase) |
|
for x in all_args |
|
if isinstance(x, fx.Node) |
|
] |
|
if any(all_args): |
|
env[node] = InvalidNode |
|
continue |
|
env[node] = new_graph.node_copy(node, lambda x: env[x]) |
|
elif node.op == "get_attr": |
|
env[node] = new_graph.node_copy(node, lambda x: env[x]) |
|
elif node.op == "output": |
|
pass |
|
output_values = [] |
|
for x in outputs: |
|
if isinstance(x, fx.Node): |
|
if x not in env: |
|
raise RuntimeError(f"Node {x} couldn't be found in env") |
|
assert not isinstance( |
|
env[x], InvalidNodeBase |
|
), f"Node {x} was invalid, but is output" |
|
output_values.append(env[x]) |
|
else: |
|
output_values.append(x) |
|
new_graph.output(output_values) |
|
|
|
new_graph.eliminate_dead_code() |
|
new_graph.lint() |
|
return new_graph |
|
|
|
|
|
def _is_primal(node: fx.Node) -> bool: |
|
return ( |
|
node.op == "placeholder" |
|
and "tangents" not in str(node.target) |
|
and not _is_bwd_seed_offset(node) |
|
and not _is_fwd_seed_offset(node) |
|
) |
|
|
|
|
|
def _is_tangent(node: fx.Node) -> bool: |
|
return node.op == "placeholder" and "tangents" in str(node.target) |
|
|
|
|
|
def _is_bwd_seed_offset(node: fx.Node) -> bool: |
|
return node.op == "placeholder" and ( |
|
"bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target) |
|
) |
|
|
|
|
|
def _is_fwd_seed_offset(node: fx.Node) -> bool: |
|
return node.op == "placeholder" and ( |
|
"fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target) |
|
) |
|
|
|
|
|
def _is_backward_state(node: fx.Node) -> bool: |
|
return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) |
|
|
|
|
|
def _extract_fwd_bwd_outputs( |
|
joint_module: fx.GraphModule, *, num_fwd_outputs |
|
) -> Tuple[List[fx.Node], List[fx.Node]]: |
|
outputs = pytree.arg_tree_leaves( |
|
*(node.args for node in joint_module.graph.find_nodes(op="output")) |
|
) |
|
fwd_outputs = outputs[:num_fwd_outputs] |
|
bwd_outputs = outputs[num_fwd_outputs:] |
|
return fwd_outputs, bwd_outputs |
|
|
|
|
|
def _remove_by_name(saved_values: List[fx.Node], name: str): |
|
for saved_value in saved_values: |
|
if saved_value.name == name: |
|
saved_values.remove(saved_value) |
|
break |
|
|
|
|
|
def _extract_fwd_bwd_modules( |
|
joint_module: fx.GraphModule, |
|
saved_values: List[fx.Node], |
|
saved_sym_nodes: List[fx.Node], |
|
*, |
|
num_fwd_outputs: int, |
|
) -> Tuple[fx.GraphModule, fx.GraphModule]: |
|
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( |
|
joint_module, num_fwd_outputs=num_fwd_outputs |
|
) |
|
placeholders = joint_module.graph.find_nodes(op="placeholder") |
|
primal_inputs = [*filter(_is_primal, placeholders)] |
|
tangent_inputs = [*filter(_is_tangent, placeholders)] |
|
fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] |
|
bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] |
|
backward_state_inputs = [*filter(_is_backward_state, placeholders)] |
|
|
|
bwd_graph = _extract_graph_with_inputs_outputs( |
|
joint_module.graph, |
|
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, |
|
bwd_outputs, |
|
) |
|
|
|
for node in bwd_graph.find_nodes(op="placeholder"): |
|
|
|
if not node.users: |
|
_remove_by_name(saved_values, node.name) |
|
_remove_by_name(saved_sym_nodes, node.name) |
|
elif _is_backward_state(node): |
|
|
|
_remove_by_name(saved_values, node.name) |
|
assert backward_state_inputs |
|
|
|
|
|
|
|
|
|
|
|
saved_symbols: Set[sympy.Symbol] = set() |
|
saved_sym_nodes_binding = [] |
|
saved_sym_nodes_derived = [] |
|
|
|
|
|
|
|
for node in saved_sym_nodes: |
|
symbol = is_symbol_binding_fx_node(node) |
|
if symbol: |
|
saved_symbols.add(symbol) |
|
saved_sym_nodes_binding.append(node) |
|
else: |
|
saved_sym_nodes_derived.append(node) |
|
|
|
|
|
|
|
symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) |
|
for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs): |
|
if "val" not in node.meta: |
|
continue |
|
new_symbols = free_symbols(node.meta["val"]) - saved_symbols |
|
|
|
for s in sorted(new_symbols, key=lambda s: s.name): |
|
|
|
|
|
|
|
if s not in symbol_bindings: |
|
continue |
|
saved_sym_nodes_binding.append(symbol_bindings[s]) |
|
saved_symbols |= new_symbols |
|
|
|
|
|
|
|
|
|
saved_sym_nodes.clear() |
|
saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) |
|
|
|
|
|
|
|
fwd_graph = _extract_graph_with_inputs_outputs( |
|
joint_module.graph, |
|
primal_inputs + fwd_seed_offset_inputs, |
|
fwd_outputs + saved_values + saved_sym_nodes, |
|
) |
|
bwd_graph = _extract_graph_with_inputs_outputs( |
|
joint_module.graph, |
|
saved_sym_nodes |
|
+ saved_values |
|
+ tangent_inputs |
|
+ bwd_seed_offset_inputs |
|
+ backward_state_inputs, |
|
bwd_outputs, |
|
) |
|
|
|
fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) |
|
bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) |
|
return fwd_module, bwd_module |
|
|
|
|
|
def default_partition( |
|
joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs |
|
) -> Tuple[fx.GraphModule, fx.GraphModule]: |
|
""" |
|
Partitions the :attr:`joint_module` in a manner that closely resembles the |
|
behavior observed in the original ``.forward()`` and ``.backward()`` of the |
|
callable, i.e., the resulting forward graph contains those operators that |
|
are executed in the original ``.forward()`` callable passed to |
|
:func:`aot_function`. |
|
|
|
The default partitioner collects the operators that are between the forward |
|
inputs and the forward outputs. This helps in finding the tensors which have |
|
to be stashed for the backward pass. These stashed tensors become the output |
|
of the generated forward graph. The remaining operators are then placed in |
|
the backward graph. |
|
|
|
.. warning:: |
|
This API is experimental and likely to change. |
|
|
|
Args: |
|
joint_module(fx.GraphModule): The joint forward and backward graph. This |
|
is the result of AOT Autograd tracing. |
|
|
|
Returns: |
|
Returns the generated forward and backward Fx graph modules. |
|
""" |
|
if has_recomputable_ops(joint_module): |
|
return min_cut_rematerialization_partition( |
|
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs |
|
) |
|
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) |
|
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) |
|
inputs = primal_inputs + fwd_seed_offset_inputs |
|
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( |
|
joint_module, num_fwd_outputs=num_fwd_outputs |
|
) |
|
forward_only_graph = _extract_graph_with_inputs_outputs( |
|
joint_module.graph, inputs, fwd_outputs |
|
) |
|
forward_node_names = { |
|
node.name for node in forward_only_graph.nodes if node.op != "output" |
|
} |
|
saved_values = [] |
|
saved_sym_nodes = [] |
|
|
|
for node in joint_module.graph.nodes: |
|
if node.name not in forward_node_names: |
|
continue |
|
if is_sym_node(node): |
|
|
|
|
|
saved_sym_nodes.append(node) |
|
elif "tensor_meta" not in node.meta and node.op == "call_function": |
|
|
|
users = node.users |
|
assert all(user.target == operator.getitem for user in users) |
|
saved_values.extend(users) |
|
else: |
|
backward_usages = [ |
|
n for n in node.users if n.name not in forward_node_names |
|
] |
|
if "tensor_meta" in node.meta and all( |
|
is_sym_node(n) for n in backward_usages |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
saved_sym_nodes.extend(backward_usages) |
|
else: |
|
saved_values.append(node) |
|
saved_values = list(dict.fromkeys(saved_values).keys()) |
|
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) |
|
|
|
return _extract_fwd_bwd_modules( |
|
joint_module, |
|
saved_values, |
|
saved_sym_nodes=saved_sym_nodes, |
|
num_fwd_outputs=num_fwd_outputs, |
|
) |
|
|
|
|
|
INT_INF = int(1e6) |
|
|
|
|
|
def _tensor_nbytes(numel: int, dtype) -> int: |
|
return numel * dtype.itemsize |
|
|
|
|
|
def _size_of(node: fx.Node) -> int: |
|
if "val" in node.meta: |
|
val = node.meta["val"] |
|
if isinstance(val, py_sym_types): |
|
return 1 |
|
|
|
|
|
|
|
elif isinstance(val, (list, tuple)): |
|
return sum( |
|
_tensor_nbytes(hint_int(n.numel(), fallback=4096), n.dtype) |
|
for n in val |
|
if isinstance(n, torch.Tensor) |
|
) |
|
elif isinstance(val, torch.Tensor): |
|
return _tensor_nbytes(hint_int(val.numel(), fallback=4096), val.dtype) |
|
|
|
raise RuntimeError(f"Unknown metadata type {type(val)}") |
|
if node.op == "get_attr": |
|
return 0 |
|
raise RuntimeError("We should always have `val` metadata on the nodes") |
|
|
|
|
|
|
|
def _count_ops(graph: fx.Graph): |
|
from collections import defaultdict |
|
|
|
cnt: Dict[str, int] = defaultdict(int) |
|
for node in graph.nodes: |
|
if node.op == "call_function": |
|
cnt[node.target.__name__] += 1 |
|
print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) |
|
|
|
|
|
@functools.lru_cache(None) |
|
def pointwise_ops(): |
|
ops = [] |
|
for attr_name in dir(torch.ops.aten): |
|
opoverloadpacket = getattr(torch.ops.aten, attr_name) |
|
if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket): |
|
continue |
|
|
|
for overload in opoverloadpacket.overloads(): |
|
op_overload = getattr(opoverloadpacket, overload) |
|
if torch.Tag.pointwise in op_overload.tags: |
|
|
|
ops.append(opoverloadpacket) |
|
break |
|
|
|
return ops |
|
|
|
|
|
def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]: |
|
arg_depths = { |
|
arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) |
|
} |
|
return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: |
|
""" |
|
This pass finds the first bwd node in the graph (by looking at users of |
|
tangents) and then reorders the graph by walking from this node to all the |
|
way to the end of the graph. At each op in this traveral, we insert this op |
|
in a new graph and try to bring only the relevant subgraph from the other |
|
non-bwd edges relevant for this op. This closely mimics the behavior of |
|
autograd engine. |
|
|
|
Why is this pass required in the first place? |
|
|
|
This is an artifact of how partitioners work today. The starting point of |
|
partitioner is a joint graph, which is fwd and then bwd graph. In the case |
|
of checkpointing, we keep portions of fwd graph in their original place in |
|
the joint graph, while obtaining a bwd graph. As a result, the resulting bwd |
|
graph has copies of recomputed fwd subgraphs followed by the original bwd |
|
graph. If we run this naively, this leads to bad memory footprint, because |
|
the fwd subgraphs are live for way longer duration than necessary. This pass |
|
reorders the operations such that we prioritize the ops for the original bwd |
|
graph while only realizing those ops from the fwd graph that are necessary |
|
at any given point in the graph. |
|
""" |
|
|
|
new_graph = fx.Graph() |
|
env: Dict[fx.Node, fx.Node] = {} |
|
|
|
|
|
for node in gm.graph.find_nodes(op="placeholder"): |
|
env[node] = new_graph.node_copy(node, lambda x: env[x]) |
|
|
|
order = {} |
|
for idx, node in enumerate(gm.graph.nodes): |
|
order[node] = idx |
|
|
|
def insert_node_in_graph(node): |
|
cur_nodes = [node] |
|
insertable_nodes = set() |
|
while len(cur_nodes) > 0: |
|
node = cur_nodes.pop() |
|
if node in insertable_nodes or node in env: |
|
continue |
|
insertable_nodes.add(node) |
|
|
|
|
|
|
|
cur_nodes += node.all_input_nodes |
|
|
|
insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) |
|
for node in insertable_nodes: |
|
env[node] = new_graph.node_copy(node, lambda x: env[x]) |
|
|
|
|
|
tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) |
|
first_node_in_bwd = None |
|
minimum_order = math.inf |
|
for tangent in tangent_inputs: |
|
for user in tangent.users: |
|
if order[user] < minimum_order: |
|
minimum_order = order[user] |
|
first_node_in_bwd = user |
|
|
|
|
|
if first_node_in_bwd is None: |
|
return gm |
|
|
|
|
|
for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: |
|
insert_node_in_graph(node) |
|
|
|
|
|
new_gm = torch.fx.GraphModule(gm, new_graph) |
|
return new_gm |
|
|
|
|
|
def functionalize_rng_ops( |
|
joint_module: fx.GraphModule, |
|
fw_module: fx.GraphModule, |
|
bw_module: fx.GraphModule, |
|
num_sym_nodes: int, |
|
) -> Tuple[fx.GraphModule, fx.GraphModule]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uid = itertools.count() |
|
|
|
def get_rng_ops(gmod): |
|
random_nodes = {} |
|
for node in gmod.graph.nodes: |
|
if ( |
|
node.op == "call_function" |
|
and hasattr(node.target, "tags") |
|
and torch.Tag.nondeterministic_seeded in node.target.tags |
|
): |
|
random_nodes[node.name] = node |
|
return random_nodes |
|
|
|
def get_device(node): |
|
""" |
|
Check the example value of the node outputs to find the device type. |
|
""" |
|
if "val" not in node.meta: |
|
return None |
|
|
|
candidates = node.meta["val"] |
|
if not isinstance(candidates, tuple): |
|
candidates = (candidates,) |
|
|
|
for candidate in candidates: |
|
if isinstance(candidate, torch.Tensor): |
|
if candidate.device.type == "cuda": |
|
return "cuda" |
|
|
|
return "cpu" |
|
|
|
def get_sample_rng_state(device): |
|
if device == "cuda": |
|
return torch.cuda.get_rng_state() |
|
return torch.get_rng_state() |
|
|
|
|
|
joint_graph_rng_ops = get_rng_ops(joint_module) |
|
fw_graph_rng_ops = get_rng_ops(fw_module) |
|
bw_graph_rng_ops = get_rng_ops(bw_module) |
|
recomputable_rng_ops_map = dict() |
|
for node in joint_module.graph.nodes: |
|
if ( |
|
must_recompute(node) |
|
and hasattr(node.target, "tags") |
|
and torch.Tag.nondeterministic_seeded in node.target.tags |
|
): |
|
base_node = joint_graph_rng_ops[node.name] |
|
fw_node = fw_graph_rng_ops[node.name] |
|
bw_node = bw_graph_rng_ops[node.name] |
|
recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node} |
|
|
|
run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state |
|
run_with_rng_state = torch._prims.rng_prims.run_with_rng_state |
|
bw_tangent_start_node = None |
|
for node in bw_module.graph.find_nodes(op="placeholder"): |
|
if "tangent" in node.name: |
|
bw_tangent_start_node = node |
|
break |
|
if bw_tangent_start_node is None: |
|
raise RuntimeError( |
|
"Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this" |
|
) |
|
|
|
fw_rng_state_outputs = [] |
|
for base_node, node_pair in recomputable_rng_ops_map.items(): |
|
|
|
fw_node = node_pair["fwd"] |
|
bw_node = node_pair["bwd"] |
|
fw_graph = fw_module.graph |
|
with fw_graph.inserting_before(fw_node): |
|
functional_fw_node = fw_graph.create_node( |
|
"call_function", |
|
run_and_save_rng, |
|
args=(fw_node.target, *fw_node.args), |
|
kwargs=fw_node.kwargs, |
|
) |
|
state = fw_graph.create_node( |
|
"call_function", |
|
operator.getitem, |
|
args=(functional_fw_node, 0), |
|
kwargs={}, |
|
) |
|
rng_output = fw_graph.create_node( |
|
"call_function", |
|
operator.getitem, |
|
args=( |
|
functional_fw_node, |
|
1, |
|
), |
|
kwargs={}, |
|
) |
|
fw_node.replace_all_uses_with(rng_output) |
|
fw_graph.erase_node(fw_node) |
|
fw_rng_state_outputs.append(state) |
|
|
|
|
|
bw_graph = bw_module.graph |
|
with bw_graph.inserting_before(bw_tangent_start_node): |
|
state_name = f"rng_state_output_{next(uid)}" |
|
bw_rng_state_node = bw_graph.placeholder(state_name) |
|
bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node)) |
|
|
|
with bw_graph.inserting_before(bw_node): |
|
rng_output = bw_graph.create_node( |
|
"call_function", |
|
run_with_rng_state, |
|
args=(bw_rng_state_node, bw_node.target, *bw_node.args), |
|
kwargs=bw_node.kwargs, |
|
) |
|
|
|
bw_node.replace_all_uses_with(rng_output) |
|
bw_graph.erase_node(bw_node) |
|
|
|
|
|
|
|
|
|
fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) |
|
fw_outputs = fw_output_node.args[0] |
|
sym_node_start_idx = len(fw_outputs) - num_sym_nodes |
|
outputs = ( |
|
fw_outputs[:sym_node_start_idx] |
|
+ fw_rng_state_outputs |
|
+ fw_outputs[sym_node_start_idx:] |
|
) |
|
fw_module.graph.output(outputs) |
|
fw_module.graph.erase_node(fw_output_node) |
|
fw_module.recompile() |
|
bw_module.recompile() |
|
return fw_module, bw_module |
|
|
|
|
|
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: |
|
""" |
|
If there are two consecutive checkpointed blocks with no operator in |
|
between, we would still want to stash the tensor at the boundary of |
|
checkpointed blocks. The following pass makes the last output node |
|
non-recomputable to allow for that. |
|
""" |
|
for node in joint_module.graph.nodes: |
|
if must_recompute(node): |
|
for user in node.users: |
|
if ( |
|
must_recompute(user) |
|
and user.meta["recompute"] > node.meta["recompute"] |
|
): |
|
node.meta["recompute"] = 0 |
|
return joint_module |
|
|
|
|
|
def solve_min_cut( |
|
joint_graph: fx.Graph, |
|
node_info: NodeInfo, |
|
min_cut_options: MinCutOptions, |
|
dont_ban=None, |
|
): |
|
if dont_ban is None: |
|
dont_ban = set() |
|
op_types = get_default_op_list() |
|
|
|
if AOT_PARTITIONER_DEBUG: |
|
joint_module_ops = { |
|
str(node.target._overloadpacket) |
|
for node in joint_graph.nodes |
|
if node.op == "call_function" and hasattr(node.target, "_overloadpacket") |
|
} |
|
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} |
|
print("Ops banned from rematerialization: ", ops_ignored) |
|
print() |
|
|
|
def is_fusible(a, b): |
|
|
|
|
|
if get_aten_target(b) == aten.cat: |
|
return True |
|
return op_types.is_fusible(a) and op_types.is_fusible(b) |
|
|
|
try: |
|
import networkx as nx |
|
except ImportError as e: |
|
raise RuntimeError( |
|
"Need networkx installed to perform smart recomputation " "heuristics" |
|
) from e |
|
|
|
def is_materialized_backwards(node): |
|
if op_types.is_view(node): |
|
return False |
|
cur_nodes = {node} |
|
while len(cur_nodes) > 0: |
|
cur = cur_nodes.pop() |
|
for user in cur.users: |
|
if not node_info.is_required_fw(user) and not is_fusible(cur, user): |
|
return True |
|
if op_types.is_view(user): |
|
cur_nodes.add(user) |
|
|
|
return False |
|
|
|
def should_ban_recomputation(node): |
|
if node.op != "call_function": |
|
return False |
|
if node.target == operator.getitem: |
|
return False |
|
if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: |
|
return False |
|
|
|
if node.meta.get("recompute", None) == 0: |
|
return True |
|
|
|
if min_cut_options.ban_if_not_in_allowlist: |
|
if not op_types.is_recomputable(node): |
|
return True |
|
else: |
|
if op_types.is_random(node) or op_types.is_compute_intensive(node): |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( |
|
node |
|
): |
|
log.info("materialized backwards: %s %s", node, tuple(node.users)) |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
if min_cut_options.ban_if_reduction: |
|
input_tensors_size = sum( |
|
_size_of(i) for i in node.args if isinstance(i, fx.Node) |
|
) |
|
output_size = _size_of(node) |
|
return output_size * 4 < input_tensors_size |
|
return False |
|
|
|
def is_materialized(node): |
|
if node.op == "placeholder": |
|
return True |
|
|
|
return not all(is_fusible(node, user) for user in node.users) |
|
|
|
def get_node_weight(node) -> float: |
|
mem_sz = _size_of(node) |
|
|
|
if isinstance(node.meta["val"], py_sym_types): |
|
|
|
if not isinstance(node.meta["val"], torch.SymInt): |
|
return INT_INF |
|
|
|
|
|
|
|
mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) |
|
if is_materialized(node): |
|
return mem_sz |
|
else: |
|
return mem_sz * 2 |
|
|
|
nx_graph = nx.DiGraph() |
|
banned_nodes = set() |
|
|
|
def ban_recomputation_if_allowed(node): |
|
if op_types.is_view(node): |
|
return False |
|
if node in dont_ban: |
|
return False |
|
|
|
|
|
|
|
|
|
if node.meta.get("recompute", 0) > 0: |
|
return False |
|
|
|
if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat): |
|
return False |
|
|
|
banned_nodes.add(node) |
|
|
|
|
|
|
|
|
|
nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) |
|
return True |
|
|
|
for node in joint_graph.nodes: |
|
if node.op == "output": |
|
continue |
|
|
|
if node in node_info.required_bw_nodes: |
|
if node not in node_info.inputs: |
|
nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf) |
|
|
|
if _is_primal(node) or _is_fwd_seed_offset(node): |
|
ban_recomputation_if_allowed(node) |
|
|
|
|
|
|
|
|
|
if node_info.is_required_fw(node) and should_ban_recomputation(node): |
|
ban_recomputation_if_allowed(node) |
|
|
|
|
|
is_non_tensor_node = ( |
|
"val" not in node.meta and "tensor_meta" not in node.meta |
|
) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor)) |
|
|
|
if is_sym_node(node): |
|
weight = float(sym_node_size(node)) |
|
elif is_non_tensor_node: |
|
weight = ( |
|
0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf |
|
) |
|
else: |
|
weight = get_node_weight(node) |
|
|
|
nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) |
|
for user in node.users: |
|
nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: |
|
""" |
|
Finds the first unfusible node in the chain of nodes starting from |
|
`start_nodes` and returns its position. |
|
""" |
|
sorted_nodes: List[Tuple[int, fx.Node, bool]] = [] |
|
for n in start_nodes: |
|
heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True)) |
|
|
|
while len(sorted_nodes) > 0: |
|
_, node, node_is_fusible = heapq.heappop(sorted_nodes) |
|
if not node_is_fusible: |
|
return node_info.get_fw_order(node) |
|
for user in node.users: |
|
if node_info.is_required_fw(user): |
|
if node_info.get_fw_order(user) > max_range: |
|
continue |
|
heapq.heappush( |
|
sorted_nodes, |
|
(node_info.get_fw_order(user), user, is_fusible(node, user)), |
|
) |
|
return max_range |
|
|
|
if min_cut_options.ban_if_used_far_apart: |
|
for used_node in node_info.required_fw_nodes: |
|
orders = [ |
|
node_info.get_fw_order(user) |
|
for user in used_node.users |
|
if node_info.is_required_fw(user) |
|
] |
|
fw_users = [ |
|
user for user in used_node.users if node_info.is_required_fw(user) |
|
] |
|
if len(orders) > 0: |
|
first_unfusible_use = find_first_unfusible(fw_users, max(orders)) |
|
for user in tuple(used_node.users): |
|
if ( |
|
node_info.is_required_fw(user) |
|
and node_info.get_fw_order(user) > first_unfusible_use |
|
and is_fusible(used_node, user) |
|
): |
|
if user in banned_nodes: |
|
continue |
|
log.info( |
|
"used above/below fusible %s:(%s) -> %s -> %s:(%s)", |
|
used_node, |
|
node_info.get_fw_order(used_node), |
|
first_unfusible_use, |
|
user, |
|
node_info.get_fw_order(user), |
|
) |
|
ban_recomputation_if_allowed(user) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if min_cut_options.ban_if_long_fusible_chains: |
|
visited = set() |
|
for start_node in joint_graph.nodes: |
|
if not node_info.is_required_fw(start_node): |
|
continue |
|
fusible = [(node_info.get_fw_order(start_node), start_node)] |
|
start_order = node_info.get_fw_order(start_node) |
|
while len(fusible) > 0: |
|
_, cur = heapq.heappop(fusible) |
|
if cur in visited: |
|
continue |
|
visited.add(cur) |
|
|
|
if ( |
|
node_info.get_fw_order(cur) > start_order + 100 |
|
and len(fusible) == 0 |
|
): |
|
log.info( |
|
"too long %s %s %s %s", |
|
cur, |
|
start_node, |
|
node_info.get_fw_order(cur), |
|
node_info.get_fw_order(start_node), |
|
) |
|
ban_recomputation_if_allowed(cur) |
|
break |
|
|
|
for user in cur.users: |
|
if ( |
|
node_info.is_required_fw(user) |
|
and is_fusible(cur, user) |
|
and user not in banned_nodes |
|
): |
|
heapq.heappush(fusible, (node_info.get_fw_order(user), user)) |
|
|
|
try: |
|
cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") |
|
except Exception: |
|
print("Failed to compute min-cut on following graph:") |
|
print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) |
|
visualize_min_cut_graph(nx_graph) |
|
raise |
|
|
|
reachable, non_reachable = partition |
|
cutset: Set[Tuple[str, str]] = set() |
|
for u, nbrs in ((n, nx_graph[n]) for n in reachable): |
|
cutset.update((u, v) for v in nbrs if v in non_reachable) |
|
|
|
cut_nodes = set() |
|
for node_in, node_out in cutset: |
|
assert node_in[:-3] == node_out[:-4] |
|
node_name = node_in[:-3] |
|
cut_nodes.add(node_name) |
|
|
|
name_to_node = get_name_to_node(joint_graph) |
|
|
|
node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)} |
|
saved_values = sorted( |
|
(name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x] |
|
) |
|
return saved_values, banned_nodes |
|
|
|
|
|
def visualize_min_cut_graph(nx_graph): |
|
import networkx as nx |
|
import pydot |
|
|
|
dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string() |
|
dot_graph = pydot.graph_from_dot_data(dot_format)[0] |
|
for edge in dot_graph.get_edges(): |
|
weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"] |
|
|
|
edge.set_label(str(weight)) |
|
|
|
if weight == float("inf"): |
|
edge.set_color("red") |
|
print("Visualizing the failed graph to min_cut_failed.svg") |
|
dot_graph.write_svg("min_cut_failed.svg") |
|
|
|
|
|
def get_default_op_list() -> OpTypes: |
|
default_recomputable_ops: List[Callable] = [ |
|
aten.add, |
|
aten.sub, |
|
aten.div, |
|
aten.atan2, |
|
aten.mul, |
|
aten.max, |
|
aten.min, |
|
aten.pow, |
|
aten.remainder, |
|
aten.fmod, |
|
aten.__and__, |
|
aten.__or__, |
|
aten.__xor__, |
|
aten.__lshift__, |
|
aten.__rshift__, |
|
aten.eq, |
|
aten.ne, |
|
aten.ge, |
|
aten.gt, |
|
aten.le, |
|
aten.lt, |
|
aten.abs, |
|
aten.bitwise_not, |
|
aten.ceil, |
|
aten.floor, |
|
aten.frac, |
|
aten.neg, |
|
aten.relu, |
|
aten.round, |
|
aten.silu, |
|
aten.trunc, |
|
aten.log, |
|
aten.log10, |
|
aten.log1p, |
|
aten.log2, |
|
aten.lgamma, |
|
aten.exp, |
|
aten.expm1, |
|
aten.erf, |
|
aten.erfc, |
|
aten.cos, |
|
aten.acos, |
|
aten.cosh, |
|
aten.sin, |
|
aten.asin, |
|
aten.sinh, |
|
aten.tan, |
|
aten.atan, |
|
aten.tanh, |
|
aten.atanh, |
|
aten.sqrt, |
|
aten.rsqrt, |
|
aten.reciprocal, |
|
aten.sigmoid, |
|
aten.softplus, |
|
aten.threshold, |
|
aten.threshold_backward, |
|
aten.clamp, |
|
aten.where, |
|
aten.lerp, |
|
aten.addcmul, |
|
aten.gelu, |
|
aten.gelu_backward, |
|
aten.sum, |
|
aten.mean, |
|
aten._grad_sum_to_size, |
|
aten.sum_to_size, |
|
aten.amax, |
|
aten.to, |
|
aten.type_as, |
|
operator.getitem, |
|
aten.squeeze, |
|
aten.unsqueeze, |
|
aten.rsub, |
|
aten._to_copy, |
|
] |
|
recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] |
|
recomputable_view_ops += [ |
|
aten.view, |
|
aten.slice, |
|
aten.t, |
|
prims.broadcast_in_dim, |
|
aten.expand, |
|
aten.as_strided, |
|
aten.permute, |
|
] |
|
view_ops = recomputable_view_ops |
|
default_recomputable_ops += [ |
|
prims.div, |
|
prims.convert_element_type, |
|
aten.clone, |
|
aten._to_copy, |
|
aten.full_like, |
|
prims.var, |
|
prims.sum, |
|
aten.var, |
|
aten.std, |
|
prims.broadcast_in_dim, |
|
aten.select, |
|
aten._unsafe_view, |
|
aten.view, |
|
aten.expand, |
|
aten.slice, |
|
aten.reshape, |
|
aten.broadcast_tensors, |
|
aten.scalar_tensor, |
|
aten.ones, |
|
aten.new_zeros, |
|
aten.lift_fresh_copy, |
|
aten.arange, |
|
aten.triu, |
|
aten.var_mean, |
|
aten.isinf, |
|
aten.any, |
|
aten.full, |
|
aten.as_strided, |
|
aten.zeros, |
|
aten.argmax, |
|
aten.maximum, |
|
prims.iota, |
|
prims._low_memory_max_pool2d_offsets_to_indices, |
|
] |
|
|
|
default_recomputable_ops += [aten.index, aten.gather] |
|
default_recomputable_ops += view_ops |
|
|
|
default_recomputable_ops += pointwise_ops() |
|
|
|
default_recomputable_ops += [ |
|
aten.zeros_like, |
|
] |
|
|
|
default_recomputable_ops += [method_to_operator(m) for m in magic_methods] |
|
recomputable_ops = set(default_recomputable_ops) |
|
|
|
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] |
|
compute_intensive_ops = [ |
|
aten.mm, |
|
aten.convolution, |
|
aten.convolution_backward, |
|
aten.bmm, |
|
aten.addmm, |
|
aten._scaled_dot_product_flash_attention, |
|
aten._scaled_dot_product_efficient_attention, |
|
aten.upsample_bilinear2d, |
|
] |
|
|
|
fusible_ops = recomputable_ops | set(random_ops) |
|
return OpTypes( |
|
set(fusible_ops), |
|
set(compute_intensive_ops), |
|
set(random_ops), |
|
set(view_ops), |
|
set(recomputable_ops), |
|
) |
|
|
|
|
|
def get_name_to_node(graph: fx.Graph): |
|
name_to_node = {} |
|
for node in graph.nodes: |
|
name_to_node[node.name] = node |
|
return name_to_node |
|
|
|
|
|
def greedy_knapsack( |
|
memory: List[float], runtimes: List[float], max_memory: float |
|
) -> Tuple[float, List[int], List[int]]: |
|
n = len(runtimes) |
|
items = list(range(n)) |
|
|
|
|
|
items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) |
|
|
|
total_memory = 0.0 |
|
total_runtime = 0.0 |
|
items_to_save = [] |
|
items_to_allow_recomputing = [] |
|
|
|
for i in items: |
|
if total_memory + memory[i] <= max_memory: |
|
total_memory += memory[i] |
|
total_runtime += runtimes[i] |
|
items_to_save.append(i) |
|
else: |
|
items_to_allow_recomputing.append(i) |
|
return total_runtime, items_to_save, items_to_allow_recomputing |
|
|
|
|
|
def ilp_knapsack( |
|
memory: List[float], runtimes: List[float], max_memory: float |
|
) -> Tuple[float, List[int], List[int]]: |
|
import numpy as np |
|
|
|
try: |
|
from scipy.optimize import Bounds, LinearConstraint, milp |
|
except ImportError: |
|
raise RuntimeError( |
|
"To use the ILP for memory budget checkpointing you need to install scipy" |
|
) from None |
|
|
|
np_memory = np.array(memory) |
|
np_runtimes = np.array(runtimes) |
|
c = -np_runtimes |
|
|
|
memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) |
|
constraints = [memory_constraint] |
|
|
|
integrality = np.ones_like(c) |
|
res = milp( |
|
c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) |
|
) |
|
if not res.success: |
|
raise RuntimeError("Somehow scipy solving failed") |
|
|
|
items_to_save = [] |
|
items_to_allow_recomputing = [] |
|
for idx, i in enumerate(res.x): |
|
if i == 1: |
|
items_to_save.append(idx) |
|
else: |
|
items_to_allow_recomputing.append(idx) |
|
return -res.fun, items_to_save, items_to_allow_recomputing |
|
|
|
|
|
def dp_knapsack( |
|
memory: List[float], runtimes: List[float], max_memory: float |
|
) -> Tuple[float, List[int], List[int]]: |
|
|
|
S = 10000 |
|
|
|
|
|
quantized_memory = torch.tensor( |
|
[int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" |
|
) |
|
runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu") |
|
|
|
|
|
quantized_max_memory = int(round(max_memory * S)) |
|
|
|
n = len(memory) |
|
|
|
|
|
|
|
|
|
|
|
dp = torch.zeros( |
|
(n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" |
|
) |
|
|
|
for i in range(1, n + 1): |
|
current_memory = quantized_memory[i - 1] |
|
current_runtime = runtimes[i - 1] |
|
|
|
|
|
dp[i, :] = dp[i - 1, :] |
|
|
|
|
|
if current_memory == 0: |
|
dp[i, :] = dp[i - 1, :] + current_runtime |
|
else: |
|
dp[i, current_memory:] = torch.maximum( |
|
dp[i - 1, current_memory:], |
|
dp[i - 1, :-current_memory] + current_runtime, |
|
) |
|
|
|
|
|
saved_items = [] |
|
recomputable_items = [] |
|
j: int = quantized_max_memory |
|
for i in range(n, 0, -1): |
|
if dp[i][j] != dp[i - 1][j]: |
|
saved_items.append(i - 1) |
|
j -= int(quantized_memory[i - 1].item()) |
|
else: |
|
recomputable_items.append(i - 1) |
|
|
|
saved_items.reverse() |
|
|
|
|
|
max_runtime = dp[n][quantized_max_memory].item() |
|
|
|
return max_runtime, saved_items, recomputable_items |
|
|
|
|
|
def _optimize_runtime_with_given_memory( |
|
memory: List[float], |
|
runtimes: List[float], |
|
max_memory: float, |
|
) -> Tuple[float, List[int], List[int]]: |
|
SOLVER = config.activation_memory_budget_solver |
|
if SOLVER == "greedy": |
|
return greedy_knapsack(memory, runtimes, max_memory) |
|
elif SOLVER == "ilp": |
|
return ilp_knapsack(memory, runtimes, max_memory) |
|
elif SOLVER == "dp": |
|
return dp_knapsack(memory, runtimes, max_memory) |
|
else: |
|
raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") |
|
|
|
|
|
from torch.utils._mode_utils import no_dispatch |
|
|
|
|
|
def estimate_runtime(node): |
|
RUNTIME_MODE = config.activation_memory_budget_runtime_estimator |
|
|
|
def materialize_arg(x): |
|
if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): |
|
shape = list(x.meta["val"].shape) |
|
|
|
def realize_symbol(d): |
|
return hint_int(d, fallback=4096) |
|
|
|
shape = [realize_symbol(s) for s in shape] |
|
return x.meta["val"].new_zeros(shape) |
|
elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): |
|
return hint_int(x.meta["val"], fallback=4096) |
|
elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): |
|
return 1.0 |
|
elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): |
|
return True |
|
else: |
|
return x |
|
|
|
if RUNTIME_MODE == "testing": |
|
return 1 |
|
|
|
elif RUNTIME_MODE == "profile": |
|
from triton.testing import do_bench |
|
|
|
with no_dispatch(): |
|
args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) |
|
ms = do_bench(lambda: node.target(*args, **kwargs)) |
|
return ms |
|
|
|
elif RUNTIME_MODE == "flops": |
|
|
|
from torch.utils.flop_counter import FlopCounterMode |
|
|
|
args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) |
|
with FlopCounterMode(display=False) as mode: |
|
node.target(*args, **kwargs) |
|
counted_flops = mode.get_total_flops() |
|
return max(counted_flops, 1) |
|
else: |
|
raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") |
|
|
|
|
|
def choose_saved_values_set( |
|
joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 |
|
) -> List[fx.Node]: |
|
if memory_budget > 1 or memory_budget < 0: |
|
raise RuntimeError( |
|
f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" |
|
) |
|
min_cut_options = MinCutOptions( |
|
ban_if_used_far_apart=config.ban_recompute_used_far_apart, |
|
ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, |
|
ban_if_materialized_backward=config.ban_recompute_materialized_backward, |
|
ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist, |
|
ban_if_reduction=config.ban_recompute_reductions, |
|
) |
|
|
|
if config.aggressive_recomputation: |
|
min_cut_options = replace( |
|
min_cut_options, |
|
ban_if_used_far_apart=False, |
|
ban_if_long_fusible_chains=False, |
|
ban_if_materialized_backward=False, |
|
ban_if_not_in_allowlist=False, |
|
) |
|
if memory_budget == 0: |
|
return node_info.inputs |
|
|
|
runtime_optimized_saved_values, _ = solve_min_cut( |
|
joint_graph, |
|
node_info, |
|
min_cut_options, |
|
) |
|
|
|
if memory_budget == 1: |
|
return runtime_optimized_saved_values |
|
|
|
def estimate_activations_size(saved_values: List[fx.Node]) -> float: |
|
return sum([_size_of(i) for i in saved_values]) / 1e9 |
|
|
|
min_act_size = estimate_activations_size(node_info.inputs) |
|
max_act_size = estimate_activations_size(runtime_optimized_saved_values) |
|
|
|
if max_act_size <= min_act_size: |
|
return runtime_optimized_saved_values |
|
|
|
def get_normalized_size(sz): |
|
return (sz / 1e9) / (max_act_size - min_act_size) |
|
|
|
def get_mem_ratio(activations: List[fx.Node]): |
|
return (estimate_activations_size(activations) - min_act_size) / ( |
|
max_act_size - min_act_size |
|
) |
|
|
|
more_aggressive_options = replace( |
|
min_cut_options, |
|
ban_if_used_far_apart=False, |
|
ban_if_long_fusible_chains=False, |
|
ban_if_materialized_backward=False, |
|
) |
|
more_aggressive_saved_values, _ = solve_min_cut( |
|
joint_graph, node_info, more_aggressive_options |
|
) |
|
if get_mem_ratio(more_aggressive_saved_values) < memory_budget: |
|
return more_aggressive_saved_values |
|
|
|
aggressive_options = replace( |
|
more_aggressive_options, |
|
ban_if_not_in_allowlist=False, |
|
) |
|
aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( |
|
joint_graph, node_info, aggressive_options |
|
) |
|
|
|
if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: |
|
return aggressive_recomputation_saved_values |
|
|
|
from torch._inductor.fx_utils import get_node_storage |
|
|
|
input_storages = {get_node_storage(node) for node in node_info.inputs} |
|
|
|
def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: |
|
return [ |
|
i |
|
for i in banned_nodes |
|
if ( |
|
|
|
i.dist_from_bw < int(1e9) |
|
and get_node_storage(i) not in input_storages |
|
) |
|
] |
|
|
|
recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) |
|
|
|
|
|
|
|
|
|
|
|
all_recomputable_banned_nodes = sorted( |
|
recomputable_banned_nodes, key=_size_of, reverse=True |
|
) |
|
if len(all_recomputable_banned_nodes) == 0: |
|
return node_info.inputs |
|
memories_banned_nodes = [ |
|
get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes |
|
] |
|
runtimes_banned_nodes = [ |
|
estimate_runtime(node) for node in all_recomputable_banned_nodes |
|
] |
|
from torch.utils._mode_utils import no_dispatch |
|
|
|
def get_saved_values_knapsack(memory_budget): |
|
with no_dispatch(): |
|
( |
|
expected_runtime, |
|
saved_node_idxs, |
|
recomputable_node_idxs, |
|
) = _optimize_runtime_with_given_memory( |
|
memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) |
|
) |
|
dont_ban = set() |
|
for idx in recomputable_node_idxs: |
|
dont_ban.add(all_recomputable_banned_nodes[idx]) |
|
assert dont_ban.issubset(all_recomputable_banned_nodes) |
|
|
|
saved_values, _ = solve_min_cut( |
|
joint_graph, |
|
node_info, |
|
aggressive_options, |
|
dont_ban, |
|
) |
|
return saved_values, expected_runtime |
|
|
|
if config.visualize_memory_budget_pareto: |
|
options = [] |
|
for sweep_memory_budget in range(100, -1, -5): |
|
saved_values, expected_runtime = get_saved_values_knapsack( |
|
sweep_memory_budget / 100 |
|
) |
|
options.append( |
|
( |
|
sweep_memory_budget, |
|
sum(runtimes_banned_nodes) - expected_runtime, |
|
get_mem_ratio(saved_values), |
|
) |
|
) |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
x_values = [item[2] for item in options] |
|
y_values = [item[1] for item in options] |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(x_values, y_values, marker="o") |
|
|
|
|
|
for i, txt in enumerate(x_values): |
|
plt.annotate( |
|
f"{txt:.2f}", |
|
(x_values[i], y_values[i]), |
|
textcoords="offset points", |
|
xytext=(0, 10), |
|
ha="center", |
|
) |
|
|
|
plt.xlabel("Memory Budget") |
|
plt.ylabel("Runtime of Recomputed Components") |
|
plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") |
|
plt.grid(True) |
|
fig = plt.gcf() |
|
plt.show() |
|
fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png" |
|
fig.savefig(fig_name) |
|
log.warning("Generated Pareto frontier curve at %s", fig_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return get_saved_values_knapsack(memory_budget=memory_budget)[0] |
|
|
|
|
|
def min_cut_rematerialization_partition( |
|
joint_module: fx.GraphModule, |
|
_joint_inputs, |
|
compiler="inductor", |
|
*, |
|
num_fwd_outputs, |
|
) -> Tuple[fx.GraphModule, fx.GraphModule]: |
|
""" |
|
Partitions the joint graph such that the backward recomputes the forward. |
|
Recomputing helps in trading off memory bandwidth with computation. |
|
|
|
To create the fwd and bwd graph, we copy the joint graph, manually set the |
|
outputs to just original forward or backward outputs. And then we run the |
|
resulting graphs through dead code elimination. |
|
|
|
.. warning:: |
|
This API is experimental and likely to change. |
|
|
|
Args: |
|
joint_module(fx.GraphModule): The joint forward and backward graph. This |
|
is the result of AOT Autograd tracing. |
|
_joint_inputs: The inputs to the joint graph. This is unused. |
|
compiler: This option determines the default set of recomputable ops. |
|
Currently, there are two options: ``nvfuser`` and ``inductor``. |
|
recomputable_ops: This is an optional set of recomputable ops. If this |
|
is not None, then this set of ops will be used instead of the |
|
default set of ops. |
|
num_fwd_outputs: The number of outputs from the forward graph. |
|
|
|
Returns: |
|
Returns the generated forward and backward Fx graph modules. |
|
""" |
|
|
|
joint_module.graph.eliminate_dead_code() |
|
joint_module.recompile() |
|
|
|
fx_g = joint_module.graph |
|
|
|
|
|
if config.cse: |
|
cse_graph = fx_graph_cse(fx_g) |
|
joint_module.graph = cse_graph |
|
joint_graph = joint_module.graph |
|
|
|
graph_has_recomputable_ops = has_recomputable_ops(joint_module) |
|
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) |
|
if graph_has_recomputable_ops: |
|
joint_module = cleanup_recompute_tags(joint_module) |
|
|
|
def classify_nodes(joint_module): |
|
name_to_node = get_name_to_node(joint_module.graph) |
|
required_bw_nodes = set() |
|
for node in joint_module.graph.nodes: |
|
if node.op == "placeholder" and "tangents" in node.target: |
|
required_bw_nodes.add(node) |
|
if node in required_bw_nodes: |
|
for user in node.users: |
|
required_bw_nodes.add(user) |
|
|
|
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) |
|
fwd_seed_offset_inputs = list( |
|
filter(_is_fwd_seed_offset, joint_module.graph.nodes) |
|
) |
|
inputs = primal_inputs + fwd_seed_offset_inputs |
|
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( |
|
joint_module, num_fwd_outputs=num_fwd_outputs |
|
) |
|
required_bw_nodes.update( |
|
o for o in bwd_outputs if o is not None and o.op != "output" |
|
) |
|
forward_only_graph = _extract_graph_with_inputs_outputs( |
|
joint_module.graph, inputs, fwd_outputs |
|
) |
|
required_fw_nodes: Set[fx.Node] = { |
|
name_to_node[node.name] |
|
for node in forward_only_graph.nodes |
|
if node.op != "output" |
|
} |
|
unclaimed_nodes = { |
|
node |
|
for node in joint_module.graph.nodes |
|
if node not in required_fw_nodes and node not in required_bw_nodes |
|
} |
|
fw_cnt = 0 |
|
fw_order = {} |
|
for node in joint_module.graph.nodes: |
|
if node in required_fw_nodes: |
|
fw_order[node] = fw_cnt |
|
fw_cnt += 1 |
|
return NodeInfo( |
|
inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order |
|
) |
|
|
|
node_info = classify_nodes(joint_module) |
|
|
|
|
|
|
|
|
|
if len(node_info.required_bw_nodes) == 0: |
|
return default_partition( |
|
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs |
|
) |
|
|
|
for node in reversed(joint_module.graph.nodes): |
|
if node.op == "output": |
|
node.dist_from_bw = int(1e9) |
|
elif not node_info.is_required_fw(node): |
|
node.dist_from_bw = 0 |
|
else: |
|
node.dist_from_bw = int(1e9) |
|
for user in node.users: |
|
node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) |
|
|
|
memory_budget = config.activation_memory_budget |
|
for node in joint_graph.nodes: |
|
if isinstance(node.meta.get("memory_budget", None), float): |
|
memory_budget = node.meta["memory_budget"] |
|
break |
|
|
|
saved_values = choose_saved_values_set( |
|
joint_graph, node_info, memory_budget=memory_budget |
|
) |
|
|
|
saved_sym_nodes = list(filter(is_sym_node, saved_values)) |
|
saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) |
|
|
|
|
|
fw_module, bw_module = _extract_fwd_bwd_modules( |
|
joint_module, |
|
saved_values, |
|
saved_sym_nodes=saved_sym_nodes, |
|
num_fwd_outputs=num_fwd_outputs, |
|
) |
|
|
|
if graph_has_recomputable_ops: |
|
if graph_has_recomputable_rng_ops: |
|
fw_module, bw_module = functionalize_rng_ops( |
|
joint_module, fw_module, bw_module, len(saved_sym_nodes) |
|
) |
|
bw_module = reordering_to_mimic_autograd_engine(bw_module) |
|
|
|
if AOT_PARTITIONER_DEBUG: |
|
from torch._inductor.fx_utils import get_node_storage |
|
|
|
storages = {get_node_storage(node) for node in saved_values} |
|
print( |
|
"Theoretical Activations Stored: ", |
|
sum(_size_of(i) for i in saved_values) / 1e9, |
|
) |
|
sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values]) |
|
fw_module_nodes = { |
|
node.name for node in fw_module.graph.nodes if node.op == "call_function" |
|
} |
|
bw_module_nodes = { |
|
node.name for node in bw_module.graph.nodes if node.op == "call_function" |
|
} |
|
remat_nodes = fw_module_nodes & bw_module_nodes |
|
|
|
counts: Dict[str, int] = defaultdict(int) |
|
for node in fw_module.graph.nodes: |
|
if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): |
|
counts[str(node.target._overloadpacket)] += 1 |
|
print( |
|
f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}" |
|
) |
|
print( |
|
"Count of Ops Rematerialized: ", |
|
sorted(counts.items(), key=lambda x: x[1], reverse=True), |
|
) |
|
return fw_module, bw_module |
|
|
|
|
|
def draw_graph( |
|
traced: torch.fx.GraphModule, |
|
fname: str, |
|
figname: str = "fx_graph", |
|
clear_meta: bool = True, |
|
prog: Optional[Union[str, List[str]]] = None, |
|
parse_stack_trace: bool = False, |
|
dot_graph_shape: Optional[str] = None, |
|
) -> None: |
|
if clear_meta: |
|
new_graph = copy.deepcopy(traced.graph) |
|
traced = fx.GraphModule(traced, new_graph) |
|
for node in traced.graph.nodes: |
|
node.meta = {} |
|
base, ext = os.path.splitext(fname) |
|
if not ext: |
|
ext = "." + config.torch_compile_graph_format |
|
print(f"Writing FX graph to file: {base}{ext}") |
|
g = graph_drawer.FxGraphDrawer( |
|
traced, |
|
figname, |
|
parse_stack_trace=parse_stack_trace, |
|
dot_graph_shape=dot_graph_shape, |
|
) |
|
x = g.get_main_dot_graph() |
|
write_method = getattr(x, "write_" + ext.lstrip(".")) |
|
fname = f"{base}{ext}" |
|
if prog is None: |
|
write_method(fname) |
|
else: |
|
write_method(fname, prog=prog) |
|
|