|
|
|
""" |
|
This file does three things: |
|
- Contains the definition of SymNode |
|
- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time |
|
- Does not depend on sympy at import time |
|
|
|
As this file is imported from within torch/__init__.py we do not want it to depend on SymPy |
|
to avoid having to load SymPy at import time, as doing so is *very* slow. |
|
""" |
|
|
|
import builtins |
|
import itertools |
|
import logging |
|
import math |
|
import operator |
|
import sys |
|
from functools import lru_cache, update_wrapper |
|
from typing import Optional, Type, TYPE_CHECKING, Union |
|
|
|
import torch |
|
|
|
|
|
from torch import ( |
|
sym_float, |
|
sym_ite, |
|
sym_max, |
|
sym_min, |
|
sym_not, |
|
SymBool, |
|
SymFloat, |
|
SymInt, |
|
) |
|
|
|
from torch.fx.experimental._sym_dispatch_mode import ( |
|
handle_sym_dispatch, |
|
sym_function_mode, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv |
|
|
|
log = logging.getLogger(__name__) |
|
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") |
|
|
|
|
|
__all__ = ["SymNode", "method_to_operator", "magic_methods"] |
|
|
|
|
|
SymTypes = (SymInt, SymFloat, SymBool) |
|
|
|
|
|
def _to_symtype(t): |
|
if t is bool: |
|
return SymBool |
|
if t is int: |
|
return SymInt |
|
if t is float: |
|
return SymFloat |
|
return t |
|
|
|
|
|
|
|
|
|
|
|
class SymNode: |
|
""" |
|
This is a type erased SymInt/SymFloat which we use to do actual operations. |
|
End users don't touch this. Magic methods are NOT defined on this object. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
expr, |
|
shape_env, |
|
pytype, |
|
hint: Optional[Union[int, float, bool]], |
|
constant=None, |
|
fx_node=None, |
|
): |
|
self._expr = expr |
|
self.shape_env = shape_env |
|
self.pytype = pytype |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hint is not None: |
|
assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( |
|
"Cannot create SymNode of type " |
|
f"{pytype} with incompatible hint of type {type(hint)}" |
|
) |
|
self._hint = hint |
|
self.constant: Optional[Union[int, float, bool]] = constant |
|
|
|
|
|
|
|
|
|
self.fx_node = ( |
|
fx_node if self.shape_env._translation_validation_enabled else None |
|
) |
|
|
|
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode": |
|
return SymNode( |
|
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node |
|
) |
|
|
|
@property |
|
def expr(self): |
|
return self.shape_env.replace(self._expr) |
|
|
|
|
|
|
|
def _update_hint(self): |
|
r = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) |
|
if r is not None: |
|
self._hint = self.pytype(r) if not isinstance(r, SymTypes) else r |
|
|
|
@property |
|
def hint(self): |
|
if self._hint is None: |
|
self._update_hint() |
|
return self._hint |
|
|
|
def has_hint(self): |
|
if self._hint is None: |
|
self._update_hint() |
|
return self._hint is not None |
|
|
|
def require_hint(self, fallback=None): |
|
if self._hint is None: |
|
self._update_hint() |
|
if self._hint is None: |
|
if fallback is not None: |
|
return fallback |
|
|
|
return self.shape_env.size_hint(self.expr) |
|
return self._hint |
|
|
|
def maybe_as_int(self): |
|
if self.expr.is_number: |
|
return int(self.expr) |
|
else: |
|
return None |
|
|
|
|
|
def maybe_as_float(self): |
|
import sympy |
|
|
|
if isinstance(self.expr, sympy.Float): |
|
return float(self.expr) |
|
else: |
|
return None |
|
|
|
def maybe_as_bool(self): |
|
import sympy |
|
|
|
if self.expr is sympy.true: |
|
return True |
|
elif self.expr is sympy.false: |
|
return False |
|
else: |
|
return None |
|
|
|
def is_int(self): |
|
return self.pytype is int |
|
|
|
def is_float(self): |
|
return self.pytype is float |
|
|
|
def is_bool(self): |
|
return self.pytype is bool |
|
|
|
def is_nested_int(self): |
|
|
|
return ( |
|
self._hint is not None |
|
and isinstance(self._hint, SymInt) |
|
and self._hint.node.is_nested_int() |
|
) |
|
|
|
def wrap_int(self, num): |
|
assert type(num) is int |
|
import sympy |
|
|
|
return SymNode( |
|
sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num |
|
) |
|
|
|
def wrap_float(self, num): |
|
assert type(num) is float |
|
import sympy |
|
|
|
return SymNode( |
|
sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num |
|
) |
|
|
|
def wrap_bool(self, num): |
|
assert type(num) is bool |
|
import sympy |
|
|
|
return SymNode( |
|
sympy.true if num else sympy.false, |
|
self.shape_env, |
|
bool, |
|
num, |
|
constant=num, |
|
fx_node=num, |
|
) |
|
|
|
def clone(self): |
|
return self |
|
|
|
def str(self): |
|
return f"{self.expr}" |
|
|
|
def __str__(self): |
|
return self.str() |
|
|
|
def __repr__(self): |
|
return self.str() |
|
|
|
|
|
|
|
def abs(self) -> "SymNode": |
|
return self._abs() |
|
|
|
def pos(self) -> "SymNode": |
|
return self._pos() |
|
|
|
def round(self, ndigits=None) -> "SymNode": |
|
return self._round(ndigits) |
|
|
|
def trunc(self) -> "SymNode": |
|
return self._trunc() |
|
|
|
def add(self, other) -> "SymNode": |
|
return self._add(other) |
|
|
|
def sub(self, other) -> "SymNode": |
|
return self._sub(other) |
|
|
|
def mul(self, other) -> "SymNode": |
|
return self._mul(other) |
|
|
|
def mod(self, other) -> "SymNode": |
|
return self._mod(other) |
|
|
|
def float_pow(self, other) -> "SymNode": |
|
return self._float_pow(other) |
|
|
|
def pow_by_natural(self, other) -> "SymNode": |
|
return self._pow_by_natural(other) |
|
|
|
def and_(self, other) -> "SymNode": |
|
return self._and_(other) |
|
|
|
def or_(self, other) -> "SymNode": |
|
return self._or_(other) |
|
|
|
def float_truediv(self, other) -> "SymNode": |
|
return self._float_truediv(other) |
|
|
|
def int_truediv(self, other) -> "SymNode": |
|
return self._int_truediv(other) |
|
|
|
def int_floordiv(self, other) -> "SymNode": |
|
return self._int_floordiv(other) |
|
|
|
def lshift(self, other) -> "SymNode": |
|
return self._lshift(other) |
|
|
|
def rshift(self, other) -> "SymNode": |
|
return self._rshift(other) |
|
|
|
def sym_not(self) -> "SymNode": |
|
return self._sym_not() |
|
|
|
def eq(self, other) -> "SymNode": |
|
return self._eq(other) |
|
|
|
def ne(self, other) -> "SymNode": |
|
return self._ne(other) |
|
|
|
def gt(self, other) -> "SymNode": |
|
return self._gt(other) |
|
|
|
def lt(self, other) -> "SymNode": |
|
return self._lt(other) |
|
|
|
def le(self, other) -> "SymNode": |
|
return self._le(other) |
|
|
|
def ge(self, other) -> "SymNode": |
|
return self._ge(other) |
|
|
|
def floor(self) -> "SymNode": |
|
return self._floor() |
|
|
|
def is_integer(self) -> "SymNode": |
|
return self._is_integer() |
|
|
|
def sym_float(self) -> "SymNode": |
|
return self._sym_float() |
|
|
|
def sym_int(self) -> "SymNode": |
|
return self._sym_int() |
|
|
|
def ceil(self) -> "SymNode": |
|
return self._ceil() |
|
|
|
def neg(self) -> "SymNode": |
|
return self._neg() |
|
|
|
def sym_min(self, other) -> "SymNode": |
|
return self._sym_min(other) |
|
|
|
def sym_max(self, other) -> "SymNode": |
|
return self._sym_max(other) |
|
|
|
def sym_ite(self, then_val, else_val) -> "SymNode": |
|
return self._sym_ite(then_val, else_val) |
|
|
|
def is_contiguous(self, sizes, strides) -> "SymNode": |
|
return self._is_contiguous(sizes, strides) |
|
|
|
def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": |
|
return self._is_channels_last_contiguous_2d(sizes, strides) |
|
|
|
def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": |
|
return self._is_channels_last_contiguous_3d(sizes, strides) |
|
|
|
def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": |
|
return self._is_channels_last_strides_2d(sizes, strides) |
|
|
|
def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": |
|
return self._is_channels_last_strides_3d(sizes, strides) |
|
|
|
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": |
|
return self._is_non_overlapping_and_dense_indicator(sizes, strides) |
|
|
|
|
|
def sym_or(self, other): |
|
return self.or_(other) |
|
|
|
def sym_and(self, other): |
|
return self.and_(other) |
|
|
|
|
|
def truediv(self, other): |
|
return self.float_truediv(other) |
|
|
|
def floordiv(self, other) -> "SymNode": |
|
return self.int_floordiv(other) |
|
|
|
|
|
def pow(self, other): |
|
return self.float_pow(other) |
|
|
|
def is_non_overlapping_and_dense(self, sizes, strides): |
|
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) |
|
|
|
def int_(self): |
|
return self.guard_int("", 0) |
|
|
|
|
|
def guard_int(self, file, line): |
|
|
|
|
|
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) |
|
try: |
|
return int(r) |
|
except Exception: |
|
log.warning("Failed to convert to int: %s", r) |
|
raise |
|
|
|
def guard_float(self, file, line): |
|
|
|
|
|
r = self.shape_env.evaluate_expr( |
|
self.expr, self.hint, fx_node=self.fx_node, expect_rational=False |
|
) |
|
try: |
|
return float(r) |
|
except Exception: |
|
log.warning("Failed to convert to float: %s", r) |
|
raise |
|
|
|
def guard_bool(self, file, line): |
|
|
|
|
|
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) |
|
try: |
|
return bool(r) |
|
except Exception: |
|
log.warning("Failed to convert to bool: %s", r) |
|
raise |
|
|
|
def expect_true(self, file, line): |
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols |
|
|
|
if ( |
|
self.has_hint() |
|
and not free_unbacked_symbols(self.expr) |
|
and not self.shape_env.prefer_deferred_runtime_asserts_over_guards |
|
): |
|
|
|
return self.guard_bool(file, line) |
|
|
|
|
|
|
|
|
|
return self.shape_env.defer_runtime_assert( |
|
self.expr, f"{file}:{line}", fx_node=self.fx_node |
|
) |
|
|
|
def expect_size(self, file, line): |
|
from torch.fx.experimental.symbolic_shapes import _advise_is_size |
|
|
|
b = self.ge(self.wrap_int(0)) |
|
|
|
r = b.expect_true(file, line) |
|
|
|
|
|
|
|
|
|
if r and not self.has_hint(): |
|
_advise_is_size(SymInt(self)) |
|
return r |
|
|
|
def guard_size_oblivious(self, file, line): |
|
""" |
|
Like guard_bool, but if we encounter unbacked symbols, if those symbols |
|
are size-like, we will treat them as >= 2 for the purposes of the analysis. |
|
|
|
This CHANGES the runtime semantics, but all size-oblivious sites have been |
|
audited to ensure that the runtime semantics don't change in a material way. |
|
Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping |
|
an unbacked one size, or a tensor reporting as non-contiguous even if it's |
|
contiguous if it would have been reported contiguous due to being empty. |
|
""" |
|
|
|
|
|
r = self.shape_env.evaluate_expr( |
|
self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True |
|
) |
|
try: |
|
return bool(r) |
|
except Exception: |
|
log.warning("Failed to convert to bool: %s", r) |
|
raise |
|
|
|
def bool_(self): |
|
return self.guard_bool("", 0) |
|
|
|
def is_symbolic(self): |
|
return True |
|
|
|
def nested_int(self): |
|
return None |
|
|
|
def is_constant(self): |
|
return False |
|
|
|
|
|
|
|
METHOD_TO_OPERATOR = { |
|
"pos": operator.pos, |
|
"abs": operator.abs, |
|
"add": operator.add, |
|
"and": operator.and_, |
|
"ceil": math.ceil, |
|
"eq": operator.eq, |
|
"floor": math.floor, |
|
"trunc": math.trunc, |
|
"int_floordiv": operator.floordiv, |
|
"ge": operator.ge, |
|
"gt": operator.gt, |
|
"is_integer": lambda x: x.is_integer(), |
|
"le": operator.le, |
|
"lshift": operator.lshift, |
|
"lt": operator.lt, |
|
"mod": operator.mod, |
|
"mul": operator.mul, |
|
"ne": operator.ne, |
|
"neg": operator.neg, |
|
"or": operator.or_, |
|
"float_pow": operator.pow, |
|
"pow_by_natural": operator.pow, |
|
"round": builtins.round, |
|
"rshift": operator.rshift, |
|
"sub": operator.sub, |
|
"sym_float": sym_float, |
|
"sym_ite": sym_ite, |
|
"sym_max": sym_max, |
|
"sym_min": sym_min, |
|
"sym_not": sym_not, |
|
"float_truediv": operator.truediv, |
|
"int_truediv": operator.truediv, |
|
} |
|
|
|
unary_magic_methods = { |
|
"abs", |
|
"sym_float", |
|
"sym_int", |
|
"ceil", |
|
"floor", |
|
"neg", |
|
"sym_not", |
|
"pos", |
|
"trunc", |
|
} |
|
|
|
|
|
|
|
def _get_sym_node_fn(name): |
|
def fn(self): |
|
return getattr(self, f"_sym_{name}")() |
|
|
|
return fn |
|
|
|
|
|
math_op_names = ( |
|
"sqrt", |
|
"cos", |
|
"cosh", |
|
"sin", |
|
"sinh", |
|
"tan", |
|
"tanh", |
|
"asin", |
|
"acos", |
|
"atan", |
|
) |
|
for name in math_op_names: |
|
sym_name = f"sym_{name}" |
|
priv_sym_name = f"_{sym_name}" |
|
setattr(SymNode, sym_name, _get_sym_node_fn(name)) |
|
METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) |
|
unary_magic_methods.add(sym_name) |
|
__all__.append(sym_name) |
|
|
|
|
|
|
|
unary_nonmagic_methods = { |
|
"is_integer", |
|
} |
|
|
|
unary_methods = unary_magic_methods | unary_nonmagic_methods |
|
|
|
|
|
|
|
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} |
|
|
|
bool_becomes_int_magic_methods = {"add", "sub", "mul"} |
|
|
|
also_bool_magic_methods = {"eq"} |
|
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods |
|
|
|
|
|
only_float_magic_methods = {"is_integer", "round", "sym_int"} |
|
|
|
|
|
magic_methods_on_operator_with_trailing_underscore = {"and", "or"} |
|
|
|
|
|
always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} |
|
|
|
for name in math_op_names: |
|
sym_name = f"sym_{name}" |
|
always_float_magic_methods.add(sym_name) |
|
|
|
|
|
always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} |
|
always_bool_magic_methods = { |
|
"eq", |
|
"ne", |
|
"gt", |
|
"lt", |
|
"le", |
|
"ge", |
|
"and", |
|
"or", |
|
"sym_not", |
|
"is_non_overlapping_and_dense", |
|
"is_integer", |
|
} |
|
|
|
|
|
|
|
|
|
def _sympy_float_truediv(a, b): |
|
from torch.utils._sympy.functions import FloatTrueDiv |
|
|
|
return FloatTrueDiv(a, b) |
|
|
|
|
|
def _sympy_int_truediv(a, b): |
|
from torch.utils._sympy.functions import IntTrueDiv |
|
|
|
return IntTrueDiv(a, b) |
|
|
|
|
|
def _sympy_floordiv(a, b): |
|
from torch.utils._sympy.functions import FloorDiv |
|
|
|
return FloorDiv(a, b) |
|
|
|
|
|
def _sympy_mod(a, b): |
|
from torch.utils._sympy.functions import Mod, PythonMod |
|
|
|
if a.is_nonnegative and b.is_nonnegative: |
|
return Mod(a, b) |
|
else: |
|
return PythonMod(a, b) |
|
|
|
|
|
def _sympy_pow_by_natural(a, b): |
|
from torch.utils._sympy.functions import PowByNatural |
|
|
|
return PowByNatural(a, b) |
|
|
|
|
|
def _sympy_float_pow(a, b): |
|
from torch.utils._sympy.functions import FloatPow |
|
|
|
return FloatPow(a, b) |
|
|
|
|
|
def _sympy_and(a, b): |
|
import sympy |
|
|
|
return sympy.And(a, b) |
|
|
|
|
|
def _sympy_or(a, b): |
|
import sympy |
|
|
|
return sympy.Or(a, b) |
|
|
|
|
|
def _sympy_lshift(a, b): |
|
from torch.utils._sympy.functions import LShift |
|
|
|
return LShift(a, b) |
|
|
|
|
|
def _sympy_rshift(a, b): |
|
from torch.utils._sympy.functions import RShift |
|
|
|
return RShift(a, b) |
|
|
|
|
|
reflectable_magic_methods = { |
|
"add": operator.add, |
|
"sub": operator.sub, |
|
"mul": operator.mul, |
|
"mod": _sympy_mod, |
|
"pow_by_natural": _sympy_pow_by_natural, |
|
"float_pow": _sympy_float_pow, |
|
"and": _sympy_and, |
|
"or": _sympy_or, |
|
"float_truediv": _sympy_float_truediv, |
|
"int_truediv": _sympy_int_truediv, |
|
"int_floordiv": _sympy_floordiv, |
|
"lshift": _sympy_lshift, |
|
"rshift": _sympy_rshift, |
|
} |
|
|
|
|
|
def _floor_ceil_helper(a, fn): |
|
import sympy |
|
|
|
if isinstance(a, sympy.Mul): |
|
aa = a.args |
|
if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: |
|
coef = sympy.Integer(aa[0]) |
|
if aa[0] == coef: |
|
return coef * aa[1] |
|
if ( |
|
isinstance(a, sympy.Float) |
|
and a == sympy.Integer(a) |
|
or isinstance(a, sympy.Integer) |
|
): |
|
return sympy.Integer(a) |
|
return fn(a) |
|
|
|
|
|
def _sympy_floor(a): |
|
from torch.utils._sympy.functions import FloorToInt |
|
|
|
return FloorToInt(a) |
|
|
|
|
|
|
|
|
|
def _sympy_trunc(a): |
|
from torch.utils._sympy.functions import TruncToInt |
|
|
|
return TruncToInt(a) |
|
|
|
|
|
def _sympy_ceil(a): |
|
from torch.utils._sympy.functions import CeilToInt |
|
|
|
return CeilToInt(a) |
|
|
|
|
|
def _sympy_eq(a, b): |
|
import sympy |
|
|
|
return sympy.Eq(a, b) |
|
|
|
|
|
def _sympy_ne(a, b): |
|
import sympy |
|
|
|
return sympy.Ne(a, b) |
|
|
|
|
|
def _sympy_gt(a, b): |
|
import sympy |
|
|
|
return sympy.Gt(a, b) |
|
|
|
|
|
def _sympy_lt(a, b): |
|
import sympy |
|
|
|
return sympy.Lt(a, b) |
|
|
|
|
|
def _sympy_le(a, b): |
|
import sympy |
|
|
|
return sympy.Le(a, b) |
|
|
|
|
|
def _sympy_ge(a, b): |
|
import sympy |
|
|
|
return sympy.Ge(a, b) |
|
|
|
|
|
def _sympy_min(a, b): |
|
import sympy |
|
|
|
return sympy.Min(a, b) |
|
|
|
|
|
def _sympy_max(a, b): |
|
import sympy |
|
|
|
return sympy.Max(a, b) |
|
|
|
|
|
def _sympy_ite(a, t, f): |
|
import sympy |
|
|
|
return sympy.Piecewise((t, a), (f, True)) |
|
|
|
|
|
current_module = sys.modules[__name__] |
|
|
|
|
|
def _get_sym_math_fn(name): |
|
def fn(a): |
|
import torch.utils._sympy.functions |
|
|
|
return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a) |
|
|
|
return fn |
|
|
|
|
|
for name in math_op_names: |
|
priv_sympy_name = f"_sympy_{name}" |
|
fn = _get_sym_math_fn(name) |
|
fn.__qualname__ = fn.__name__ = priv_sympy_name |
|
setattr(current_module, priv_sympy_name, fn) |
|
|
|
del fn, name, priv_sympy_name |
|
|
|
|
|
def _sympy_abs(a): |
|
import sympy |
|
|
|
return sympy.Abs(a) |
|
|
|
|
|
def _sympy_round(number, ndigits=None): |
|
from torch.utils._sympy.functions import RoundDecimal, RoundToInt |
|
|
|
if ndigits is None: |
|
return RoundToInt(number) |
|
else: |
|
return RoundDecimal(number, ndigits) |
|
|
|
|
|
def _sympy_sym_float(a): |
|
from torch.utils._sympy.functions import ToFloat |
|
|
|
|
|
|
|
return ToFloat(a) |
|
|
|
|
|
def _sympy_is_integer(a): |
|
import sympy |
|
|
|
from torch.utils._sympy.functions import ToFloat |
|
|
|
return sympy.Eq(ToFloat(sympy.floor(a)), a) |
|
|
|
|
|
magic_methods = { |
|
**reflectable_magic_methods, |
|
"sym_not": operator.invert, |
|
"pos": operator.pos, |
|
"eq": _sympy_eq, |
|
"ne": _sympy_ne, |
|
"gt": _sympy_gt, |
|
"lt": _sympy_lt, |
|
"le": _sympy_le, |
|
"ge": _sympy_ge, |
|
"floor": _sympy_floor, |
|
"trunc": _sympy_trunc, |
|
"sym_float": _sympy_sym_float, |
|
"ceil": _sympy_ceil, |
|
"neg": operator.neg, |
|
"sym_min": _sympy_min, |
|
"sym_max": _sympy_max, |
|
"sym_ite": _sympy_ite, |
|
"abs": _sympy_abs, |
|
"round": _sympy_round, |
|
"is_integer": _sympy_is_integer, |
|
} |
|
|
|
|
|
for name in math_op_names: |
|
sym_name = f"sym_{name}" |
|
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") |
|
|
|
del name, sym_name, math_op_names, current_module |
|
|
|
|
|
def sympy_is_contiguous(sizes, strides): |
|
dim = len(sizes) |
|
return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) |
|
|
|
|
|
def sympy_is_contiguous_generic(sizes, strides, dim_order): |
|
import sympy |
|
|
|
dim = len(sizes) |
|
|
|
if len(dim_order) != dim: |
|
return sympy.false |
|
|
|
is_contiguous = sympy.true |
|
z = sympy.Integer(1) |
|
|
|
for d in dim_order: |
|
is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) |
|
z *= sizes[d] |
|
|
|
for d in range(dim): |
|
is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) |
|
return is_contiguous |
|
|
|
|
|
|
|
|
|
|
|
|
|
def sympy_is_channels_last_contiguous_2d(sizes, strides): |
|
return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) |
|
|
|
|
|
def sympy_is_channels_last_contiguous_3d(sizes, strides): |
|
return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) |
|
|
|
|
|
def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): |
|
import sympy |
|
|
|
dim = len(sizes) |
|
|
|
if dim != len(dim_order): |
|
return sympy.false |
|
|
|
m = sympy.Integer(0) |
|
r = sympy.true |
|
|
|
|
|
r &= sympy.Ne(strides[1], 0) |
|
|
|
for d in dim_order: |
|
r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if d == 0: |
|
r &= sympy.Ne(m, strides[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m = strides[d] * sympy.Max(sizes[d], 1) |
|
|
|
return r |
|
|
|
|
|
def sympy_is_channels_last_strides_2d(sizes, strides): |
|
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) |
|
|
|
|
|
def sympy_is_channels_last_strides_3d(sizes, strides): |
|
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) |
|
|
|
|
|
def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): |
|
from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator |
|
|
|
return IsNonOverlappingAndDenseIndicator(*sizes, *strides) |
|
|
|
|
|
sizes_strides_methods = { |
|
|
|
|
|
"is_contiguous": sympy_is_contiguous, |
|
"is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, |
|
"is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, |
|
"is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, |
|
"is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, |
|
"is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, |
|
} |
|
|
|
alternate_impl_if_hinted_methods = { |
|
"sym_min": builtins.min, |
|
"sym_max": builtins.max, |
|
} |
|
|
|
|
|
def to_node(self, num): |
|
if isinstance(num, SymTypes): |
|
return num.node |
|
elif type(num) is bool: |
|
return self.wrap_bool(num) |
|
elif type(num) is int: |
|
return self.wrap_int(num) |
|
elif type(num) is float: |
|
return self.wrap_float(num) |
|
else: |
|
|
|
|
|
return NotImplemented |
|
|
|
|
|
def wrap_node(x): |
|
|
|
if isinstance(x, SymNode) and x.constant is not None: |
|
return x.constant |
|
if x.is_int(): |
|
return SymInt(x) |
|
elif x.is_float(): |
|
return SymFloat(x) |
|
elif x.is_bool(): |
|
return SymBool(x) |
|
else: |
|
raise AssertionError(f"unrecognized return type {x}") |
|
|
|
|
|
def method_to_operator(method): |
|
return METHOD_TO_OPERATOR[method] |
|
|
|
|
|
def _make_node_magic(method, func): |
|
func = lru_cache(256)(func) |
|
|
|
if method in magic_methods_on_operator_with_trailing_underscore: |
|
method_attr = f"{method}_" |
|
else: |
|
method_attr = method |
|
|
|
def binary_magic_impl(self, other): |
|
from torch.fx.experimental.symbolic_shapes import safe_expand |
|
|
|
op = method_to_operator(method) |
|
|
|
out_hint = None |
|
if self.hint is not None and other.hint is not None: |
|
out_hint = op(self.hint, other.hint) |
|
|
|
alternate_impl = alternate_impl_if_hinted_methods.get(method) |
|
if alternate_impl and out_hint is not None: |
|
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) |
|
|
|
if sym_function_mode(): |
|
return to_node( |
|
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) |
|
) |
|
assert isinstance(other, SymNode) |
|
try: |
|
if method == "mod": |
|
from torch.utils._sympy.functions import Mod, PythonMod |
|
|
|
|
|
|
|
shape_env = self.shape_env |
|
if ( |
|
self.expr.is_nonnegative |
|
or shape_env.bound_sympy(self.expr).lower >= 0 |
|
) and ( |
|
other.expr.is_nonnegative |
|
or shape_env.bound_sympy(other.expr).lower >= 0 |
|
): |
|
out = Mod(self.expr, other.expr) |
|
else: |
|
out = PythonMod(self.expr, other.expr) |
|
else: |
|
|
|
out = func(self.expr, other.expr) |
|
except Exception: |
|
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) |
|
raise |
|
out = safe_expand(out) |
|
sym_node_log.debug("%s %s %s -> %s", func, self.expr, other.expr, out) |
|
pytype: Type |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if method in always_float_magic_methods: |
|
pytype = float |
|
elif method in always_bool_magic_methods: |
|
pytype = bool |
|
elif self.pytype is float or other.pytype is float: |
|
pytype = float |
|
else: |
|
pytype = self.pytype |
|
|
|
if ( |
|
pytype is not None |
|
and out_hint is not None |
|
and not isinstance(out_hint, SymTypes) |
|
): |
|
out_hint = pytype(out_hint) |
|
|
|
|
|
|
|
fx_node, _ = self.shape_env._create_fx_call_function( |
|
op, (self.fx_node, other.fx_node) |
|
) |
|
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) |
|
|
|
def unary_magic_impl(self): |
|
from torch.fx.experimental.symbolic_shapes import safe_expand |
|
|
|
op = method_to_operator(method) |
|
if sym_function_mode(): |
|
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) |
|
|
|
expr = self.expr |
|
if method == "floor" or method == "ceiling": |
|
expr = self.shape_env._simplify_floor_div(expr) |
|
|
|
try: |
|
out = func(expr) |
|
except Exception: |
|
log.warning("failed to eval %s(%s)", method, expr) |
|
raise |
|
sym_node_log.debug("%s %s -> %s", func, expr, out) |
|
out_hint = None |
|
if self.hint is not None: |
|
out_hint = op(self.hint) |
|
out = safe_expand(out) |
|
pytype: Type |
|
if method in always_int_magic_methods: |
|
pytype = int |
|
elif method in always_bool_magic_methods: |
|
pytype = bool |
|
elif method in always_float_magic_methods: |
|
pytype = float |
|
else: |
|
pytype = self.pytype |
|
|
|
fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) |
|
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) |
|
|
|
if method in unary_methods: |
|
setattr(SymNode, f"_{method_attr}", unary_magic_impl) |
|
elif method == "sym_ite": |
|
|
|
def sym_ite_impl(pred_node, then_node, else_node): |
|
from torch.fx.experimental.symbolic_shapes import safe_expand |
|
|
|
out_hint = then_node.hint if pred_node.hint else else_node.hint |
|
if sym_function_mode(): |
|
return to_node( |
|
pred_node, |
|
handle_sym_dispatch( |
|
sym_ite, |
|
( |
|
wrap_node(pred_node), |
|
wrap_node(then_node), |
|
wrap_node(else_node), |
|
), |
|
{}, |
|
), |
|
) |
|
|
|
try: |
|
out = func(pred_node.expr, then_node.expr, else_node.expr) |
|
except Exception: |
|
log.warning( |
|
"failed to eval %s(%s, %s, %s)", |
|
method, |
|
pred_node.expr, |
|
then_node.expr, |
|
else_node.expr, |
|
) |
|
raise |
|
|
|
out = safe_expand(out) |
|
fx_node, _ = pred_node.shape_env._create_fx_call_function( |
|
sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) |
|
) |
|
return SymNode( |
|
out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node |
|
) |
|
|
|
setattr(SymNode, f"_{method_attr}", sym_ite_impl) |
|
elif method == "round": |
|
|
|
def round_impl(self, ndigits=None): |
|
from torch.fx.experimental.symbolic_shapes import safe_expand |
|
|
|
op = builtins.round |
|
if sym_function_mode(): |
|
return to_node( |
|
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) |
|
) |
|
|
|
expr = self.expr |
|
try: |
|
out = func(expr, ndigits) |
|
except Exception: |
|
log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) |
|
raise |
|
|
|
out = safe_expand(out) |
|
|
|
if ndigits is None: |
|
pytype = int |
|
else: |
|
pytype = self.pytype |
|
|
|
out_hint = None |
|
if self.hint is not None: |
|
out_hint = op(self.hint, ndigits) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = [self.fx_node] |
|
if ndigits is not None: |
|
args.append(ndigits) |
|
fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) |
|
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) |
|
|
|
setattr(SymNode, f"_{method_attr}", round_impl) |
|
else: |
|
setattr(SymNode, f"_{method_attr}", binary_magic_impl) |
|
|
|
|
|
def _make_node_sizes_strides(method, func): |
|
|
|
|
|
def sizes_strides_impl(self, sizes, strides): |
|
op = getattr(sys.modules[__name__], method) |
|
if sym_function_mode(): |
|
return to_node( |
|
self, |
|
handle_sym_dispatch( |
|
op, |
|
([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), |
|
{}, |
|
), |
|
) |
|
size_exprs = [s.expr for s in sizes] |
|
stride_exprs = [s.expr for s in strides] |
|
try: |
|
out = func(size_exprs, stride_exprs) |
|
except Exception: |
|
log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) |
|
raise |
|
|
|
|
|
size_hints = [] |
|
out_hint = None |
|
for s in sizes: |
|
if s.hint is None: |
|
break |
|
size_hints.append(s.hint) |
|
else: |
|
stride_hints = [] |
|
for s in strides: |
|
if s.hint is None: |
|
break |
|
stride_hints.append(s.hint) |
|
else: |
|
out_hint = op(size_hints, stride_hints) |
|
|
|
|
|
pytype: Type |
|
if method.endswith("_indicator"): |
|
pytype = int |
|
else: |
|
pytype = bool |
|
return SymNode(out, self.shape_env, pytype, out_hint) |
|
|
|
setattr(SymNode, f"_{method}", sizes_strides_impl) |
|
|
|
|
|
|
|
|
|
def sizes_strides_user(sizes, strides): |
|
import sympy |
|
|
|
from torch.fx.experimental.symbolic_shapes import ( |
|
eval_is_non_overlapping_and_dense, |
|
) |
|
|
|
for a in itertools.chain(sizes, strides): |
|
if isinstance(a, SymInt): |
|
return wrap_node( |
|
getattr(a.node, method)( |
|
[to_node(a.node, b) for b in sizes], |
|
[to_node(a.node, b) for b in strides], |
|
) |
|
) |
|
if method == "is_non_overlapping_and_dense_indicator": |
|
return eval_is_non_overlapping_and_dense(sizes, strides) |
|
else: |
|
|
|
return bool( |
|
func( |
|
[sympy.sympify(a) for a in sizes], |
|
[sympy.sympify(a) for a in strides], |
|
) |
|
) |
|
|
|
|
|
if not hasattr(sys.modules[__name__], method): |
|
setattr(sys.modules[__name__], method, sizes_strides_user) |
|
|
|
|
|
for method, func in magic_methods.items(): |
|
_make_node_magic(method, func) |
|
|
|
for method, func in sizes_strides_methods.items(): |
|
_make_node_sizes_strides(method, func) |
|
|
|
|
|
def _make_user_magic(method, user_type): |
|
|
|
|
|
|
|
if method in magic_methods_on_operator_with_trailing_underscore: |
|
method_attr = f"sym_{method}" |
|
else: |
|
method_attr = method |
|
|
|
def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): |
|
if isinstance(x, (int, float, bool)): |
|
return x |
|
if isinstance(x, SymBool): |
|
return x.node.guard_bool("", 0) |
|
raise AssertionError("expect to be called with constant SymBools") |
|
|
|
def is_constant(x): |
|
if isinstance(x, (int, float, bool)): |
|
return True |
|
if isinstance(x, (SymInt, SymFloat, SymBool)): |
|
return x.node.is_constant() |
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if method in bool_becomes_int_magic_methods: |
|
|
|
def promote(x): |
|
"""Implements True+True=2, which works in python but not sympy""" |
|
if isinstance(x, SymBool): |
|
return SymInt(x.node.wrap_int(int(x))) |
|
return x |
|
|
|
else: |
|
|
|
def promote(x): |
|
return x |
|
|
|
def promote2(self, other): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if method not in [ |
|
"add", |
|
"sub", |
|
"mul", |
|
"mod", |
|
"float_pow", |
|
"float_truediv", |
|
"int_floordiv", |
|
"sym_min", |
|
"sym_max", |
|
|
|
"eq", |
|
"ne", |
|
"gt", |
|
"lt", |
|
"le", |
|
"ge", |
|
]: |
|
return self, other |
|
f_self = isinstance(self, (float, torch.SymFloat)) |
|
f_other = isinstance(other, (float, torch.SymFloat)) |
|
if f_self or f_other: |
|
if not f_self: |
|
self = torch.sym_float(self) |
|
if not f_other: |
|
other = torch.sym_float(other) |
|
return self, other |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unary_magic_impl(self): |
|
self = promote(self) |
|
if is_constant(self): |
|
return (method_to_operator(method))(get_constant(self)) |
|
return wrap_node(getattr(self.node, method_attr)()) |
|
|
|
def binary_magic_impl(self, other): |
|
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): |
|
return NotImplemented |
|
sym_node_log.debug("MAGIC %s %s %s", method, self, other) |
|
self = promote(self) |
|
other = promote(other) |
|
self, other = promote2(self, other) |
|
if is_constant(self): |
|
return (method_to_operator(method))(get_constant(self), other) |
|
if is_constant(other): |
|
other = get_constant(other) |
|
other_node = to_node(self.node, other) |
|
if other_node is NotImplemented: |
|
return NotImplemented |
|
ret = wrap_node(getattr(self.node, method_attr)(other_node)) |
|
return get_constant(ret) if is_constant(ret) else ret |
|
|
|
def rbinary_magic_impl(self, other): |
|
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): |
|
return NotImplemented |
|
self = promote(self) |
|
other = promote(other) |
|
self, other = promote2(self, other) |
|
if is_constant(self): |
|
return (method_to_operator(method))(get_constant(self), other) |
|
if is_constant(other): |
|
other = get_constant(other) |
|
other_node = to_node(self.node, other) |
|
if other_node is NotImplemented: |
|
return NotImplemented |
|
ret = wrap_node(getattr(other_node, method_attr)(self.node)) |
|
return get_constant(ret) if is_constant(ret) else ret |
|
|
|
if method in unary_magic_methods: |
|
setattr(user_type, f"__{method}__", unary_magic_impl) |
|
elif method in unary_nonmagic_methods: |
|
orig = getattr(user_type, method) |
|
setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) |
|
elif method == "sym_ite": |
|
|
|
def sym_ite_magic_impl(pred, then_val, else_val): |
|
pred_node = pred.node |
|
then_node = to_node(pred_node, then_val) |
|
else_node = to_node(pred_node, else_val) |
|
if then_node is NotImplemented or else_node is NotImplemented: |
|
return NotImplemented |
|
assert ( |
|
isinstance(then_node, SymNode) |
|
and isinstance(else_node, SymNode) |
|
and then_node.pytype == else_node.pytype |
|
) |
|
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) |
|
return get_constant(ret) if ret.node.is_constant() else ret |
|
|
|
setattr(user_type, f"__{method}__", sym_ite_magic_impl) |
|
elif method == "round": |
|
|
|
def round_magic_impl(self, ndigits=None): |
|
if is_constant(self): |
|
return builtins.round(get_constant(self), ndigits) |
|
|
|
return wrap_node(getattr(self.node, method)(ndigits)) |
|
|
|
setattr(user_type, f"__{method}__", round_magic_impl) |
|
else: |
|
setattr(user_type, f"__{method}__", binary_magic_impl) |
|
if method in reflectable_magic_methods: |
|
setattr(user_type, f"__r{method}__", rbinary_magic_impl) |
|
|
|
|
|
for method, func in magic_methods.items(): |
|
if method in only_bool_magic_methods: |
|
_make_user_magic(method, SymBool) |
|
continue |
|
if method in only_float_magic_methods: |
|
_make_user_magic(method, SymFloat) |
|
continue |
|
if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: |
|
_make_user_magic(method, SymBool) |
|
_make_user_magic(method, SymInt) |
|
_make_user_magic(method, SymFloat) |
|
|
|
del method |
|
del func |
|
|