Spaces:
Running
Running
import torch | |
from torch.fx import Node | |
from torch.fx._compatibility import compatibility | |
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor | |
from torch.utils._pytree import tree_map_only | |
from torch.utils import _pytree as pytree | |
from torch.multiprocessing.reductions import StorageWeakRef | |
import _operator | |
from enum import Enum | |
import itertools | |
from typing import Set, Dict | |
from collections import defaultdict | |
__all__ = ['reinplace'] | |
class _ViewType(Enum): | |
NonView = 0 | |
SingleOutputView = 1 | |
MultiOutputView = 2 | |
def _is_view_op(tgt): | |
if tgt is not None and isinstance(tgt, torch._ops.OpOverload): | |
schema = tgt._schema | |
if len(schema.arguments) > 0: | |
first_arg = schema.arguments[0] | |
# check if op is a view | |
return first_arg.alias_info is not None and not first_arg.alias_info.is_write | |
def _get_view_type(tgt) -> _ViewType: | |
if tgt is not None and isinstance(tgt, torch._ops.OpOverload): | |
schema = tgt._schema | |
if len(schema.arguments) > 0: | |
first_arg = schema.arguments[0] | |
# check if op is a view | |
if first_arg.alias_info is not None and not first_arg.alias_info.is_write: | |
# check if op is a multi-output view | |
if '*' in first_arg.alias_info.after_set: | |
return _ViewType.MultiOutputView | |
else: | |
return _ViewType.SingleOutputView | |
return _ViewType.NonView | |
# Stores a bunch of metadata related to functionalization each node. | |
# Relevant metadata: | |
# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) | |
# The fake tensor output from running the current node | |
# n.meta['view_of']: Node | |
# If the current node n is a view of some base tensor, the 'view_of' field tells us which | |
# view node was used to generate the current node (a view tensor). | |
# This information actually makes `fake_result` redundant, but we can use `fake_result` | |
# to sanity check that our aliasing information is correct. | |
class _FunctionalizationMetadataProp(torch.fx.Interpreter): | |
def run_node(self, node: Node): | |
self.node_counter += 1 | |
result = super().run_node(node) | |
node.meta['fake_result'] = result | |
node.meta['node_idx'] = self.node_counter | |
# (1) Update metadata with the list of nodes that are used by this node | |
# copy_() doesn't read from its first argument; it writes to it, overwriting previous data. | |
# We don't want to treat it as "being used as an input". | |
node_args = node.args | |
if node.target is torch.ops.aten.copy_.default: | |
node_args = node_args[1:] | |
# (2) Update metadata to track aliasing information about view tensor nodes. | |
if node.op == 'call_function': | |
view_type = _get_view_type(node.target) | |
if view_type == _ViewType.SingleOutputView: | |
assert isinstance(node.args[0], Node) | |
node.meta['view_of'] = node.args[0] | |
elif view_type == _ViewType.MultiOutputView: | |
self.multi_output_view_nodes[node] = node.args[0] | |
# Check if we returned a multi-output view, | |
# and we're now grabbing the individual views from the output. | |
# | |
# For multi-output views, we want to map each output view to the base, | |
# but this mapping involves two separate nodes in FX IR. | |
# e.g. "a, b = x_1.split(...)" becomes: | |
# %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) | |
# %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) | |
# %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) | |
# And we'd like to set: | |
# getitem1.meta['view_of'] = x_1 | |
elif node.target is _operator.getitem: | |
list_arg = node.args[0] | |
maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) | |
if maybe_base_of_view is not None: | |
# Note: we could also track indexing info here for multi-output views. | |
# I don't think this metadata is strictly needed for de-functionalization. | |
assert isinstance(maybe_base_of_view, Node) | |
node.meta['view_of'] = maybe_base_of_view | |
if 'view_of' in node.meta: | |
# We're linking the current node with its first argument as views. | |
# Assert here that this is actually the case, and their storages are the same. | |
assert isinstance(node.meta['fake_result'], FakeTensor) | |
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) | |
view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) | |
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage()) | |
assert view_storage == base_storage | |
return result | |
def propagate(self, *args): | |
self.multi_output_view_nodes = {} | |
self.node_counter = -1 | |
with FakeTensorMode() as mode: | |
fake_args = [mode.from_tensor(a) for a in args] | |
return super().run(*fake_args) | |
def _schemas_match(functional_schema, inplace_schema): | |
names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name | |
arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( | |
a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) | |
# for the inplace op, its first argument should be mutable | |
assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write | |
# and its remaining arguments shouldn't be. | |
assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) | |
return names_match and arg_types_match | |
# TODO: this should be beefed up to be able to properly re-inplace with: | |
# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) | |
# - out= ops (e.g. angle -> angle.out) | |
# TODO: we should also figure this info out using torchgen. | |
def _maybe_get_inplace_op(op): | |
# __module__ seems broken; it returns torch._ops.aten which doesn't exist | |
if not isinstance(op, torch._ops.OpOverload): | |
return None | |
# Some view ops have inplace variants (as_strided_, etc), | |
# but we do NOT want the reinplacing pass to directly add these into the program. | |
# (they'll require extra special handling, aren't aren't really useful for perf anyway) | |
if _is_view_op(op): | |
return None | |
op_namespace = op.__module__.split(".")[-1] | |
op_base_name = op.overloadpacket.__name__ | |
maybe_namespace_module = getattr(torch.ops, op_namespace) | |
maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) | |
if maybe_inplace_op is None: | |
return None | |
inplace_overloads = [ | |
getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() | |
] | |
inplace_overloads_with_matching_schemas = [ | |
f | |
for f in inplace_overloads | |
if _schemas_match(op._schema, f._schema) | |
] | |
# Just because foo() and foo_() are both existing operators, | |
# They aren't guaranteed to have compatible schemas. | |
# For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant, | |
# Even though several overloads of pow_ exist. | |
if len(inplace_overloads_with_matching_schemas) == 0: | |
return None | |
assert len(inplace_overloads_with_matching_schemas) == 1 | |
inplace_op = inplace_overloads_with_matching_schemas[0] | |
return inplace_op | |
_VIEW_INVERSE_MAP = { | |
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, | |
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, | |
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, | |
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, | |
} | |
# This function, given a set of set of (aliased) tensor nodes, | |
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index | |
# in the node ordering. | |
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): | |
def _add_if_tensor(x, set_): | |
if isinstance(x, FakeTensor): | |
set_.add(StorageWeakRef(x._typed_storage())) | |
nodes_used_after = set() | |
for t in tensor_aliases: | |
# get all nodes that use the current alias | |
usage_nodes = t.users | |
for n in usage_nodes: | |
# We only care about usages after the current node | |
if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index: | |
continue | |
# We also don't care about intermediate view ops. | |
# They only matter if their output is then used elsewhere | |
# (either in an out-of-place op, or as an output to the function). | |
if n in tensor_aliases: | |
if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: | |
continue | |
nodes_used_after.add(n) | |
return nodes_used_after | |
# Given an op that we're trying to re-inplace, "b = foo(a)", | |
# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" | |
# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: | |
# If there are any aliases in the alias_set(a) that satisfy: | |
# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" | |
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata | |
# as "alias" | |
def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: | |
def matching_view_metadata(a, b): | |
return a.size() == b.size() and \ | |
a.stride() == b.stride() and \ | |
a.storage_offset() == b.storage_offset() | |
view_inverse_nodes = set() | |
# Go through them in node order, so we can see chains of view_scatter ops. | |
for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): | |
if n.target not in _VIEW_INVERSE_MAP: | |
continue | |
base = n.args[0] | |
mutated_view = n.args[1] | |
assert isinstance(base, Node) | |
assert isinstance(base.meta['fake_result'], FakeTensor) | |
assert isinstance(mutated_view, Node) | |
assert isinstance(mutated_view.meta['fake_result'], FakeTensor) | |
# Check that this view_inverse op actually corresponds to taking doing the inverse | |
# of one of our existing self_alias nodes. | |
original_view = _VIEW_INVERSE_MAP[n.target] | |
for self_alias in self_aliases: | |
# We're looking for some alias of the self arg, "alias", | |
# that was created from some op `alias = foo(base, args...)` | |
# such that the current _scatter op "inverts" that foo call. | |
# We can check that by running the original op again, and checking that the strides match. | |
if 'view_of' not in self_alias.meta: | |
continue | |
self_alias_base = self_alias.meta['view_of'] | |
try: | |
# The we're trying to re-use the args from the view_scatter call inside of the corresponding | |
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse | |
# of the current alias we're looking at. | |
view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) | |
expected_metadata = self_alias.meta['fake_result'] | |
# If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. | |
if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ | |
matching_view_metadata(view_replay_metadata, expected_metadata): | |
view_inverse_nodes.add(n) | |
except Exception: | |
continue | |
return view_inverse_nodes | |
def reinplace(gm, *sample_args): | |
""" | |
Given an fx.GraphModule, modifies it to perform "reinplacing", | |
mutating the nodes of the graph. | |
We look for out-of-place op call sites like `b = a.add(...)`, | |
and convert them to be inplace (`b = a.add_(...)`), | |
as long as the input to the current operator ("a") isn't re-used | |
anywhere later in the graph. | |
This pass currently expects to operate on a **functional, ATen** graph. | |
This can be obtained by running `make_fx(functionalize(f))`. | |
Sample inputs are needed to determine aliasing relationships of the inputs. | |
In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the | |
inputs to the program. | |
Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: | |
(1) Perform some initial checks on the metadata of "a" and "args..." | |
that can disqualify them from being reinplaced. | |
(1a) Check that the self argument we're attempting to reinplace | |
has acceptable dtype/size metadata to reinplace with. | |
For example, if we have: | |
a = torch.ones(1) | |
b = torch.ones(10) | |
out = torch.add(a, b) | |
We can't turn that into | |
a.add_(b) | |
Because that would require resizing "a". | |
Similarly, we can't convert torch.ge(a, b) into a.ge_(b), | |
because that would require changing a's dtype (from e.g. float32 to bool). | |
Note that in this specific example, we could technically do better.. | |
If we see the pattern: | |
a_1 = a.ge(b) | |
a_2 = aten._to_copy(a_1, a.dtype) | |
Then we this should be valid to completely re-inplace | |
(this is exactly what functionalization will emit when it sees a.ge_(b)). | |
This optimization is only really important for user programs | |
that directly use inplace comparison ops though. | |
We also cannot re-inplace on tensors that have overlapping memory, | |
e.g. torch.ones(1).expand(4, 4).add_(1) | |
(1b) Check if "a" is an alias of any of the program inputs. | |
If it is, skip and move to the next node. | |
Inplace'ing an op that would cause it to mutate a program is not sound, | |
because that would be a side effect visible to the user. | |
NOTE: there's a future optimization that we should make: | |
if "a" is a (alias of a) program input, but later in the program | |
there is a node that looks like "a.copy_(...)", | |
Then re-inplacing is ok to do - we are temporarily re-using a's buffer, | |
which will later be overwritten by the copy_() call. | |
This will be an important optimization to have for programs that mutate | |
their inputs. It currently isn't implemented though. | |
(1c) Check if "a" and "args..." alias | |
For example, re-inplacing to create code like the below | |
isn't guaranteed to be sound: | |
aten.mul_(a, a) | |
(2) Check that "a" and all of its outstanding aliases are not used anywhere | |
later in the graph. If this is the case, then it's safe to re-inplace | |
to "b = foo_(a)". | |
There are a few caveats to this, explained in more detail below: | |
(a) If "a" is used later as an argument to a view op, that is okay. | |
It's only a problem if "a" (or that view) is later passed | |
into a normal operator, or if it is returned as the program output. | |
(b) If "a" is a repeat argument in `foo()`, then don't reinplace. | |
Most ATen kernels don't make any guarantees that this is sound, | |
e.g. if you do aten.mul_(a, a). | |
So we'll just ban re-inplacing in this case. | |
It's only a problem if "a" (or that view) is later passed | |
(c) If "a" is used as an input into a view "inverse" / "scatter" | |
operator, it is potentially fine to re-inplace | |
(and remove that scatter operator from the graph). | |
See below for a more detailed example. | |
NOTE: there is an optimization in this step that is crucial | |
to fully recovering performance from functionalization. | |
Given this program: | |
def f(x): | |
a = torch.ops.aten.add(x, x) | |
b = torch.ops.aten.diagonal(a) | |
torch.ops.aten.fill_(b, 0) | |
return d | |
Functionalization will emit the following: | |
def f(x): | |
a = torch.ops.aten.add(x, x) | |
b = torch.ops.aten.diagonal(a, 0, 1) | |
b_updated = torch.ops.aten.fill(b, 0) | |
a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) | |
return a_updated | |
Ordinarily, we would not be able to reinplace the fill, | |
because "b" aliases with "a" which is used by the diagonal_scatter call. | |
"re-inplacing" is on the hook for figuring out that it is ok to | |
completely, the expensive diagonal_scatter call, if we re-inplace the add(). | |
So, for every `alias in alias_set(a)`, instead of checking | |
that "alias" is not used anywhere later in the graph, | |
we check that | |
EITHER: | |
(a) alias is not used anywhere later in the graph | |
OR: | |
(b) alias is used exactly once later on in the graph, | |
in the following op: | |
out = foo_scatter(alias, x, args...) | |
where the following must hold: | |
(i) "foo_scatter" is the "inverse" operator for foo. | |
This only applies to "foo" ops that are view operators, | |
which view into a subset of the original tensor's memory. | |
In practice, there are ~4 operators where this applies: | |
diagonal -> diagonal_scatter | |
slice -> slice_scatter | |
select -> select_scatter | |
as_strided -> as_strided_scatter | |
(ii) "args..." are the same between the foo() and foo_scatter() calls. | |
(3) Perform the actual re-inplacing on foo! | |
(3b) is the common case, but special care is needed for {view}_scatter (3a) | |
(3a) {view}_scatter ops. | |
Consider this program: | |
a = torch.zeros(2, 2) | |
b = torch.ones(2) | |
a[0] = b | |
Post functionalization, that will look like: | |
a = torch.zeros(2) | |
b = torch.ones(1) | |
a_updated = torch.select_scatter(a, b, 0, 0) | |
In this case though, there is no "functional" op to re-inplace! | |
Instead, we'd like to directly remove toe select_scatter call. | |
We already know from (3) that this is valid, | |
because "a" has no later usages in the graph. | |
We perform the re-inplacing on the {view}_scatter op like so | |
Before: | |
a_updated = torch.select_scatter(a, b, args...) | |
After: | |
a_slice = a.select(a, args...) | |
a_slice.copy_(b) | |
(3b) Otherwise, replace the functional op with its inplace variant. | |
Before: | |
b = foo(a, args...) | |
After: | |
a.foo_(args...) | |
(4) Finally, after converting either: | |
Before: | |
b = foo(a) | |
After: | |
foo_(a) | |
or | |
Before: | |
b = {slice}_scatter(a, mutated_slice, args...) | |
After: | |
slice = {slice}(a, args...) | |
slice.copy_(mutated_slice) | |
We now need to find all later nodes that use "b" as an argument | |
and update them to take in "a" instead. | |
Note that for the majority of inplace ops, this isn't actually necessary | |
(because most inplace ops return "self" as their output). | |
This isn't generally true for all mutable ops though, which is why | |
we need to actually replace all of the arguments. | |
We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], | |
That maps a given tensor storage to the set of all nodes that take in that storage | |
as an input. | |
Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused | |
together. | |
(5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" | |
during step (3) get manually deleted from the graph. | |
Their outputs are no longer used, so technically standard DCE would be able | |
to do this, but we can no longer run FX's DCE pass now that we have mutable | |
ops in the graph. | |
""" | |
_FunctionalizationMetadataProp(gm).propagate(*sample_args) | |
# Useful debug printing | |
# def _print(x): | |
# if isinstance(x, FakeTensor): | |
# print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}') | |
# for n in gm.graph.nodes: | |
# print(n.format_node()) | |
# if hasattr(n, 'meta'): | |
# print(f'node_idx: {n.meta["node_idx"]}') | |
# if 'fake_result' in n.meta: | |
# tree_map(_print, n.meta['fake_result']) | |
# if 'view_of' in n.meta: | |
# print(f'view_of: {str(n.meta["view_of"])}') | |
# print() | |
# We need to know which nodes correspond to inputs (or their aliases) | |
# so we know not to re-inplace them. | |
# NOTE: later, we'll need to add an optimization for fully recovering performance | |
# on programs that mutate inputs. | |
input_storages = { | |
StorageWeakRef( | |
node.meta['fake_result']._typed_storage() | |
) for node in gm.graph.nodes if node.op == 'placeholder'} | |
# We also need to know for a given node, what are all of its aliasing nodes. | |
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) | |
for n in gm.graph.nodes: | |
if 'fake_result' in n.meta: | |
# Tree-mapping because some ops can return lists of tensors. | |
def _add_to_map(x): | |
if isinstance(x, FakeTensor): | |
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) | |
pytree.tree_map_(_add_to_map, n.meta['fake_result']) | |
# inplace-ify functional ops, subject to the constraints written below. | |
all_later_view_inverse_nodes_to_delete = set() | |
for idx, node in enumerate(gm.graph.nodes): | |
if node.op == 'call_function': | |
# Today, the re-inplace pass on directly acts on: | |
# - functional ops with an inplace variant | |
# - {view}_scatter ops that can be potentially removed from the graph. | |
# Both of these ops take in tensor first args, so filtering on this condition | |
# makes the later code simpler. | |
# We should revisit this at some point though, particularly when we also want | |
# the reinplacer to be able to handle out= and mutable operators | |
# and tensorlist first args (like `_foreach_` ops). | |
if not isinstance(node.target, torch._ops.OpOverload): | |
continue | |
if len(node.target._schema.arguments) < 1: | |
continue | |
if type(node.target._schema.arguments[0].type) != torch.TensorType: | |
continue | |
# Step 1a: Check that the self argument we're attempting to reinplace | |
# has the same size/stride as the output. | |
# For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor) | |
# As it would require resizing scalar_tensor. | |
# (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), | |
# this is probably an optimization to revisit later). | |
self_arg = node.args[0] | |
self_flattened = pytree.tree_leaves(self_arg.meta['fake_result']) | |
node_flattened = pytree.tree_leaves(node.meta['fake_result']) | |
self_has_wrong_metadata = False | |
if len(self_flattened) == len(node_flattened): | |
for self_meta, node_meta in zip(self_flattened, node_flattened): | |
if self_meta.numel() != node_meta.numel(): | |
self_has_wrong_metadata = True | |
if self_meta.dtype != node_meta.dtype: | |
self_has_wrong_metadata = True | |
# We also cannot re-inplace on tensors that have internal memory overlap. | |
# e.g. torch.ones(1).expand(4, 4).add_(1) | |
if torch._debug_has_internal_overlap(self_meta) == 1: | |
self_has_wrong_metadata = True | |
# Here, we (optimistically) assume that a.resize(b) is valid to re-inplace, | |
# Since users should never really be calling the functional "torch.ops.aten.resize" | |
# op directly in their programs. | |
if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: | |
continue | |
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input | |
self_arg_name = self_arg.name | |
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) | |
if self_arg_storage in input_storages: | |
# TODO: later, add the optimization for handling `copy_()` calls in the graph. | |
continue | |
if len([x for x in node.args if x is self_arg]) > 1: | |
# Step 1c: | |
# Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, | |
# so we prevent re-inplacing in this case. | |
continue | |
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) | |
self_aliases = storage_to_nodes[self_arg_storage] | |
# First, we find all later usages of any of the aliases of self_arg. | |
later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) | |
# Then, we check if any of those later usages are actually view_scatter ops | |
# that are safe to fully remove. | |
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) | |
# Step 2: Check to see if the input to the op is re-used later in the graph. | |
# If not (same goes for its aliases), then this op is safe to re-in place. | |
# This is a slightly roundabout way to check that there are no later usages of the current self argument. | |
# (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) | |
can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 | |
if not can_reinplace: | |
continue | |
# Step 3a: Special handling for when we see *_scatter operators. | |
# When we see an operator like `b = torch.slice_scatter(a, ...)`, | |
# instead of trying to "inplace" it into a.slice_scatter_(..._), | |
# we would prefer to remove it from the graph entirely, | |
# and instead copy_() the slice directly into the larger tensor. | |
# See the description of the algorithm for a full example. | |
if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete: | |
view_op = _VIEW_INVERSE_MAP[node.target] | |
# Before: | |
# base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) | |
# After: | |
# slice = torch.ops.aten.slice.default(base, args...) | |
# slice.copy_(mutated_slice) | |
with gm.graph.inserting_before(node): | |
mutated_slice_node = node.args[1] | |
remaining_slice_args = node.args[2:] | |
slice_node = gm.graph.create_node( | |
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs) | |
copy_node = gm.graph.create_node( | |
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {}) | |
# Add the slice_scatter node to our "nodes to delete" list. | |
all_later_view_inverse_nodes_to_delete.add(node) | |
else: | |
# Step 3b: Check to see if this operator has an inplace variant. | |
maybe_inplace_op = _maybe_get_inplace_op(node.target) | |
if maybe_inplace_op is None: | |
continue | |
# And if so, replace it with its inplace variant. | |
node.target = maybe_inplace_op | |
# At this point, 'storage_to_nodes' will be stale. | |
# Now that we're inplacing `b = foo(a)`, we need to effectively | |
# union together the dict values for b and a's storage. | |
# Hmm... morally I think we also want to keep the `fake_result` metadata | |
# up to date here, but I'm not sure how easy it is to do. | |
# Maybe it's fine to wait until the end of the pass to update it. | |
curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) | |
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) | |
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) | |
# Need to remember the view_scatter view nodes we found so we can remove them alter. | |
all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages) | |
# Step 4: | |
# Now that we've replaced b = a.foo() with a.foo_(), | |
# We need to replace any later usages of "b" with "a" | |
for old in itertools.chain([node], later_view_inverse_node_usages): | |
new = old.args[0] | |
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] | |
for node_to_update in nodes_to_update: | |
new_args = [] | |
args = node_to_update.args | |
def replace_arg(a): | |
if a == old: | |
return new | |
return a | |
# First, replace usages of "b" with "a" | |
node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args) | |
node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs) | |
# Second, update our storage_to_nodes data structure. | |
old_flattened_res = pytree.tree_leaves(old.meta['fake_result']) | |
node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result']) | |
old_res_storage = { | |
StorageWeakRef( | |
x._typed_storage() | |
) for x in old_flattened_res if isinstance(x, FakeTensor)} | |
node_res_storage = { | |
StorageWeakRef( | |
x._typed_storage() | |
) for x in node_flattened_res if isinstance(x, FakeTensor)} | |
# This will happen if we're updating a view op, e.g. | |
# e.g. replacing | |
# x = view(old) | |
# x = view(new) | |
# When that happens, we need to make sure to keep our | |
# storage mapping up to date. | |
# | |
# We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, | |
# or multiple tensors that all share the same storage. | |
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs. | |
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: | |
new_flattened_res = pytree.tree_leaves(new.meta['fake_result']) | |
new_res_storage = { | |
StorageWeakRef( | |
x._typed_storage() | |
) for x in new_flattened_res if isinstance(x, FakeTensor)} | |
assert len(new_res_storage) == 1 | |
(old_ref,) = old_res_storage | |
(new_ref,) = new_res_storage | |
(node_ref,) = node_res_storage | |
# Technically, "old_ref" and all its aliases will remain | |
# in our mapping. | |
# That should be fine though, since we deleted "old" | |
# from the graph at this point. | |
storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) | |
storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) | |
# Step 4: delete any _scatter nodes that we de-functionalized | |
# Need to take care not to delete any of these nodes until after *all* modifications | |
# to the graph are finished. | |
for to_delete in all_later_view_inverse_nodes_to_delete: | |
gm.graph.erase_node(to_delete) | |
gm.recompile() | |
return gm | |