Spaces:
Running
Running
import collections | |
import dataclasses | |
import itertools | |
import logging | |
import re | |
import typing | |
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | |
from unittest.mock import patch | |
import sympy | |
import torch | |
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols | |
from .codegen.common import index_prevent_reordering | |
from .utils import ( | |
get_dtype_size, | |
reduction_num_outputs, | |
sympy_index_symbol, | |
sympy_str, | |
sympy_subs, | |
VarRanges, | |
) | |
from .virtualized import OpsHandler, ReductionType, V | |
log = logging.getLogger(__name__) | |
is_indirect = re.compile(r"indirect|tmp").search | |
Dep = Union["MemoryDep", "StarDep", "WeakDep"] | |
class MemoryDep(typing.NamedTuple): | |
name: str | |
index: sympy.Expr # type: ignore[assignment] | |
var_names: Tuple[sympy.Symbol, ...] | |
size: Tuple[sympy.Expr, ...] | |
def __repr__(self): | |
return f"MemoryDep({self.name!r}, {self.index}, {self.ranges})" | |
def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]: | |
"""{c0: 128, c1: 512, ...}""" | |
return dict(zip(self.var_names, self.size)) | |
def get_numel(self) -> sympy.Expr: | |
if self.is_indirect(): | |
numel = V.graph.get_numel(self.name) | |
else: | |
vars = set(self.index.free_symbols) | |
numel = sympy.Integer(1) | |
for var, size in zip(self.var_names, self.size): | |
if var in vars: | |
numel = numel * size | |
return numel | |
def rename(self, renames: Dict[str, str]) -> "MemoryDep": | |
if self.name in renames: | |
return MemoryDep( | |
renames[self.name], self.index, var_names=self.var_names, size=self.size | |
) | |
return self | |
def numbytes_hint(self): | |
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( | |
V.graph.get_dtype(self.name) | |
) | |
def has_unbacked_symbols(self): | |
return len(free_unbacked_symbols(self.get_numel())) > 0 | |
def is_contiguous(self) -> bool: | |
return isinstance(self.index, sympy.Symbol) and self.index in self.var_names | |
def is_scalar(self) -> bool: | |
if isinstance(self.index, sympy.Symbol): | |
return self.index not in self.var_names and not self.is_indirect() | |
return isinstance(self.index, (int, sympy.Integer)) | |
def is_indirect(self) -> bool: | |
return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined] | |
class StarDep(typing.NamedTuple): | |
# depends on the entire buffer | |
name: str | |
def index(self): | |
raise NotImplementedError("StarDep does not have an index") | |
def get_numel(self) -> sympy.Expr: | |
return V.graph.get_numel(self.name) | |
def rename(self, renames: Dict[str, str]) -> "StarDep": | |
if self.name in renames: | |
return StarDep(renames[self.name]) | |
return self | |
def numbytes_hint(self): | |
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( | |
V.graph.get_dtype(self.name) | |
) | |
def has_unbacked_symbols(self): | |
return len(free_unbacked_symbols(self.get_numel())) > 0 | |
def is_contiguous(self) -> bool: | |
return False | |
def is_scalar(self) -> bool: | |
return False | |
def is_indirect(self) -> bool: | |
return False | |
# Used for tracking mutation ordering | |
# if A reads a buffer and B mutates it | |
# B must be ordered after A | |
# | |
# It is weak because if it turns out A's read is never used, we can still | |
# eliminate it | |
class WeakDep(typing.NamedTuple): | |
name: str | |
def index(self): | |
raise NotImplementedError("WeakDep does not have an index") | |
def get_numel(self) -> sympy.Expr: | |
return sympy.Integer(1) | |
def rename(self, renames: Dict[str, str]) -> "WeakDep": | |
if self.name in renames: | |
return WeakDep(renames[self.name]) | |
return self | |
def numbytes_hint(self): | |
return 1 # Purely inserted for ordering, not an actual dep | |
def has_unbacked_symbols(self): | |
return False | |
def is_contiguous(self) -> bool: | |
return False | |
class IndexExprDep(typing.NamedTuple): | |
index: sympy.Expr # type: ignore[assignment] | |
var_names: Tuple[sympy.Symbol, ...] | |
size: Tuple[sympy.Expr, ...] | |
class ReadWrites: | |
reads: Set[Dep] | |
writes: Set[Dep] | |
index_exprs: Set[IndexExprDep] | |
range_vars: Optional[List[sympy.Expr]] = None | |
var_ranges: Optional[VarRanges] = None | |
op_counts: typing.Counter[str] = dataclasses.field( | |
default_factory=collections.Counter | |
) | |
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": | |
return ReadWrites( | |
{dep.rename(renames) for dep in self.reads}, | |
{dep.rename(renames) for dep in self.writes}, | |
self.index_exprs, | |
self.range_vars, | |
self.var_ranges, | |
op_counts=self.op_counts, | |
) | |
def with_read(self, dep: Dep) -> "ReadWrites": | |
assert isinstance(dep, (WeakDep, StarDep)) | |
return ReadWrites( | |
set.union(self.reads, {dep}), | |
self.writes, | |
self.index_exprs, | |
self.range_vars, | |
self.var_ranges, | |
op_counts=self.op_counts, | |
) | |
def merge(self, other: "ReadWrites"): | |
reads = set.union(self.reads, other.reads) | |
writes = set.union(self.writes, other.writes) | |
index_exprs = set.union(self.index_exprs, other.index_exprs) | |
op_counts = collections.Counter(self.op_counts) | |
op_counts.update(other.op_counts) | |
return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts) | |
def merge_list(read_writes: List["ReadWrites"]): | |
all_writes = set.union(*[rw.writes for rw in read_writes]) | |
all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes | |
all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes]) | |
op_counts: typing.Counter[Any] = collections.Counter() | |
for rw in read_writes: | |
op_counts.update(rw.op_counts) | |
return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts) | |
def remove_reads(self, rem_reads): | |
return ReadWrites( | |
self.reads - rem_reads, | |
self.writes, | |
self.index_exprs, | |
self.range_vars, | |
self.var_ranges, | |
op_counts=self.op_counts, | |
) | |
def reads_and_writes(self): | |
return itertools.chain(self.reads, self.writes) | |
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] | |
def __init__(self, var_ranges: VarRanges, normalize: bool): | |
super().__init__() | |
self._reads: Set[Dep] = set() | |
self._writes: Set[MemoryDep] = set() | |
self._index_exprs: Set[IndexExprDep] = set() | |
self._var_ranges: VarRanges = var_ranges | |
self._normalize: bool = normalize | |
def canonicalize( | |
self, index: sympy.Expr | |
) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: | |
if not self._normalize: | |
sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] | |
var_names = tuple( | |
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1 | |
) | |
sizes = tuple(v for v in sizes if v != 1) | |
return index, var_names, sizes # type: ignore[return-value] | |
# Try to further simplify the indexes even if simplify_loops didn't | |
# convert it to the simplest form because of the interference from | |
# different indexing formulas. | |
free_symbols = index.free_symbols | |
var_ranges = { | |
k: V.graph.sizevars.simplify(v) | |
for k, v in self._var_ranges.items() | |
# TODO(jansel): explore this further normalization | |
# if k in free_symbols | |
} | |
index_vars = [*var_ranges.keys()] | |
sizes = tuple(var_ranges.values()) | |
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( | |
index_vars, | |
sizes, | |
index_prevent_reordering([index], index_vars, sizes), | |
) | |
# assign new variables each dimension to deal with numbering mismatches | |
# d0, d1, d2 could become d0, d2 -- which won't match d0, d1 | |
new_vars, add_var = var_builder(canonicalization_prefix()) | |
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) | |
index = sympy_subs(sympy.expand(index), replacement) | |
new_vars = [*new_vars.keys()] | |
new_sizes = [*new_sizes] | |
free_symbols = index.free_symbols | |
while new_vars and new_vars[-1] not in free_symbols: | |
# Reduction has last (reduced) dim in its sizes, but | |
# downstream users won't. Normalize this away. | |
new_vars.pop() | |
new_sizes.pop() | |
return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] | |
def load(self, name: str, index: sympy.Expr) -> str: | |
self._reads.add(MemoryDep(name, *self.canonicalize(index))) | |
return f"load({name}, {sympy_str(index)})" | |
def load_seed(self, name: str, index: int): | |
assert isinstance(index, int) | |
return self.load(name, sympy.Integer(index)) | |
def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str: | |
self._writes.add(MemoryDep(name, *self.canonicalize(index))) | |
return f"store({name}, {sympy_str(index)}, {value}, {mode})" | |
def store_reduction(self, name: str, index, value) -> str: | |
return self.store(name, index, f"store_reduction({value})") | |
def index_expr(self, index: sympy.Expr, dtype) -> str: | |
self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) | |
return f"index_expr({sympy_str(index)}, {dtype})" | |
def bucketize( | |
self, | |
values, | |
offsets_name: str, | |
offsets_size: sympy.Expr, | |
indexing_dtype: torch.dtype, | |
right: bool, | |
): | |
self._reads.add(StarDep(offsets_name)) | |
return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})" | |
class _OpCounter: | |
"""Shim to count how many times each op is used""" | |
def __init__(self, inner): | |
super().__init__() | |
self.parent_handler = inner | |
self._op_counts: typing.Counter[Any] = collections.Counter() | |
def __getattr__(self, name): | |
self._op_counts[name] += 1 | |
return getattr(self.parent_handler, name) | |
class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] | |
def __init__(self, var_ranges: VarRanges, normalize: bool): | |
parent_handler = _RecordLoadStoreInner( | |
var_ranges=var_ranges, normalize=normalize | |
) | |
parent_handler = _OpCounter(parent_handler) | |
super().__init__(parent_handler=parent_handler) | |
def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]: | |
cnt = itertools.count() | |
var_ranges: VarRanges = dict() | |
def add_var(length: sympy.Expr) -> sympy.Symbol: | |
v = sympy_index_symbol(f"{prefix}{next(cnt)}") | |
var_ranges[v] = length | |
return v | |
return var_ranges, add_var | |
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str): | |
var_ranges, add_var = var_builder(prefix) | |
args: List[List[sympy.Symbol]] = [] | |
for size in argsizes: | |
args.append(list(map(add_var, size))) | |
return args, var_ranges | |
def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"): | |
from .ir import SqueezeView | |
var_ranges, add_var = var_builder(prefix) | |
args: List[List[sympy.Expr]] = [] | |
new_sizes: List[List[sympy.Expr]] = [] | |
for size in argsizes: | |
new_size, reindex = SqueezeView.squeezer(size) | |
new_sizes.append(new_size) | |
args.append(reindex(list(map(add_var, new_size)))) | |
return args, var_ranges | |
def extract_read_writes( | |
fn: Callable[..., Any], | |
*argsizes: Tuple[sympy.Expr, ...], | |
normalize: bool = False, | |
prefix: str = "d", | |
): | |
args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) | |
rw = RecordLoadStore(var_ranges, normalize=normalize) | |
with V.set_ops_handler(rw): | |
fn(*args) | |
if normalize: | |
range_vars = [] # Number of vars could differ due to normalization | |
else: | |
range_vars = list(itertools.chain.from_iterable(args)) | |
inner = rw.parent_handler.parent_handler | |
return ReadWrites( | |
set(inner._reads), | |
set(inner._writes), | |
inner._index_exprs, | |
range_vars, | |
var_ranges, | |
rw.parent_handler._op_counts, | |
) | |
def extract_input_node_reduction_ranges( | |
input_node: "torch._inductor.ir.TensorBox", | |
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]: | |
""" | |
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same. | |
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes. | |
In this case, reduction_sizes of the Reduction nodes need to be the same. | |
Otherwise returns (None, None). | |
""" | |
from .ir import ComputedBuffer, Loops | |
if isinstance(input_node.data, ComputedBuffer): | |
# Input node has already been realized. Return its size and reduction_size. | |
size = input_node.get_size() | |
reduction_size = input_node.get_reduction_size() | |
if len(reduction_size) > 0: | |
return (size, reduction_size) | |
else: | |
return (None, None) | |
if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined] | |
# Other IRNodes do not have reduction_ranges. | |
return (None, None) | |
# There is one issue: what if there are views / permutations between the input node and its dependent realized nodes? | |
# The current method still uses reduction ranges from the dependent realized node, which is not ideal. | |
# Is there a way to check whether there are permutations inbetween? | |
reads = input_node.get_reads() | |
reduction_size = None | |
size = None | |
while reduction_size is None and len(reads) > 0: | |
seen = set() | |
new_reads = [] | |
for read in reads: | |
if not isinstance(read, MemoryDep): | |
continue | |
if read.name in seen: | |
continue | |
seen.add(read.name) | |
buffer = V.graph.get_buffer(read.name) | |
if buffer is None: | |
continue | |
if ( | |
isinstance(buffer, ComputedBuffer) | |
and len(buffer.get_reduction_size()) > 0 | |
): | |
if reduction_size is None: | |
reduction_size = buffer.get_reduction_size() | |
size = buffer.get_size() | |
elif ( | |
reduction_size != buffer.get_reduction_size() | |
or size != buffer.get_size() | |
): | |
return (None, None) | |
else: | |
new_reads.extend(buffer.get_reads()) | |
if reads == new_reads: | |
return (size, reduction_size) | |
else: | |
reads = new_reads | |
return (size, reduction_size) | |
def canonicalization_prefix(): | |
return "c" | |
# ops handler which computes all the free unbacked symbols for an IR | |
class FreeUnbackedSymbolsOpsHandler: | |
symbols: Set[sympy.Symbol] | |
def __init__(self): | |
self.symbols = set() | |
def __getattr__(self, name: str) -> Callable[..., Any]: | |
def inner(*args, **kwargs): | |
for a in itertools.chain(args, kwargs.values()): | |
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): | |
self.symbols |= free_unbacked_symbols(a) | |
return inner | |
def indirect_indexing(self, index_var, size, check=True) -> sympy.Symbol: | |
assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean)) | |
self.symbols |= free_unbacked_symbols(size) | |
return sympy_index_symbol(f"({str(index_var)})") | |
def frexp(self, x): | |
return (None,) * 2 | |
def reduction( | |
self, | |
dtype: torch.dtype, | |
src_dtype: torch.dtype, | |
reduction_type: ReductionType, | |
value: Union[None, Tuple[None, ...]], | |
) -> Union[None, Tuple[None, ...]]: | |
num_values = reduction_num_outputs(reduction_type) | |
return (None,) * num_values if num_values > 1 else None | |
def _typecheck_FreeUnbackedSymbolsOpsHandler( | |
h: FreeUnbackedSymbolsOpsHandler, | |
) -> OpsHandler[None]: | |
return h | |
def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None): | |
from .ir import FlexibleLayout | |
args = [index, rindex] if rindex is not None else [index] | |
handler = FreeUnbackedSymbolsOpsHandler() | |
# NB: I cargo culted the allow_indexing patch here, I don't understand why | |
# people do this all over | |
with V.set_ops_handler(handler), patch.object( | |
FlexibleLayout, "allow_indexing", True | |
): | |
fn(*args) | |
return handler.symbols | |