|
|
|
import logging |
|
import operator |
|
from typing import Any, Dict, Optional, Set, TYPE_CHECKING |
|
|
|
|
|
if TYPE_CHECKING: |
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv |
|
else: |
|
ShapeEnv = Any |
|
|
|
import torch |
|
import torch.utils._pytree as pytree |
|
from torch import fx |
|
from torch.fx._compatibility import compatibility |
|
from torch.fx._utils import lazy_format_graph_code |
|
from torch.fx.experimental.sym_node import SymNode |
|
from torch.fx.graph_module import GraphModule |
|
|
|
log = logging.getLogger(__name__) |
|
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") |
|
|
|
|
|
def _get_example_value(node: fx.Node) -> Optional[str]: |
|
""" |
|
Get the example value key for a node, since dynamo uses "example_value" |
|
while non-strict export uses "val. |
|
""" |
|
if "example_value" in node.meta: |
|
return node.meta["example_value"] |
|
elif "val" in node.meta: |
|
return node.meta["val"] |
|
else: |
|
return None |
|
|
|
|
|
@compatibility(is_backward_compatible=True) |
|
def insert_deferred_runtime_asserts( |
|
gm: GraphModule, |
|
shape_env: ShapeEnv, |
|
name: str, |
|
export: bool = False, |
|
) -> None: |
|
""" |
|
During tracing, we may have discovered that some data-dependent values |
|
had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime |
|
that x.item() >= 0. This asserts can happen unpredictably during fake |
|
tensor propagation, so we cannot conveniently insert them into the FX graph |
|
when they occur. Instead, we accumulate them in the ShapeEnv, and in this |
|
pass insert them into the graph as proper tests. |
|
""" |
|
|
|
|
|
nodes_that_already_have_sym_constraint_range = set() |
|
|
|
|
|
nodes_that_already_have_sym_constraint_size = set() |
|
|
|
|
|
|
|
for node in gm.graph.nodes: |
|
if ( |
|
node.op == "call_function" |
|
and node.target == torch.ops.aten.sym_constrain_range.default |
|
): |
|
assert len(node.args) == 1 |
|
nodes_that_already_have_sym_constraint_range.add( |
|
(node.args[0], node.kwargs["min"], node.kwargs["max"]) |
|
) |
|
if ( |
|
node.op == "call_function" |
|
and node.target == torch.ops.aten.sym_constrain_range_for_size.default |
|
): |
|
assert len(node.args) == 1 |
|
nodes_that_already_have_sym_constraint_size.add(node.args[0]) |
|
|
|
|
|
import sympy |
|
|
|
from torch.fx.experimental.symbolic_shapes import ( |
|
CallMethodKey, |
|
cast_symbool_to_symint_guardless, |
|
ConvertIntKey, |
|
DivideByKey, |
|
free_symbols, |
|
InnerTensorKey, |
|
) |
|
from torch.utils._sympy.interp import sympy_interp |
|
from torch.utils._sympy.reference import PythonReferenceAnalysis |
|
|
|
|
|
ras_by_symbol = shape_env.deferred_runtime_asserts.copy() |
|
graph = gm.graph |
|
|
|
if not any(ras for ras in ras_by_symbol.values()): |
|
return |
|
|
|
graph_code_log.debug( |
|
"%s", |
|
lazy_format_graph_code(f"pre insert_deferred_runtime_asserts {name}", gm), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
new_ras = [] |
|
ras_exprs: Set[sympy.Expr] = set() |
|
for ras in ras_by_symbol.pop(None, []): |
|
if ras.expr not in ras_exprs: |
|
new_ras.append(ras) |
|
ras_exprs.add(ras.expr) |
|
ras_by_symbol[None] = new_ras |
|
|
|
|
|
symbol_to_proxy: Dict[sympy.Symbol, fx.Proxy] = {} |
|
placeholders = set() |
|
last_placeholder = None |
|
for node in graph.nodes: |
|
if node.op != "placeholder": |
|
break |
|
last_placeholder = node |
|
placeholders.add(node) |
|
if last_placeholder is None: |
|
last_placeholder = next(iter(graph.nodes)) |
|
|
|
|
|
|
|
needed_symbols: Set[sympy.Symbol] = set() |
|
for ras in ras_by_symbol.values(): |
|
for ra in ras: |
|
needed_symbols.update(free_symbols(ra.expr)) |
|
|
|
log.debug("needed_symbols = %s", needed_symbols) |
|
|
|
def add_runtime_asserts(ras): |
|
for ra in ras: |
|
log.debug("inserting runtime assert %s", ra.expr) |
|
|
|
fvs = free_symbols(ra.expr) |
|
missing = fvs - symbol_to_proxy.keys() |
|
if missing: |
|
i1 = min(missing, key=str) |
|
|
|
|
|
ras_by_symbol.setdefault(i1, []).append(ra) |
|
else: |
|
|
|
|
|
res = sympy_interp( |
|
PythonReferenceAnalysis, symbol_to_proxy, ra.expr |
|
).node |
|
graph.call_function( |
|
torch.ops.aten._assert_scalar.default, |
|
|
|
|
|
( |
|
res, |
|
f"Runtime assertion failed for expression {ra.expr} on node '{res}'", |
|
), |
|
) |
|
|
|
inserted_sym_nodes = 0 |
|
nodes = list(graph.nodes) |
|
for i, node in enumerate(nodes[:-1]): |
|
|
|
|
|
|
|
with graph.inserting_before( |
|
nodes[i + 1] if node not in placeholders else last_placeholder.next |
|
): |
|
|
|
|
|
|
|
|
|
if ( |
|
node in placeholders |
|
and (example_value := _get_example_value(node)) is not None |
|
): |
|
|
|
def match_symbol(symint, cb): |
|
if ( |
|
isinstance(symint, torch.SymInt) |
|
and isinstance(symint.node, SymNode) |
|
and isinstance(s := symint.node.expr, sympy.Symbol) |
|
and s not in symbol_to_proxy |
|
and s in needed_symbols |
|
): |
|
symbol_to_proxy[s] = fx.Proxy(cb()) |
|
log.debug("symbol_to_proxy[%s] = %s", s, symbol_to_proxy[s]) |
|
nonlocal inserted_sym_nodes |
|
inserted_sym_nodes += 1 |
|
|
|
match_symbol(example_value, lambda: node) |
|
if isinstance(t := example_value, torch.Tensor): |
|
for i, s in enumerate(t.size()): |
|
match_symbol( |
|
s, |
|
lambda: graph.call_function( |
|
torch.ops.aten.sym_size.int, (node, i) |
|
), |
|
) |
|
for i, s in enumerate(t.stride()): |
|
match_symbol( |
|
s, |
|
lambda: graph.call_function( |
|
torch.ops.aten.sym_stride.int, (node, i) |
|
), |
|
) |
|
match_symbol( |
|
t.storage_offset(), |
|
lambda: graph.call_function( |
|
torch.ops.aten.sym_storage_offset.default, (node,) |
|
), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if node not in placeholders: |
|
last_sym_node = last_placeholder |
|
for _ in range(inserted_sym_nodes): |
|
last_sym_node = last_sym_node.next |
|
with graph.inserting_before(last_sym_node.next): |
|
add_runtime_asserts(ras_by_symbol.pop(None, [])) |
|
|
|
defs = [] |
|
|
|
if unbacked_bindings := node.meta.get("unbacked_bindings"): |
|
for s, keypath in unbacked_bindings.items(): |
|
defs.append(s) |
|
|
|
|
|
|
|
def go(node, keypath): |
|
if keypath == (): |
|
return node |
|
if ( |
|
len(keypath) >= 2 |
|
and isinstance(keypath[0], CallMethodKey) |
|
and isinstance(keypath[1], pytree.SequenceKey) |
|
): |
|
if keypath[0].name == "size": |
|
return go( |
|
graph.call_function( |
|
torch.ops.aten.sym_size.int, |
|
(node, keypath[1].idx), |
|
), |
|
keypath[2:], |
|
) |
|
if keypath[0].name == "stride": |
|
return go( |
|
graph.call_function( |
|
torch.ops.aten.stride.int, |
|
(node, keypath[1].idx), |
|
), |
|
keypath[2:], |
|
) |
|
return go( |
|
graph.call_method( |
|
keypath[0].name, (node, keypath[1].idx) |
|
), |
|
keypath[2:], |
|
) |
|
elif isinstance(keypath[0], CallMethodKey): |
|
return go( |
|
graph.call_method(keypath[0].name, (node,)), keypath[1:] |
|
) |
|
elif isinstance(keypath[0], pytree.SequenceKey): |
|
return go( |
|
graph.call_function( |
|
operator.getitem, (node, keypath[0].idx) |
|
), |
|
keypath[1:], |
|
) |
|
elif isinstance(keypath[0], ConvertIntKey): |
|
return go( |
|
graph.call_function( |
|
cast_symbool_to_symint_guardless, (node,) |
|
), |
|
keypath[1:], |
|
) |
|
elif isinstance(keypath[0], DivideByKey): |
|
|
|
return go( |
|
graph.call_function( |
|
operator.floordiv, (node, keypath[0].divisor) |
|
), |
|
keypath[1:], |
|
) |
|
elif isinstance(keypath[0], InnerTensorKey): |
|
return go( |
|
graph.call_function( |
|
getattr, (node, keypath[0].inner_name) |
|
), |
|
keypath[1:], |
|
) |
|
else: |
|
raise AssertionError(f"unrecognized keypath {keypath}") |
|
|
|
symbol_to_proxy[s] = fx.Proxy(go(node, keypath)) |
|
log.debug("symbol_to_proxy[%s] = %s", s, symbol_to_proxy[s]) |
|
|
|
for i0 in defs: |
|
ras = ras_by_symbol.pop(i0, []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if i0 in shape_env.size_like: |
|
if export: |
|
if ( |
|
symbol_to_proxy[i0].node |
|
not in nodes_that_already_have_sym_constraint_size |
|
): |
|
graph.call_function( |
|
torch.ops.aten.sym_constrain_range_for_size.default, |
|
(symbol_to_proxy[i0].node,), |
|
) |
|
else: |
|
graph.call_function( |
|
torch._check_is_size, (symbol_to_proxy[i0].node,) |
|
) |
|
|
|
vr = shape_env.var_to_range[i0] |
|
if not shape_env._default_unspecified_value_range().issubset(vr): |
|
|
|
|
|
|
|
|
|
def convert(s): |
|
try: |
|
return int(s) |
|
except TypeError: |
|
return None |
|
|
|
min_val = convert(vr.lower) |
|
max_val = convert(vr.upper) |
|
|
|
if ( |
|
symbol_to_proxy[i0].node, |
|
min_val, |
|
max_val, |
|
) not in nodes_that_already_have_sym_constraint_range: |
|
graph.call_function( |
|
torch.ops.aten.sym_constrain_range.default, |
|
(symbol_to_proxy[i0].node,), |
|
{ |
|
"min": convert(vr.lower), |
|
"max": convert(vr.upper), |
|
}, |
|
) |
|
|
|
add_runtime_asserts(ras) |
|
|