|
|
|
""" |
|
This is a simple interpreter for Sympy expressions that dispatches to |
|
classes following the torch._inductor.virtualized calling convention. |
|
For directness, the interpreter takes the handler directly rather than |
|
consulting the TLS. It does not use most of the methods on the full |
|
handler; only those with corresponding Sympy expressions. To see an example |
|
of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis. |
|
""" |
|
|
|
import functools |
|
from typing import Any, Dict, Union |
|
|
|
import sympy |
|
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom |
|
|
|
import torch |
|
from .functions import ( |
|
CeilToInt, |
|
CleanDiv, |
|
FloatPow, |
|
FloatTrueDiv, |
|
FloorDiv, |
|
FloorToInt, |
|
IntTrueDiv, |
|
IsNonOverlappingAndDenseIndicator, |
|
Mod, |
|
ModularIndexing, |
|
PowByNatural, |
|
PythonMod, |
|
RoundDecimal, |
|
RoundToInt, |
|
ToFloat, |
|
TruncToFloat, |
|
TruncToInt, |
|
Where, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache(None) |
|
def handlers(): |
|
|
|
|
|
|
|
|
|
|
|
HANDLERS = { |
|
sympy.Or: "or_", |
|
sympy.And: "and_", |
|
sympy.Eq: "eq", |
|
sympy.Ne: "ne", |
|
sympy.Lt: "lt", |
|
sympy.Gt: "gt", |
|
sympy.Le: "le", |
|
sympy.Ge: "ge", |
|
sympy.Not: "not_", |
|
IntTrueDiv: "int_truediv", |
|
FloatTrueDiv: "truediv", |
|
FloorDiv: "floordiv", |
|
CleanDiv: "floordiv", |
|
TruncToFloat: "trunc", |
|
Where: "where", |
|
sympy.Add: "add", |
|
sympy.Mul: "mul", |
|
FloatPow: "pow", |
|
PowByNatural: "pow_by_natural", |
|
|
|
|
|
|
|
|
|
|
|
|
|
sympy.Pow: "pow_by_natural", |
|
Mod: "mod", |
|
PythonMod: "mod", |
|
|
|
|
|
|
|
sympy.Mod: "mod", |
|
sympy.Abs: "abs", |
|
sympy.log: "log", |
|
sympy.exp: "exp", |
|
sympy.Min: "minimum", |
|
sympy.Max: "maximum", |
|
ModularIndexing: "modular_indexing", |
|
sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", |
|
sympy.Piecewise: "piecewise", |
|
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", |
|
RoundDecimal: "round_decimal", |
|
} |
|
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: |
|
HANDLERS[getattr(sympy, name)] = name |
|
|
|
return HANDLERS |
|
|
|
|
|
ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"} |
|
|
|
|
|
def sympy_interp( |
|
analysis, |
|
env: Dict[sympy.Symbol, Any], |
|
expr: Union[sympy.Expr, SympyBoolean], |
|
*, |
|
index_dtype=torch.int64, |
|
): |
|
|
|
dtype = None |
|
if isinstance(expr, BooleanAtom): |
|
dtype = torch.bool |
|
elif isinstance(expr, sympy.Integer): |
|
dtype = torch.int64 |
|
elif isinstance(expr, sympy.Number): |
|
dtype = torch.double |
|
|
|
if dtype is not None: |
|
return analysis.constant(expr, dtype) |
|
elif isinstance(expr, sympy.Symbol): |
|
return env[expr] |
|
|
|
|
|
if isinstance(expr, sympy.Pow) and isinstance( |
|
expr.args[1], sympy.core.numbers.Half |
|
): |
|
return analysis.sqrt(sympy_interp(analysis, env, expr.args[0])) |
|
if isinstance(expr, ToFloat): |
|
return analysis.to_dtype( |
|
sympy_interp(analysis, env, expr.args[0]), torch.float64 |
|
) |
|
|
|
|
|
args = [sympy_interp(analysis, env, arg) for arg in expr.args] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INDEX_DTYPE_HANDLERS = { |
|
TruncToInt: "trunc_to_int", |
|
sympy.floor: "floor_to_int", |
|
sympy.ceiling: "ceil_to_int", |
|
FloorToInt: "floor_to_int", |
|
CeilToInt: "ceil_to_int", |
|
RoundToInt: "round_to_int", |
|
} |
|
if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: |
|
return getattr(analysis, handler_name)(*args, index_dtype) |
|
|
|
if hasattr(expr.func, "_torch_handler_name"): |
|
handler_name = expr.func._torch_handler_name |
|
else: |
|
handler_name = handlers()[expr.func] |
|
handler = getattr(analysis, handler_name) |
|
if handler_name in ASSOCIATIVE_OPS: |
|
assert len(args) > 1 |
|
acc = handler(args[0], args[1]) |
|
for i in range(2, len(args)): |
|
acc = handler(acc, args[i]) |
|
return acc |
|
else: |
|
return handler(*args) |
|
|