File size: 17,303 Bytes
d1ceb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
# mypy: allow-untyped-defs
import logging
import operator
from typing import Any, Dict, Optional, Set, TYPE_CHECKING

# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
if TYPE_CHECKING:
    from torch.fx.experimental.symbolic_shapes import ShapeEnv
else:
    ShapeEnv = Any

import torch
import torch.utils._pytree as pytree
from torch import fx
from torch.fx._compatibility import compatibility
from torch.fx._utils import lazy_format_graph_code
from torch.fx.experimental.sym_node import SymNode
from torch.fx.graph_module import GraphModule

log = logging.getLogger(__name__)
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")


def _get_example_value(node: fx.Node) -> Optional[str]:
    """
    Get the example value key for a node, since dynamo uses "example_value"
    while non-strict export uses "val.
    """
    if "example_value" in node.meta:
        return node.meta["example_value"]
    elif "val" in node.meta:
        return node.meta["val"]
    else:
        return None


@compatibility(is_backward_compatible=True)
def insert_deferred_runtime_asserts(
    gm: GraphModule,
    shape_env: ShapeEnv,
    name: str,
    export: bool = False,
) -> None:
    """
    During tracing, we may have discovered that some data-dependent values
    had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime
    that x.item() >= 0.  This asserts can happen unpredictably during fake
    tensor propagation, so we cannot conveniently insert them into the FX graph
    when they occur.  Instead, we accumulate them in the ShapeEnv, and in this
    pass insert them into the graph as proper tests.
    """

    # We hash (node_name, min_val, max_val)
    nodes_that_already_have_sym_constraint_range = set()

    # We hash only node name here because size don't take min/max
    nodes_that_already_have_sym_constraint_size = set()
    # TODO this only works for top-level nodes today, also
    # we should potentially use it not create duplicate
    # assert_async nodes
    for node in gm.graph.nodes:
        if (
            node.op == "call_function"
            and node.target == torch.ops.aten.sym_constrain_range.default
        ):
            assert len(node.args) == 1
            nodes_that_already_have_sym_constraint_range.add(
                (node.args[0], node.kwargs["min"], node.kwargs["max"])
            )
        if (
            node.op == "call_function"
            and node.target == torch.ops.aten.sym_constrain_range_for_size.default
        ):
            assert len(node.args) == 1
            nodes_that_already_have_sym_constraint_size.add(node.args[0])

    # Import sympy locally
    import sympy

    from torch.fx.experimental.symbolic_shapes import (
        CallMethodKey,
        cast_symbool_to_symint_guardless,
        ConvertIntKey,
        DivideByKey,
        free_symbols,
        InnerTensorKey,
    )
    from torch.utils._sympy.interp import sympy_interp
    from torch.utils._sympy.reference import PythonReferenceAnalysis

    # TODO: Request simplification on runtime asserts before emitting them
    ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
    graph = gm.graph

    if not any(ras for ras in ras_by_symbol.values()):
        return

    graph_code_log.debug(
        "%s",
        lazy_format_graph_code(f"pre insert_deferred_runtime_asserts {name}", gm),
    )

    # deduplicate unassociated runtime assertions
    # we could do better, some guards might be redundant,
    # e.g. Eq(s0, 4) & Eq(2*s0, 8)
    # but unclear how to handle all of that right now.
    # TODO(pianpwk): better way of doing this
    new_ras = []
    ras_exprs: Set[sympy.Expr] = set()
    for ras in ras_by_symbol.pop(None, []):  # type: ignore[call-overload]
        if ras.expr not in ras_exprs:
            new_ras.append(ras)
            ras_exprs.add(ras.expr)
    ras_by_symbol[None] = new_ras  # type: ignore[index]

    # We are going to mutate the dict
    symbol_to_proxy: Dict[sympy.Symbol, fx.Proxy] = {}
    placeholders = set()
    last_placeholder = None
    for node in graph.nodes:
        if node.op != "placeholder":
            break
        last_placeholder = node
        placeholders.add(node)
    if last_placeholder is None:  # no placeholders, just insert before first node
        last_placeholder = next(iter(graph.nodes))

    # Identify what symbols we need to reify.  This isn't strictly needed
    # but helps reduce churn on the graph
    needed_symbols: Set[sympy.Symbol] = set()
    for ras in ras_by_symbol.values():
        for ra in ras:
            needed_symbols.update(free_symbols(ra.expr))

    log.debug("needed_symbols = %s", needed_symbols)

    def add_runtime_asserts(ras):
        for ra in ras:
            log.debug("inserting runtime assert %s", ra.expr)
            # Need to process ALL free symbols, not just unbacked ones
            fvs = free_symbols(ra.expr)
            missing = fvs - symbol_to_proxy.keys()
            if missing:
                i1 = min(missing, key=str)
                # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689
                # assert shape_env.is_unbacked_symint(i1), i1
                ras_by_symbol.setdefault(i1, []).append(ra)
            else:
                # Convert the sympy expression into a sequence of FX
                # nodes
                res = sympy_interp(
                    PythonReferenceAnalysis, symbol_to_proxy, ra.expr
                ).node
                graph.call_function(
                    torch.ops.aten._assert_scalar.default,
                    # TODO: use ra.msg here, but it's pretty
                    # useless right now
                    (
                        res,
                        f"Runtime assertion failed for expression {ra.expr} on node '{res}'",
                    ),
                )

    inserted_sym_nodes = 0  # for inserting unassociated runtime asserts
    nodes = list(graph.nodes)
    for i, node in enumerate(nodes[:-1]):
        # Placeholders can match symbols, but when we destructure them
        # with size we have to make sure we insert the nodes after all
        # the placeholders
        with graph.inserting_before(
            nodes[i + 1] if node not in placeholders else last_placeholder.next
        ):
            # Unfortunately, this logic still must remain because manual
            # make_fx calls may not explicitly bind all symbolic ints as
            # arguments to the function, so we must infer it from the other
            # arguments
            if (
                node in placeholders
                and (example_value := _get_example_value(node)) is not None
            ):

                def match_symbol(symint, cb):
                    if (
                        isinstance(symint, torch.SymInt)
                        and isinstance(symint.node, SymNode)
                        and isinstance(s := symint.node.expr, sympy.Symbol)
                        and s not in symbol_to_proxy
                        and s in needed_symbols
                    ):
                        symbol_to_proxy[s] = fx.Proxy(cb())
                        log.debug("symbol_to_proxy[%s] = %s", s, symbol_to_proxy[s])
                        nonlocal inserted_sym_nodes
                        inserted_sym_nodes += 1

                match_symbol(example_value, lambda: node)
                if isinstance(t := example_value, torch.Tensor):
                    for i, s in enumerate(t.size()):
                        match_symbol(
                            s,
                            lambda: graph.call_function(
                                torch.ops.aten.sym_size.int, (node, i)
                            ),
                        )
                    for i, s in enumerate(t.stride()):
                        match_symbol(
                            s,
                            lambda: graph.call_function(
                                torch.ops.aten.sym_stride.int, (node, i)
                            ),
                        )
                    match_symbol(
                        t.storage_offset(),
                        lambda: graph.call_function(
                            torch.ops.aten.sym_storage_offset.default, (node,)
                        ),
                    )

            # Handle asserts that aren't associated with any symbol.  This
            # doesn't really have to be in the loop as it will only run once,
            # it just needs to happen right after the placeholders.
            # insert this after placeholders & added sym nodes, and before non-placeholders.
            if node not in placeholders:
                last_sym_node = last_placeholder
                for _ in range(inserted_sym_nodes):
                    last_sym_node = last_sym_node.next
                with graph.inserting_before(last_sym_node.next):
                    add_runtime_asserts(ras_by_symbol.pop(None, []))  # type: ignore[call-overload]

            defs = []

            if unbacked_bindings := node.meta.get("unbacked_bindings"):
                for s, keypath in unbacked_bindings.items():
                    defs.append(s)

                    # TODO: some CSE when generating these nodes can probably
                    # help reduce graph size and improve compile itme
                    def go(node, keypath):
                        if keypath == ():
                            return node
                        if (
                            len(keypath) >= 2
                            and isinstance(keypath[0], CallMethodKey)
                            and isinstance(keypath[1], pytree.SequenceKey)
                        ):
                            if keypath[0].name == "size":
                                return go(
                                    graph.call_function(
                                        torch.ops.aten.sym_size.int,
                                        (node, keypath[1].idx),
                                    ),
                                    keypath[2:],
                                )
                            if keypath[0].name == "stride":
                                return go(
                                    graph.call_function(
                                        torch.ops.aten.stride.int,
                                        (node, keypath[1].idx),
                                    ),
                                    keypath[2:],
                                )
                            return go(
                                graph.call_method(
                                    keypath[0].name, (node, keypath[1].idx)
                                ),
                                keypath[2:],
                            )
                        elif isinstance(keypath[0], CallMethodKey):
                            return go(
                                graph.call_method(keypath[0].name, (node,)), keypath[1:]
                            )
                        elif isinstance(keypath[0], pytree.SequenceKey):
                            return go(
                                graph.call_function(
                                    operator.getitem, (node, keypath[0].idx)
                                ),
                                keypath[1:],
                            )
                        elif isinstance(keypath[0], ConvertIntKey):
                            return go(
                                graph.call_function(
                                    cast_symbool_to_symint_guardless, (node,)
                                ),
                                keypath[1:],
                            )
                        elif isinstance(keypath[0], DivideByKey):
                            # TODO: need to assert divisibility
                            return go(
                                graph.call_function(
                                    operator.floordiv, (node, keypath[0].divisor)
                                ),
                                keypath[1:],
                            )
                        elif isinstance(keypath[0], InnerTensorKey):
                            return go(
                                graph.call_function(
                                    getattr, (node, keypath[0].inner_name)
                                ),
                                keypath[1:],
                            )
                        else:
                            raise AssertionError(f"unrecognized keypath {keypath}")

                    symbol_to_proxy[s] = fx.Proxy(go(node, keypath))
                    log.debug("symbol_to_proxy[%s] = %s", s, symbol_to_proxy[s])

            for i0 in defs:
                ras = ras_by_symbol.pop(i0, [])
                # Before we perform any asserts, first apply range
                # refinement.  This is important, because if we are going
                # to retrace the graph (and we typically are if we send
                # the graph to AOTAutograd), we need to make sure we apply
                # range refinement (ala _check_is_size) first, BEFORE we
                # run any of the asserts.  Otherwise, we may decide to
                # perform substitutions based on the asserts which we then
                # can't back out, because value ranges can only be applied
                # to asserts.)
                #
                # A perhaps better long term plan is to avoid this order
                # dependence by making it possible to refine ranges on
                # arbitrary expressions, not just symbols.  But it is not
                # so easy to make use of this information, see
                # https://twitter.com/ezyang/status/1745801370299482492
                # We actually made an attempt at this in
                # https://github.com/pytorch/pytorch/pull/119043
                # which didn't work.
                #
                # Another ideas for how to do this:
                # - Have bound_sympy be the source of truth of the ranges of any expression
                # - Cache intermediate results for every subexpression of bound_sympy
                # - This cache should be possible to edit to refine ranges
                #
                # One issue with this proposal is that if
                # we have a bound on 2x, we are not going to be able to
                # apply it for 4x.  Similarly, we may have bounds for an
                # equivalent expression that we are not applying because
                # it's not a perfect match (e.g. x < y vs y > x)".
                #
                # The first issue we already have it and it's impossible
                # to solve in general, so any implementation on a best
                # effort basis should do.
                #
                # The second issue is a preexisting one. It can be mitigated
                # with a normalisation algorithm. In general, it may also
                # be on a best effort basis, but since our grammar is not
                # terribly difficult, chances are we could even fully
                # normalise SymPy expressions... who knows.

                if i0 in shape_env.size_like:
                    if export:
                        if (
                            symbol_to_proxy[i0].node
                            not in nodes_that_already_have_sym_constraint_size
                        ):
                            graph.call_function(
                                torch.ops.aten.sym_constrain_range_for_size.default,
                                (symbol_to_proxy[i0].node,),
                            )
                    else:
                        graph.call_function(
                            torch._check_is_size, (symbol_to_proxy[i0].node,)
                        )

                vr = shape_env.var_to_range[i0]
                if not shape_env._default_unspecified_value_range().issubset(vr):
                    # The runtime range is constrained, so add a runtime
                    # assert and also explicitly refine the range
                    # (refinement should not be necessary once runtime
                    # asserts cause refinement, but that's NYI)
                    def convert(s):
                        try:
                            return int(s)
                        except TypeError:
                            return None

                    min_val = convert(vr.lower)
                    max_val = convert(vr.upper)

                    if (
                        symbol_to_proxy[i0].node,
                        min_val,
                        max_val,
                    ) not in nodes_that_already_have_sym_constraint_range:
                        graph.call_function(
                            torch.ops.aten.sym_constrain_range.default,
                            (symbol_to_proxy[i0].node,),
                            {
                                "min": convert(vr.lower),
                                "max": convert(vr.upper),
                            },
                        )

                add_runtime_asserts(ras)