# mypy: allow-untyped-defs | |
""" | |
This file contains canonical definitions for our symbol naming conventions, | |
across torch.fx.experimental.symbolic_shapes and torch._inductor. The | |
intention is: | |
1. To make it easily greppable where all the sites we use a prefix are | |
2. Make it possible to easily tell if we can introduce a new prefix without | |
introducing a conflict | |
You can occasionally test if prefixes have been hardcoded by renaming prefixes | |
in this file and seeing what breaks. | |
""" | |
from enum import auto, Enum | |
from typing import Sequence, Union | |
import sympy | |
class SymT(Enum): | |
SIZE = auto() | |
FLOAT = auto() | |
UNBACKED_INT = auto() | |
UNBACKED_FLOAT = auto() | |
# Inductor: The intermediates in inner_fn tmp0, one generated per ops call. | |
# If one of these shows up in an indexing expression, that means an | |
# indirect load is happening. | |
TMP = auto() | |
# Inductor: Placeholder variable that is later replaced with TMP | |
INDIRECT = auto() | |
# Inductor: Some size expressions are replaced with a precomputed size ps0 | |
# which is computed host side, and then directly reused in the kernel, so | |
# we don't repeatedly recompute it on device. | |
PRECOMPUTED_SIZE = auto() | |
# Inductor: An indexing variable i0 in loops IR which ranges over non-reduced | |
# dim in the loop | |
INDEX = auto() | |
# Inductor: A reduction indexing r0 variable in loops IR which ranges over | |
# reduced dim in the loop | |
RINDEX = auto() | |
# Inductor: In templated kernels torch._inductor.kernel, we have a hook to | |
# store the final output and append epilogue fusions. To do this, we must | |
# know what the indexes the outputs range over. NB: These will also | |
# advertise as INDEX, this is... probably OK? | |
TEMPLATE_INDEX = auto() | |
# Inductor: iteration domain for blockIdx.x/blockIdx.y | |
XBLOCK = auto() | |
YBLOCK = auto() | |
# Inductor: this is used solely for dynamic_reshape_indexer | |
VIEW = auto() | |
# Invariant: there must not be a prefix which is a prefix of another string, | |
# as this introduces ambiguity | |
prefix_str = { | |
SymT.SIZE: "s", # integer | |
SymT.UNBACKED_INT: "u", # integer | |
# Prefix z here is chosen to avoid false aliasing in symbol_is_type test | |
# DO NOT add a "z" type. You also need to avoid conflicts on these | |
# prefixes but this is somewhat easier to manage | |
SymT.FLOAT: "zf", | |
SymT.UNBACKED_FLOAT: "zuf", | |
SymT.TMP: "tmp", | |
SymT.PRECOMPUTED_SIZE: "ps", | |
SymT.INDEX: "i", | |
SymT.RINDEX: "r", | |
SymT.TEMPLATE_INDEX: "idx", | |
SymT.XBLOCK: "x", | |
SymT.YBLOCK: "y", | |
SymT.INDIRECT: "indirect", # false aliasing? | |
SymT.VIEW: "view", | |
} | |
def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol: | |
# TODO: maybe put the assumptions here directly | |
return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs) | |
# This type is a little wider than it should be, because free_symbols says | |
# that it contains Basic, rather than Symbol | |
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> bool: | |
assert isinstance(sym, sympy.Symbol) | |
if isinstance(prefix, SymT): | |
return sym.name.startswith(prefix_str[prefix]) | |
else: | |
return sym.name.startswith(tuple(prefix_str[p] for p in prefix)) | |
def free_symbol_is_type(e: sympy.Expr, prefix: SymT) -> bool: | |
return any(symbol_is_type(v, prefix) for v in e.free_symbols) | |