File size: 5,528 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# mypy: allow-untyped-defs
"""
This is a simple interpreter for Sympy expressions that dispatches to
classes following the torch._inductor.virtualized calling convention.
For directness, the interpreter takes the handler directly rather than
consulting the TLS. It does not use most of the methods on the full
handler; only those with corresponding Sympy expressions. To see an example
of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
"""
import functools
from typing import Any, Dict, Union
import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch
from .functions import (
CeilToInt,
CleanDiv,
FloatPow,
FloatTrueDiv,
FloorDiv,
FloorToInt,
IntTrueDiv,
IsNonOverlappingAndDenseIndicator,
Mod,
ModularIndexing,
PowByNatural,
PythonMod,
RoundDecimal,
RoundToInt,
ToFloat,
TruncToFloat,
TruncToInt,
Where,
)
# TODO: Dedupe this with SYMPY_INTERP
@functools.lru_cache(None)
def handlers():
# TODO add CeilDiv (it doesn't appear in the index_expr)
# TODO default to some decompositions if the interpreter doesn't have them
# like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
HANDLERS = {
sympy.Or: "or_",
sympy.And: "and_",
sympy.Eq: "eq",
sympy.Ne: "ne",
sympy.Lt: "lt",
sympy.Gt: "gt",
sympy.Le: "le",
sympy.Ge: "ge",
sympy.Not: "not_",
IntTrueDiv: "int_truediv",
FloatTrueDiv: "truediv",
FloorDiv: "floordiv",
CleanDiv: "floordiv", # TODO: hmm?
TruncToFloat: "trunc",
Where: "where",
sympy.Add: "add",
sympy.Mul: "mul",
FloatPow: "pow",
PowByNatural: "pow_by_natural",
# sympy simplifies x * x into Pow(x, 2), so we need to handle this.
# Do NOT use builtin Pow for floats
# TODO: There is a hazard here, if we have float * float it will
# also get turned into Pow(float, 2) but we don't want this because
# pow_by_natural is assumed to only be integers. Probably the fix is
# to add a FloatMul to impede this optimization
sympy.Pow: "pow_by_natural",
Mod: "mod",
PythonMod: "mod", # TODO: this is wrong
# TODO: Inductor can generate these, but it's ill-specified which
# semantics were intended here. Needs to be cleaned up along with
# FloorDiv in a bigger cleanup
sympy.Mod: "mod",
sympy.Abs: "abs",
sympy.log: "log",
sympy.exp: "exp",
sympy.Min: "minimum",
sympy.Max: "maximum",
ModularIndexing: "modular_indexing",
sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
sympy.Piecewise: "piecewise",
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
RoundDecimal: "round_decimal",
}
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
HANDLERS[getattr(sympy, name)] = name
return HANDLERS
ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
def sympy_interp(
analysis,
env: Dict[sympy.Symbol, Any],
expr: Union[sympy.Expr, SympyBoolean],
*,
index_dtype=torch.int64,
):
# Handle base cases
dtype = None
if isinstance(expr, BooleanAtom):
dtype = torch.bool
elif isinstance(expr, sympy.Integer):
dtype = torch.int64
elif isinstance(expr, sympy.Number):
dtype = torch.double
if dtype is not None:
return analysis.constant(expr, dtype)
elif isinstance(expr, sympy.Symbol):
return env[expr]
# Special cases
if isinstance(expr, sympy.Pow) and isinstance(
expr.args[1], sympy.core.numbers.Half
):
return analysis.sqrt(sympy_interp(analysis, env, expr.args[0]))
if isinstance(expr, ToFloat):
return analysis.to_dtype(
sympy_interp(analysis, env, expr.args[0]), torch.float64
)
# Recursive case
args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type]
# These handlers are special because they take an extra dtype argument
# specifying what they should convert to, and we need to appropriately set
# this up when we convert from Sympy. A reasonable default when you
# are translating is to conservatively do int64, and then narrow these
# arguments later when you discover you can narrow the index range. But
# if you already know that 32-bit indexing is OK, you can directly do the
# sympy translation with index_dtype=torch.int32
INDEX_DTYPE_HANDLERS = {
TruncToInt: "trunc_to_int",
sympy.floor: "floor_to_int",
sympy.ceiling: "ceil_to_int",
FloorToInt: "floor_to_int",
CeilToInt: "ceil_to_int",
RoundToInt: "round_to_int",
}
if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
return getattr(analysis, handler_name)(*args, index_dtype)
if hasattr(expr.func, "_torch_handler_name"):
handler_name = expr.func._torch_handler_name
else:
handler_name = handlers()[expr.func]
handler = getattr(analysis, handler_name)
if handler_name in ASSOCIATIVE_OPS:
assert len(args) > 1
acc = handler(args[0], args[1])
for i in range(2, len(args)):
acc = handler(acc, args[i])
return acc
else:
return handler(*args)
|