|
|
|
from __future__ import annotations |
|
|
|
import collections |
|
import contextlib |
|
import dataclasses |
|
import enum |
|
import functools |
|
import inspect |
|
import io |
|
import itertools |
|
import json |
|
import logging |
|
import math |
|
import operator |
|
import os |
|
import platform |
|
import shutil |
|
import sys |
|
import tempfile |
|
import textwrap |
|
import time |
|
import unittest |
|
from datetime import datetime |
|
from io import StringIO |
|
from pathlib import Path |
|
from typing import ( |
|
Any, |
|
Callable, |
|
Dict, |
|
Generic, |
|
Iterable, |
|
List, |
|
NamedTuple, |
|
Optional, |
|
Protocol, |
|
Set, |
|
Tuple, |
|
TypeVar, |
|
Union, |
|
ValuesView, |
|
) |
|
from typing_extensions import Concatenate, ParamSpec |
|
from unittest import mock |
|
|
|
import sympy |
|
|
|
import torch |
|
import torch._export |
|
import torch.utils._pytree as pytree |
|
from torch._dynamo.device_interface import get_interface_for_device |
|
from torch._dynamo.utils import detect_fake_mode |
|
from torch.autograd import DeviceType |
|
from torch.autograd.profiler_util import EventList |
|
from torch.fx.passes.shape_prop import ShapeProp |
|
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing |
|
from torch.utils._sympy.symbol import make_symbol, SymT |
|
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges |
|
from . import config |
|
from .runtime.runtime_utils import cache_dir, ceildiv as runtime_ceildiv |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
_T = TypeVar("_T") |
|
VarRanges = Dict[sympy.Expr, sympy.Expr] |
|
|
|
GPU_ALIGN_BYTES = 16 |
|
|
|
ALIGN_BYTES = 64 |
|
assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" |
|
|
|
|
|
def _align(nbytes): |
|
"""Round up to the nearest multiple of ALIGN_BYTES""" |
|
return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES |
|
|
|
|
|
def _is_aligned(v: sympy.Expr): |
|
"""v can be statically proven to be a multiple of ALIGN_BYTES""" |
|
if isinstance(v, (sympy.Add, sympy.Max)): |
|
return all(map(_is_aligned, v.args)) |
|
return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES |
|
|
|
|
|
class align(sympy.Function): |
|
"""Symbolically round up to the nearest multiple of ALIGN_BYTES""" |
|
|
|
nargs = (1,) |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, value): |
|
if isinstance(value, (int, sympy.Integer)): |
|
return _align(int(value)) |
|
if _is_aligned(value): |
|
return value |
|
|
|
|
|
def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: |
|
""" |
|
Returns benchmark results by examining torch profiler events. |
|
This could be more accurate as it doesn't count CPU side overhead. |
|
However, this also requires manually excluding irrelevant event, e.g. |
|
vectorized_elementwise_kernel which is used to fill L2 cache, |
|
various CUDA events, etc, so could also be fragile. |
|
""" |
|
|
|
fn() |
|
torch.cuda.synchronize() |
|
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") |
|
|
|
|
|
start_event = torch.cuda.Event(enable_timing=True) |
|
end_event = torch.cuda.Event(enable_timing=True) |
|
start_event.record() |
|
for _ in range(5): |
|
cache.zero_() |
|
fn() |
|
end_event.record() |
|
torch.cuda.synchronize() |
|
estimate_ms = start_event.elapsed_time(end_event) / 5 |
|
|
|
|
|
n_warmup = max(1, int(warmup / estimate_ms)) |
|
n_repeat = max(1, int(rep / estimate_ms)) |
|
|
|
|
|
for _ in range(n_warmup): |
|
fn() |
|
|
|
with torch.profiler.profile( |
|
activities=[ |
|
torch.profiler.ProfilerActivity.CUDA, |
|
] |
|
) as p: |
|
|
|
for i in range(n_repeat): |
|
|
|
cache.zero_() |
|
|
|
fn() |
|
|
|
torch.cuda.synchronize() |
|
|
|
log.debug("raw events") |
|
log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) |
|
|
|
filtered_events = EventList( |
|
[ |
|
event |
|
for event in p.events() |
|
if event.device_type == DeviceType.CUDA and event.name != "Context Sync" |
|
] |
|
) |
|
if len(filtered_events) % n_repeat != 0: |
|
raise RuntimeError( |
|
"Failed to divide all profiling events into #repeat groups. " |
|
"#CUDA events: %d, #repeats: %s", |
|
len(filtered_events), |
|
n_repeat, |
|
) |
|
num_event_per_group = len(filtered_events) / n_repeat |
|
actual_events = EventList( |
|
[ |
|
event |
|
for i, event in enumerate(filtered_events) |
|
if i % num_event_per_group != 0 |
|
] |
|
) |
|
actual_events._build_tree() |
|
actual_events = actual_events.key_averages() |
|
|
|
log.debug("profiling time breakdown") |
|
log.debug(actual_events.table(row_limit=-1)) |
|
|
|
res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat |
|
log.debug("profiling results: %s ms", res) |
|
return res |
|
|
|
|
|
@functools.lru_cache(None) |
|
def has_torchvision_roi_align() -> bool: |
|
try: |
|
from torchvision.ops import roi_align |
|
|
|
torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta") |
|
return roi_align is not None and hasattr( |
|
getattr(torch.ops, "torchvision", None), "roi_align" |
|
) |
|
except ImportError: |
|
return False |
|
except RuntimeError as e: |
|
assert "torchvision::nms does not exist" in str(e) |
|
return False |
|
|
|
|
|
def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: |
|
if device is None: |
|
return torch.tensor(0.0).device |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
if device.type not in ("cpu", "meta") and device.index is None: |
|
device_interface = get_interface_for_device(device.type) |
|
return torch.device(device.type, index=device_interface.Worker.current_device()) |
|
return device |
|
|
|
|
|
def sympy_product(it): |
|
return functools.reduce(operator.mul, it, sympy.Integer(1)) |
|
|
|
|
|
def sympy_dot(seq1, seq2): |
|
assert len(seq1) == len(seq2) |
|
return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) |
|
|
|
|
|
def unique(it: Iterable[_T]) -> ValuesView[_T]: |
|
return {id(x): x for x in it}.values() |
|
|
|
|
|
def ceildiv( |
|
numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] |
|
) -> Union[int, sympy.Expr]: |
|
if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): |
|
return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) |
|
|
|
|
|
|
|
assert isinstance(numer, int) and isinstance( |
|
denom, int |
|
), f"{numer}: {type(numer)}, {denom}: {type(denom)}" |
|
return runtime_ceildiv(numer, denom) |
|
|
|
|
|
def _type_of(key): |
|
|
|
|
|
|
|
|
|
if key is None: |
|
return "*i8" |
|
dtype_str = str(key).split(".")[-1] |
|
tys = { |
|
"bool": "i1", |
|
"float8e4nv": "fp8e4nv", |
|
"float8e5": "fp8e5", |
|
"float8e4b15": "fp8e4b15", |
|
"float8e4b15x4": "fp8e4b15x4", |
|
"float8_e4m3fn": "fp8e4nv", |
|
"float8_e5m2": "fp8e5", |
|
"float16": "fp16", |
|
"bfloat16": "bf16", |
|
"float32": "fp32", |
|
"float64": "fp64", |
|
"int8": "i8", |
|
"int16": "i16", |
|
"int32": "i32", |
|
"int64": "i64", |
|
"uint8": "u8", |
|
"uint16": "u16", |
|
"uint32": "u32", |
|
"uint64": "u64", |
|
} |
|
|
|
for v in list(tys.values()): |
|
tys[v] = v |
|
return key if isinstance(key, str) else f"*{tys[dtype_str]}" |
|
|
|
|
|
def convert_shape_to_inductor( |
|
lst: Iterable[Union[int, torch.SymInt]] |
|
) -> List[sympy.Expr]: |
|
""" |
|
Gets the shape and stride of a tensor. For non-symbolic tensors, this is |
|
trivial. But for symbolic tensors, we need to map from SymIntNode into |
|
sympy.Expr. |
|
""" |
|
return [ |
|
i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst |
|
] |
|
|
|
|
|
def convert_shape_to_symint( |
|
lst: Iterable[Union[int, sympy.Expr]] |
|
) -> List[Union[int, torch.SymInt]]: |
|
""" |
|
Takes a list of shapes from Inductor and converts them into symints (or just |
|
ints if all shapes are static). |
|
""" |
|
from .virtualized import V |
|
|
|
return [ |
|
i |
|
if isinstance(i, int) |
|
else int(i) |
|
if isinstance(i, sympy.Integer) |
|
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) |
|
for i in lst |
|
] |
|
|
|
|
|
def is_view(op: torch._ops.OpOverload): |
|
""" |
|
Does this op overload have aliasing |
|
""" |
|
assert isinstance(op, torch._ops.OpOverload) |
|
return any(a.alias_info is not None for a in op._schema.arguments) |
|
|
|
|
|
def is_pointwise_use(use): |
|
if not use.op == "call_function": |
|
return False |
|
|
|
if not ( |
|
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem |
|
): |
|
return False |
|
|
|
if use.target is operator.getitem or is_view(use.target): |
|
return all(is_pointwise_use(u) for u in use.users) |
|
|
|
return torch.Tag.pointwise in use.target.tags |
|
|
|
|
|
def gen_gm_and_inputs(target, args, kwargs): |
|
g = torch.fx.Graph() |
|
g_args = [] |
|
a_args = [] |
|
for n, arg in enumerate(args): |
|
if isinstance(arg, torch.Tensor): |
|
g_args.append(g.placeholder(f"arg{n}")) |
|
a_args.append(arg) |
|
else: |
|
g_args.append(arg) |
|
assert all(not isinstance(x, torch.Tensor) for x in kwargs.values()) |
|
node = g.call_function(target, tuple(g_args), kwargs) |
|
if ( |
|
len(target._schema.returns) == 1 |
|
and str(target._schema.returns[0].type) == "Tensor" |
|
): |
|
node = (node,) |
|
g.output(node) |
|
|
|
gm = torch.fx.GraphModule({}, g) |
|
return gm, a_args |
|
|
|
|
|
def synchronize(device: str = "cuda"): |
|
if device == "cpu": |
|
return |
|
device_interface = get_interface_for_device(device) |
|
if device_interface.is_available(): |
|
device_interface.synchronize() |
|
|
|
|
|
def timed( |
|
model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda" |
|
) -> float: |
|
synchronize(device) |
|
torch.manual_seed(1337) |
|
t0 = time.perf_counter() |
|
for _ in range(times): |
|
result = model(*example_inputs) |
|
synchronize(device) |
|
t1 = time.perf_counter() |
|
|
|
assert result is not None |
|
return t1 - t0 |
|
|
|
|
|
def print_performance( |
|
fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda" |
|
): |
|
timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) |
|
took = torch.median(timings) / times |
|
print(f"{took / baseline:.6f}") |
|
return took |
|
|
|
|
|
def precompute_method(obj: Any, method: str): |
|
"""Replace obj.method() with a new method that returns a precomputed constant.""" |
|
result = getattr(obj, method)() |
|
setattr(obj, method, lambda: result) |
|
|
|
|
|
def precompute_methods(obj: Any, methods: List[str]): |
|
"""Replace methods with new methods that returns a precomputed constants.""" |
|
for method in methods: |
|
precompute_method(obj, method) |
|
|
|
|
|
def cmp(a, b) -> int: |
|
return int(a > b) - int(a < b) |
|
|
|
|
|
def pad_listlike(x, size): |
|
if len(x) == 1: |
|
return type(x)([x[0]]) * size |
|
else: |
|
return x |
|
|
|
|
|
|
|
def tuple_sorted(x): |
|
if len(x) == 0: |
|
return [] |
|
|
|
def sort_func(elem): |
|
if isinstance(elem, str): |
|
return elem |
|
else: |
|
|
|
|
|
return elem.get_name() |
|
|
|
return sorted(x, key=sort_func) |
|
|
|
|
|
P = ParamSpec("P") |
|
RV = TypeVar("RV", covariant=True) |
|
|
|
|
|
class CachedMethod(Protocol, Generic[P, RV]): |
|
@staticmethod |
|
def clear_cache(self) -> None: |
|
... |
|
|
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: |
|
... |
|
|
|
|
|
|
|
def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: |
|
key = f"__{fn.__name__}_cache" |
|
|
|
@functools.wraps(fn) |
|
def wrapper(self): |
|
if not hasattr(self, key): |
|
setattr(self, key, fn(self)) |
|
return getattr(self, key) |
|
|
|
def clear_cache(self): |
|
if hasattr(self, key): |
|
delattr(self, key) |
|
|
|
wrapper.clear_cache = clear_cache |
|
return wrapper |
|
|
|
|
|
def aggregate_origins(node_schedule): |
|
from . import ir |
|
|
|
if isinstance(node_schedule, list): |
|
return functools.reduce( |
|
operator.or_, |
|
[ |
|
node.node.origins |
|
for node in node_schedule |
|
if hasattr(node, "node") and node.node |
|
], |
|
set(), |
|
) |
|
elif isinstance(node_schedule, ir.ExternKernel): |
|
return node_schedule.origins |
|
else: |
|
return set() |
|
|
|
|
|
def get_fused_kernel_name(node_schedule, descriptive_names): |
|
all_origins = aggregate_origins(node_schedule) |
|
if descriptive_names == "original_aten": |
|
|
|
sources = [ |
|
origin.meta["original_aten"]._overloadpacket.__name__ |
|
for origin in all_origins |
|
if origin.op == "call_function" |
|
and "original_aten" in origin.meta |
|
and origin.meta["original_aten"] is not None |
|
] |
|
sources = sorted(set(sources)) |
|
elif descriptive_names == "torch": |
|
|
|
sources = [] |
|
for origin in all_origins: |
|
if origin.op == "call_function" and "source_fn_stack" in origin.meta: |
|
source_fn = origin.meta["source_fn_stack"][-1] |
|
if isinstance(source_fn[1], str): |
|
sources.append(source_fn[1]) |
|
else: |
|
sources.append(source_fn[1].__name__) |
|
sources = sorted(set(sources)) |
|
elif descriptive_names == "inductor_node": |
|
sources = [ |
|
origin.name for origin in all_origins if origin.op == "call_function" |
|
] |
|
else: |
|
raise NotImplementedError |
|
sources = sources |
|
return "_".join(["fused"] + sources) |
|
|
|
|
|
def get_kernel_metadata(node_schedule, wrapper): |
|
all_origins = aggregate_origins(node_schedule) |
|
inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] |
|
|
|
from_node_dict = collections.defaultdict(list) |
|
original_aten_dict = collections.defaultdict(list) |
|
for node in inductor_nodes: |
|
if "original_aten" in node.meta and node.meta["original_aten"] is not None: |
|
key = str(node.meta["original_aten"]._overloadpacket) |
|
original_aten_dict[key].append(node.name) |
|
if "from_node" in node.meta: |
|
key = node.meta["from_node"][0][0] |
|
from_node_dict[key].append(node.name) |
|
metadata = ( |
|
f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], " |
|
f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]" |
|
) |
|
|
|
detailed_metadata = [] |
|
for original_node, nodes in sorted(from_node_dict.items()): |
|
detailed_metadata.append( |
|
f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" |
|
) |
|
return metadata, "\n".join(detailed_metadata) |
|
|
|
|
|
def dominated_nodes( |
|
initial_queue: Iterable[torch.fx.Node], skip_filter=None |
|
) -> Set[torch.fx.Node]: |
|
"""Returns the set of nodes whose values depend on those within initial_queue""" |
|
initial_queue = list(initial_queue) |
|
dominated_set = set(initial_queue) |
|
|
|
while initial_queue: |
|
node = initial_queue.pop() |
|
for user in node.users: |
|
if skip_filter and skip_filter(user): |
|
continue |
|
if user not in dominated_set: |
|
dominated_set.add(user) |
|
initial_queue.append(user) |
|
|
|
return dominated_set |
|
|
|
|
|
def gather_origins(args, kwargs): |
|
import itertools |
|
|
|
from . import ir |
|
|
|
def is_unrealized_node(n): |
|
if isinstance(n, ir.TensorBox): |
|
return is_unrealized_node(n.data) |
|
if isinstance(n, ir.StorageBox): |
|
return is_unrealized_node(n.data) |
|
return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) |
|
|
|
kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] |
|
arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] |
|
return set(itertools.chain(*arg_origins, *kwarg_origins)) |
|
|
|
|
|
def sympy_str(expr: sympy.Expr) -> str: |
|
""" |
|
Normal sympy str is very slow, this is a lot faster. The result are |
|
somewhat worse, as it doesn't do as much simplification. So don't |
|
use this for final codegen. |
|
""" |
|
if isinstance(expr, sympy.Symbol): |
|
return expr.name |
|
if isinstance(expr, sympy.Add): |
|
return " + ".join(map(sympy_str, expr.args)) |
|
if isinstance(expr, sympy.Mul): |
|
return " * ".join(map(sympy_str, expr.args)) |
|
|
|
if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)): |
|
return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" |
|
return str(expr) |
|
|
|
|
|
def get_bounds_index_expr(index): |
|
from .virtualized import V |
|
|
|
|
|
if ( |
|
config.compute_all_bounds |
|
and (fx_node := getattr(V.interpreter, "current_node", None)) |
|
and fx_node.target != "index_expr" |
|
): |
|
return bound_sympy(index) |
|
else: |
|
return ValueRanges.unknown() |
|
|
|
|
|
def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol: |
|
""" |
|
Used to generate an integer-nonnegative symbol. |
|
""" |
|
|
|
|
|
assert prefix != SymT.SIZE |
|
|
|
|
|
return make_symbol(prefix, idx, integer=True, nonnegative=True) |
|
|
|
|
|
def generate_assert(check): |
|
return (check or config.debug_index_asserts) and config.assert_indirect_indexing |
|
|
|
|
|
def sympy_index_symbol(name: str) -> sympy.Symbol: |
|
""" |
|
Used to generate an integer-nonnegative symbol. |
|
""" |
|
|
|
|
|
assert name[0] != "s" |
|
|
|
|
|
return sympy.Symbol(name, integer=True, nonnegative=True) |
|
|
|
|
|
def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr: |
|
""" |
|
When the passed replacement symbol v is a string, it is converted to a symbol with name v that |
|
have the same replaced expression integer and nonnegative properties. |
|
""" |
|
|
|
def to_symbol(replaced, replacement): |
|
assert isinstance(replaced, sympy.Expr) |
|
if isinstance(replacement, str): |
|
return sympy.Symbol( |
|
replacement, |
|
integer=replaced.is_integer, |
|
nonnegative=replaced.is_nonnegative, |
|
) |
|
else: |
|
return replacement |
|
|
|
|
|
return sympy.sympify(expr).xreplace( |
|
{k: to_symbol(k, v) for k, v in replacements.items()} |
|
) |
|
|
|
|
|
def is_symbolic(a: Any) -> bool: |
|
return isinstance(a, torch.SymInt) or ( |
|
isinstance(a, torch.Tensor) |
|
and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) |
|
) |
|
|
|
|
|
def any_is_symbolic(*args: Any) -> bool: |
|
return any(is_symbolic(a) for a in args) |
|
|
|
|
|
def get_first_incompatible_cudagraph_node(gm): |
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols |
|
|
|
forbidden_set = { |
|
"aten._fused_moving_avg_obs_fq_helper.default", |
|
"aten._fused_moving_avg_obs_fq_helper_functional.default", |
|
"aten.multinomial.default", |
|
"fbgemm.dense_to_jagged.default", |
|
"fbgemm.jagged_to_padded_dense.default", |
|
"run_and_save_rng_state", |
|
"run_with_rng_state", |
|
"aten._local_scalar_dense", |
|
|
|
|
|
|
|
|
|
"aten._assert_scalar", |
|
} |
|
if torch.are_deterministic_algorithms_enabled(): |
|
forbidden_set.update( |
|
{ |
|
"aten._unsafe_index_put.default", |
|
"aten.index_put.default", |
|
"aten.index_put_.default", |
|
"aten.scatter.src", |
|
"aten.scatter.reduce", |
|
"aten.scatter.value_reduce", |
|
"aten.scatter_add_", |
|
"aten.scatter_add.default", |
|
"aten.scatter_reduce.two", |
|
"aten.scatter_reduce_.two", |
|
"aten.scatter_reduce.two_out", |
|
} |
|
) |
|
for node in gm.graph.nodes: |
|
if str(node.target) in forbidden_set: |
|
return node |
|
if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): |
|
return node |
|
return None |
|
|
|
|
|
def has_incompatible_cudagraph_ops(gm): |
|
return get_first_incompatible_cudagraph_node(gm) is not None |
|
|
|
|
|
def output_node(gm: torch.fx.GraphModule): |
|
"""Get the output node from an FX graph""" |
|
last_node = next(iter(reversed(gm.graph.nodes))) |
|
assert last_node.op == "output" |
|
return last_node |
|
|
|
|
|
_registered_caches: List[Any] = [] |
|
|
|
|
|
def clear_on_fresh_inductor_cache(obj: Any): |
|
""" |
|
Use this decorator to register any caches that should be cache_clear'd |
|
with fresh_inductor_cache(). |
|
""" |
|
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): |
|
raise AttributeError(f"{obj} does not have a cache_clear method") |
|
|
|
_registered_caches.append(obj) |
|
return obj |
|
|
|
|
|
def clear_inductor_caches(): |
|
""" |
|
Clear all registered caches. |
|
""" |
|
for obj in _registered_caches: |
|
obj.cache_clear() |
|
|
|
|
|
@contextlib.contextmanager |
|
def fresh_inductor_cache(cache_entries=None): |
|
""" |
|
Contextmanager that provides a clean tmp cachedir for inductor. |
|
|
|
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes |
|
generated with this cache instance. |
|
""" |
|
clear_inductor_caches() |
|
|
|
inductor_cache_dir = tempfile.mkdtemp() |
|
try: |
|
with mock.patch.dict( |
|
os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} |
|
): |
|
triton_cache_dir = os.path.join(inductor_cache_dir, "triton") |
|
with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): |
|
yield |
|
if isinstance(cache_entries, dict): |
|
assert len(cache_entries) == 0, "expected empty cache_entries dict" |
|
if os.path.exists(triton_cache_dir): |
|
files = os.listdir(triton_cache_dir) |
|
cache_entries.update( |
|
{ |
|
f: os.path.getsize(os.path.join(triton_cache_dir, f)) |
|
for f in files |
|
if ".lock" not in f |
|
} |
|
) |
|
shutil.rmtree(inductor_cache_dir) |
|
except Exception: |
|
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) |
|
raise |
|
finally: |
|
clear_inductor_caches() |
|
|
|
|
|
def argsort(seq) -> List[int]: |
|
|
|
getter = seq.__getitem__ |
|
a_r = range(len(seq)) |
|
return list(reversed(sorted(a_r, key=getter, reverse=True))) |
|
|
|
|
|
@functools.lru_cache(8) |
|
def get_dtype_size(dtype): |
|
return torch.empty((), dtype=dtype).element_size() |
|
|
|
|
|
class LineContext(NamedTuple): |
|
context: Any |
|
|
|
|
|
class IndentedBuffer: |
|
tabwidth = 4 |
|
|
|
def __init__(self, initial_indent=0): |
|
self._lines = [] |
|
self._indent = initial_indent |
|
|
|
def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]: |
|
buf = StringIO() |
|
p = 1 |
|
linemap = [] |
|
for line in self._lines: |
|
if isinstance(line, DeferredLineBase): |
|
line = line() |
|
if line is None: |
|
continue |
|
elif isinstance(line, LineContext): |
|
linemap.append((p, line.context)) |
|
continue |
|
assert isinstance(line, str) |
|
buf.write(line) |
|
buf.write("\n") |
|
p += 1 + line.count("\n") |
|
return buf.getvalue(), linemap |
|
|
|
def getvalue(self) -> str: |
|
v, _ = self.getvaluewithlinemap() |
|
return v |
|
|
|
def getrawvalue(self) -> str: |
|
buf = StringIO() |
|
for line in self._lines: |
|
if isinstance(line, DeferredLineBase): |
|
line = line() |
|
if line is None: |
|
continue |
|
elif isinstance(line, LineContext): |
|
continue |
|
assert isinstance(line, str) |
|
|
|
if line.endswith("\\"): |
|
buf.write(line[:-1]) |
|
else: |
|
buf.write(line) |
|
buf.write("\n") |
|
return buf.getvalue() |
|
|
|
def clear(self): |
|
self._lines.clear() |
|
|
|
def __bool__(self): |
|
return bool(self._lines) |
|
|
|
def prefix(self): |
|
return " " * (self._indent * self.tabwidth) |
|
|
|
def newline(self): |
|
self.writeline("\n") |
|
|
|
def writeline(self, line): |
|
if isinstance(line, LineContext): |
|
self._lines.append(line) |
|
elif isinstance(line, DeferredLineBase): |
|
self._lines.append(line.with_prefix(self.prefix())) |
|
elif line.strip(): |
|
self._lines.append(f"{self.prefix()}{line}") |
|
else: |
|
self._lines.append("") |
|
|
|
def writelines(self, lines): |
|
for line in lines: |
|
self.writeline(line) |
|
|
|
def indent(self, offset=1): |
|
@contextlib.contextmanager |
|
def ctx(): |
|
self._indent += offset |
|
try: |
|
yield |
|
finally: |
|
self._indent -= offset |
|
|
|
return ctx() |
|
|
|
def do_indent(self, offset=1): |
|
self._indent += offset |
|
|
|
def do_unindent(self, offset=1): |
|
self._indent -= offset |
|
|
|
def splice(self, other_code, strip=False): |
|
if isinstance(other_code, IndentedBuffer): |
|
dedent = float("inf") |
|
for line in other_code._lines: |
|
if not isinstance(line, LineContext) and line: |
|
dedent = min(dedent, len(line) - len(line.lstrip())) |
|
if math.isinf(dedent): |
|
dedent = 0 |
|
for line in other_code._lines: |
|
if isinstance(line, LineContext): |
|
self._lines.append(line) |
|
else: |
|
IndentedBuffer.writeline(self, line[int(dedent) :]) |
|
else: |
|
other_code = textwrap.dedent(other_code) |
|
if strip: |
|
other_code = other_code.lstrip() |
|
if not other_code: |
|
return |
|
other_code = other_code.rstrip() |
|
for line in other_code.split("\n"): |
|
self.writeline(line) |
|
|
|
def map(self, func: Callable[[Any], Any]) -> IndentedBuffer: |
|
res = IndentedBuffer(initial_indent=self._indent) |
|
res._lines = [func(line) for line in self._lines] |
|
return res |
|
|
|
def __repr__(self): |
|
return f"{type(self)}({self.getvalue()})" |
|
|
|
def __add__(self, other): |
|
assert self._indent == other._indent |
|
res = IndentedBuffer(initial_indent=self._indent) |
|
res.writelines(self._lines) |
|
res.writelines(other._lines) |
|
return res |
|
|
|
|
|
class FakeIndentedBuffer(IndentedBuffer): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def __getattribute__(self, name): |
|
if name == "__class__": |
|
return object.__getattribute__(self, name) |
|
raise RuntimeError( |
|
f"Tried to call self.{name} on FakeIndentedBuffer. This buffer" |
|
"is currently used on TritonTemplateKernel to prevent actual" |
|
"writes to the body without explicitly specifying the body with" |
|
"`TritonTemplateKernel.set_subgraph_body(name)`" |
|
) |
|
|
|
|
|
@contextlib.contextmanager |
|
def restore_stdout_stderr(initial_stdout, initial_stderr): |
|
try: |
|
yield |
|
finally: |
|
sys.stdout = initial_stdout |
|
sys.stderr = initial_stderr |
|
|
|
|
|
class DeferredLineBase: |
|
"""A line that can be 'unwritten' at a later time""" |
|
|
|
def __init__(self, line): |
|
if not line.strip(): |
|
line = "" |
|
self.line = line |
|
|
|
def __call__(self) -> Optional[str]: |
|
"""Returns either self.line or None to indicate the line has been 'unwritten'""" |
|
raise NotImplementedError |
|
|
|
def _new_line(self, line: str) -> DeferredLineBase: |
|
"""Returns a new deferred line with the same condition""" |
|
raise NotImplementedError |
|
|
|
def with_prefix(self, prefix): |
|
return self._new_line(f"{prefix}{self.line}") |
|
|
|
def lstrip(self): |
|
return self._new_line(self.line.lstrip()) |
|
|
|
def __getitem__(self, index): |
|
return self._new_line(self.line[index]) |
|
|
|
def __bool__(self): |
|
return bool(self.line) |
|
|
|
def __len__(self): |
|
return len(self.line) |
|
|
|
|
|
@functools.lru_cache(None) |
|
def is_big_gpu(index) -> bool: |
|
min_sms = 68 |
|
avail_sms = torch.cuda.get_device_properties(index).multi_processor_count |
|
if avail_sms < min_sms: |
|
log.warning( |
|
"Not enough SMs to use max_autotune_gemm mode", |
|
extra={"min_sms": min_sms, "avail_sms": avail_sms}, |
|
) |
|
return False |
|
return True |
|
|
|
|
|
def use_max_autotune() -> bool: |
|
return ( |
|
config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache |
|
) |
|
|
|
|
|
def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: |
|
return ( |
|
use_max_autotune() |
|
and layout.device.type == "cuda" |
|
and layout.dtype in allowed_layout_dtypes |
|
and is_big_gpu(layout.device.index or 0) |
|
) |
|
|
|
|
|
def _use_autotune_backend(backend: str) -> bool: |
|
return backend.upper() in [ |
|
x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") |
|
] |
|
|
|
|
|
def use_triton_template(layout, *, enable_int32=False): |
|
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] |
|
if enable_int32: |
|
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] |
|
return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( |
|
"TRITON" |
|
) |
|
|
|
|
|
def use_cutlass_template(layout, m, n, k): |
|
from .virtualized import V |
|
|
|
gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) |
|
if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: |
|
return False |
|
from .codegen.cuda.cutlass_utils import try_import_cutlass |
|
|
|
|
|
if torch.version.hip: |
|
return False |
|
|
|
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] |
|
res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( |
|
"CUTLASS" |
|
) |
|
|
|
if res: |
|
if not try_import_cutlass(): |
|
log.warning( |
|
"Failed to import CUTLASS lib. Please check whether " |
|
"_inductor.config.cuda.cutlass_dir is set correctly. " |
|
"Skipping CUTLASS backend for now." |
|
) |
|
return False |
|
return res |
|
|
|
|
|
def _use_template_for_cpu(layout): |
|
return use_max_autotune() and layout.device.type == "cpu" |
|
|
|
|
|
def use_cpp_packed_gemm_template(layout, mat1, mat2): |
|
from . import ir |
|
from .codegen.cpp_micro_gemm import create_micro_gemm |
|
from .kernel.mm_common import mm_args |
|
|
|
if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): |
|
return False |
|
|
|
if not config.cpp.weight_prepack: |
|
return False |
|
|
|
layout_dtypes = [torch.float32] |
|
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2) |
|
|
|
if has_free_symbols((n, k)): |
|
return False |
|
if isinstance(mat2, ir.BaseView): |
|
mat2 = mat2.unwrap_view() |
|
micro_gemm = create_micro_gemm( |
|
"micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() |
|
) |
|
|
|
return ( |
|
layout.dtype in layout_dtypes |
|
and micro_gemm is not None |
|
and n % micro_gemm.register_blocking[1] == 0 |
|
and mat1.get_stride()[-1] == 1 |
|
and isinstance(mat2, ir.StorageBox) |
|
and mat2.is_module_buffer() |
|
) |
|
|
|
|
|
def use_aten_gemm_kernels(): |
|
return not use_max_autotune() or _use_autotune_backend("ATEN") |
|
|
|
|
|
class DebugDirManager: |
|
counter = itertools.count(0) |
|
prev_debug_name: str |
|
|
|
def __init__(self): |
|
self.id = next(DebugDirManager.counter) |
|
|
|
def __enter__(self): |
|
self.prev_debug_name = torch._dynamo.config.debug_dir_root |
|
self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" |
|
torch._dynamo.config.debug_dir_root = self.new_name |
|
|
|
def __exit__(self, *args): |
|
shutil.rmtree(self.new_name) |
|
torch._dynamo.config.debug_dir_root = self.prev_debug_name |
|
|
|
|
|
def run_and_get_code(fn, *args, **kwargs): |
|
from .graph import GraphLowering |
|
|
|
compile_to_module = GraphLowering.compile_to_module |
|
source_codes: List[str] = [] |
|
|
|
def patched_compile_to_module(self): |
|
mod = compile_to_module(self) |
|
with open(mod.__file__) as f: |
|
source_codes.append(f.read()) |
|
return mod |
|
|
|
|
|
with config.patch({"fx_graph_cache": False}): |
|
with mock.patch.object( |
|
GraphLowering, "compile_to_module", patched_compile_to_module |
|
): |
|
torch._dynamo.reset() |
|
result = fn(*args, **kwargs) |
|
return result, source_codes |
|
|
|
|
|
def get_code(fn, *args, **kwargs): |
|
"""Get the inductor-generated code, but skip any actual compilation or running.""" |
|
from .graph import GraphLowering |
|
|
|
source_codes: List[str] = [] |
|
|
|
def patched_compile_to_module(self: GraphLowering): |
|
class DummyModule: |
|
"""This is empty to replace the generated triton module""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def call(self, *args, **kwargs): |
|
|
|
pass |
|
|
|
code, _ = ( |
|
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() |
|
) |
|
|
|
|
|
source_codes.append(code) |
|
return DummyModule() |
|
|
|
|
|
with config.patch({"fx_graph_cache": False}): |
|
with mock.patch.object( |
|
GraphLowering, "compile_to_module", patched_compile_to_module |
|
): |
|
torch._dynamo.reset() |
|
|
|
_ = fn(*args, **kwargs) |
|
|
|
return source_codes |
|
|
|
|
|
def get_triton_code(fn, *args, **kwargs): |
|
source_codes = get_code(fn, *args, **kwargs) |
|
|
|
assert ( |
|
1 <= len(source_codes) <= 2 |
|
), f"expected one or two code outputs got {len(source_codes)}" |
|
return source_codes[0] |
|
|
|
|
|
def run_and_get_triton_code(fn, *args, **kwargs): |
|
_, source_codes = run_and_get_code(fn, *args, **kwargs) |
|
|
|
assert ( |
|
1 <= len(source_codes) <= 2 |
|
), f"expected one or two code outputs got {len(source_codes)}" |
|
return source_codes[0] |
|
|
|
|
|
@contextlib.contextmanager |
|
def override_lowering(aten_op, override_fn): |
|
""" |
|
Override the lowering of aten_op with override_fn. |
|
The first argument of override_fn is the original lowering fn. |
|
""" |
|
from torch._inductor import lowering |
|
|
|
orig_fn = lowering.lowerings[aten_op] |
|
try: |
|
lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) |
|
yield |
|
finally: |
|
lowering.lowerings[aten_op] = orig_fn |
|
|
|
|
|
def add_scheduler_init_hook(pre_fn, post_fn=None): |
|
""" |
|
Add hook functions to be called at the beginning and end of Scheduler.__init__. |
|
Used for unit tests. |
|
""" |
|
from torch._inductor.scheduler import Scheduler |
|
|
|
orig_fn = Scheduler.__init__ |
|
|
|
def wrapper(scheduler, nodes): |
|
pre_fn(scheduler, nodes) |
|
out = orig_fn(scheduler, nodes) |
|
if post_fn: |
|
post_fn(scheduler, nodes) |
|
return out |
|
|
|
return unittest.mock.patch.object(Scheduler, "__init__", wrapper) |
|
|
|
|
|
def developer_warning(msg): |
|
""" |
|
Warnings that will be actionable for PyTorch developers, but not |
|
end users. Allows us to easily disable them in stable releases but |
|
keep them on for nightly builds. |
|
""" |
|
if config.developer_warnings: |
|
log.warning(msg) |
|
else: |
|
log.info(msg) |
|
|
|
|
|
def get_benchmark_name(): |
|
""" |
|
An experimental API used only when config.benchmark_kernel is true. |
|
|
|
The benchmark name is only available at codegen time. So we can not |
|
directly call it in benchmark_all_kernels which is run after codegen. |
|
|
|
The function assumes the argument after --only is the benchmark name. |
|
It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc |
|
scripts, this function may return None. |
|
|
|
There are 2 flavors of --only argument we need handle: |
|
1. --only model_name |
|
2. --only=model_name |
|
""" |
|
try: |
|
idx = sys.argv.index("--only") |
|
if ( |
|
idx + 1 < len(sys.argv) |
|
and len(sys.argv[idx + 1]) > 0 |
|
and sys.argv[idx + 1][0] != "-" |
|
): |
|
return sys.argv[idx + 1] |
|
except ValueError: |
|
pass |
|
|
|
for arg in sys.argv: |
|
if arg.startswith("--only="): |
|
return arg[len("--only=") :] |
|
|
|
|
|
def is_ones(items): |
|
return all(x == 1 for x in items) |
|
|
|
|
|
def is_zeros(items): |
|
return all(x == 0 for x in items) |
|
|
|
|
|
def is_cpu_device(inputs): |
|
return all( |
|
item.device == torch.device("cpu") |
|
for item in inputs |
|
if isinstance(item, torch.Tensor) |
|
) |
|
|
|
|
|
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: |
|
assert isinstance( |
|
val, sympy.Expr |
|
), "only support sympy.Expr as input to get_sympy_Expr_dtype" |
|
if val.is_integer: |
|
return torch.int64 |
|
else: |
|
return torch.float64 |
|
|
|
|
|
@contextlib.contextmanager |
|
def maybe_profile(should_profile, *args, **kwargs): |
|
if should_profile: |
|
with torch.profiler.profile(*args, **kwargs) as p: |
|
yield p |
|
else: |
|
yield |
|
|
|
|
|
def parallel_num_threads(): |
|
threads = config.cpp.threads |
|
if threads < 1: |
|
threads = torch.get_num_threads() |
|
return threads |
|
|
|
|
|
@functools.lru_cache(None) |
|
def get_device_tflops(dtype): |
|
from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops |
|
|
|
assert dtype in (torch.float16, torch.bfloat16, torch.float32) |
|
|
|
if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): |
|
|
|
from torch._utils_internal import max_clock_rate |
|
|
|
sm_clock = max_clock_rate() |
|
if dtype in (torch.float16, torch.bfloat16): |
|
return get_max_tensorcore_tflops(dtype, sm_clock) |
|
|
|
if torch.backends.cuda.matmul.allow_tf32: |
|
return get_max_tensorcore_tflops(torch.float32, sm_clock) |
|
else: |
|
return get_max_simd_tflops(torch.float32, sm_clock) |
|
else: |
|
if dtype in (torch.float16, torch.bfloat16): |
|
return get_max_tensorcore_tflops(dtype) |
|
|
|
if torch.backends.cuda.matmul.allow_tf32: |
|
return get_max_tensorcore_tflops(torch.float32) |
|
else: |
|
return get_max_simd_tflops(torch.float32) |
|
|
|
|
|
@functools.lru_cache(None) |
|
def get_gpu_dram_gbps(): |
|
from triton.testing import get_dram_gbps |
|
|
|
return get_dram_gbps() |
|
|
|
|
|
def get_gpu_shared_memory(): |
|
from triton.runtime import driver |
|
|
|
return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) |
|
|
|
|
|
def is_welford_reduction(reduction_type): |
|
return reduction_type.startswith("welford") |
|
|
|
|
|
def reduction_num_outputs(reduction_type): |
|
return 3 if is_welford_reduction(reduction_type) else 1 |
|
|
|
|
|
def is_linux() -> bool: |
|
return platform.system() == "Linux" |
|
|
|
|
|
def has_free_symbols(itr: Iterable[Any]): |
|
return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) |
|
|
|
|
|
def is_dynamic(*args): |
|
from . import ir |
|
|
|
for t in args: |
|
if isinstance(t, ir.TensorBox): |
|
if has_free_symbols(t.data.get_size()) or ( |
|
hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride()) |
|
): |
|
return True |
|
elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)): |
|
assert hasattr(t, "get_size") and hasattr(t, "get_stride") |
|
if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()): |
|
return True |
|
elif not isinstance(t, ir.IRNode): |
|
continue |
|
else: |
|
raise TypeError(f"unexpected type for is_dynamic {type(t)}") |
|
|
|
return False |
|
|
|
|
|
|
|
class Placeholder(enum.Enum): |
|
|
|
|
|
KERNEL_NAME = "KERNEL_NAME" |
|
|
|
|
|
|
|
DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" |
|
|
|
|
|
def pass_execution_and_save(func, gm, inp, msg): |
|
from .pattern_matcher import stable_topological_sort |
|
|
|
with tempfile.NamedTemporaryFile( |
|
mode="w", |
|
encoding="utf-8", |
|
delete=False, |
|
) as f: |
|
before_io = io.StringIO() |
|
after_io = io.StringIO() |
|
ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp) |
|
print(f"Before:\n{gm.graph}", file=f) |
|
print(gm.graph, file=before_io) |
|
start_time = datetime.now() |
|
func(gm.graph) |
|
time_elapsed = datetime.now() - start_time |
|
|
|
stable_topological_sort(gm.graph) |
|
gm.graph.lint() |
|
gm.recompile() |
|
|
|
print(f"After:\n{gm.graph}", file=f) |
|
print(gm.graph, file=after_io) |
|
t = before_io.getvalue() == after_io.getvalue() |
|
log.info( |
|
"%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", |
|
msg, |
|
f.name, |
|
t, |
|
time_elapsed, |
|
) |
|
|
|
|
|
def is_collective(node): |
|
from . import ir |
|
|
|
return type(node) == ir._CollectiveKernel |
|
|
|
|
|
def is_wait(node): |
|
from . import ir |
|
|
|
return type(node) == ir._WaitKernel |
|
|
|
|
|
def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int): |
|
"Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" |
|
num_rng_seed_offset_inputs = ( |
|
2 if torch._functorch.config.functionalize_rng_ops else 0 |
|
) |
|
return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs |
|
|
|
|
|
def count_tangents(fx_g: torch.fx.GraphModule): |
|
""" |
|
Infers which inputs are static for a backwards graph |
|
""" |
|
|
|
def is_saved_tensor(x): |
|
return ( |
|
"tangents" not in x.name |
|
and "bwd_seed" not in x.name |
|
and "bwd_base_offset" not in x.name |
|
) |
|
|
|
arg_count = 0 |
|
static_arg_idxs = [] |
|
for n in fx_g.graph.nodes: |
|
if n.op == "placeholder": |
|
if is_saved_tensor(n): |
|
static_arg_idxs.append(arg_count) |
|
arg_count += 1 |
|
|
|
assert static_arg_idxs == list(range(len(static_arg_idxs))) |
|
return len(static_arg_idxs) |
|
|
|
|
|
@dataclasses.dataclass |
|
class BoxedBool: |
|
value: bool |
|
|
|
def __bool__(self): |
|
return self.value |
|
|
|
@staticmethod |
|
def disable(obj): |
|
if isinstance(obj, BoxedBool): |
|
obj.value = False |
|
return obj |
|
return False |
|
|
|
|
|
@contextlib.contextmanager |
|
def collect_defined_kernels(kernel_list): |
|
from .codegen.wrapper import WrapperCodeGen |
|
|
|
orig_define_kernel = WrapperCodeGen.define_kernel |
|
|
|
def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs): |
|
nonlocal kernel_list |
|
kernel_list.append(kernel_code) |
|
return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs) |
|
|
|
with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel): |
|
yield |
|
|
|
|
|
def get_cloned_parameter_buffer_name(name: str): |
|
return name + "__original__" |
|
|
|
|
|
def is_gpu(device: str): |
|
return device in ["cuda", "xpu"] |
|
|
|
|
|
def device_need_guard(device: str): |
|
assert isinstance(device, str) |
|
return is_gpu(device) |
|
|
|
|
|
def needs_fallback_due_to_atomic_add_limitations(dtype): |
|
|
|
return dtype in {torch.int64, torch.bool, torch.bfloat16} |
|
|
|
|
|
def use_scatter_fallback( |
|
op_overload: torch._ops.OpOverload, |
|
reduction_type, |
|
self_dtype, |
|
src_dtype, |
|
src_device_type, |
|
src_is_tensor, |
|
): |
|
reduce_ty = ( |
|
"add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum" |
|
) |
|
|
|
return ( |
|
reduction_type not in {None, reduce_ty} |
|
or ( |
|
src_is_tensor |
|
and is_gpu(src_device_type) |
|
and needs_fallback_due_to_atomic_add_limitations(src_dtype) |
|
) |
|
or ( |
|
op_overload.overloadpacket == torch.ops.aten.scatter_reduce_ |
|
and reduction_type == "sum" |
|
and src_is_tensor |
|
and src_device_type == "cpu" |
|
and config.cpp.fallback_scatter_reduce_sum |
|
and (config.cpp.dynamic_threads or parallel_num_threads() != 1) |
|
) |
|
or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64}) |
|
or torch.are_deterministic_algorithms_enabled() |
|
) |
|
|
|
|
|
def dump_node_schedule(node_schedule): |
|
""" |
|
An API that can be used in pdb to dump a node_schedule. |
|
Right mainly dump the read/write dependencies but can add more as needed. |
|
""" |
|
from torch._inductor.codegen.simd import DisableReduction, EnableReduction |
|
from torch._inductor.scheduler import SchedulerNode |
|
|
|
print(f"Node schedule with {len(node_schedule)} nodes") |
|
for idx, node in enumerate(node_schedule): |
|
print(f" {idx:3}:") |
|
if node is EnableReduction: |
|
print("enable reduction") |
|
elif node is DisableReduction: |
|
print("disable reduction") |
|
elif isinstance(node, SchedulerNode): |
|
is_red = node.is_reduction() |
|
print(f"{'red' if is_red else 'pw'} scheduler node") |
|
if is_red: |
|
assert node.node is not None |
|
print(f"original reduction hint {node.node.data.reduction_hint}") |
|
print("ReadDep:") |
|
for dep in node.read_writes.reads: |
|
print(dep) |
|
print("WriteDep:") |
|
for dep in node.read_writes.writes: |
|
print(dep) |
|
else: |
|
raise RuntimeError(f"Unrecognized node type: {type(node)}") |
|
|
|
|
|
def tensor_is_aligned(tensor: torch.Tensor): |
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
tensor.storage_offset() * get_dtype_size(tensor.dtype) |
|
) % GPU_ALIGN_BYTES == 0 |
|
|
|
|
|
def should_assume_input_aligned(example_input: torch.Tensor): |
|
|
|
|
|
|
|
if not is_gpu(example_input.device.type): |
|
return False |
|
return config.assume_aligned_inputs or tensor_is_aligned(example_input) |
|
|
|
|
|
def maybe_get_suppress_shape_guards_ctx(): |
|
|
|
|
|
|
|
|
|
tracing_context = torch._guards.TracingContext.try_get() |
|
if not tracing_context: |
|
return contextlib.nullcontext() |
|
|
|
|
|
shape_env = tracing_context.fake_mode.shape_env |
|
if not shape_env: |
|
return contextlib.nullcontext() |
|
|
|
return shape_env.suppress_guards() |
|
|
|
|
|
def aoti_eager_cache_dir(namespace: str, device: str): |
|
return Path(cache_dir()) / "aoti_eager" / namespace / device |
|
|
|
|
|
def aoti_eager_op_conf_lock(op_func_name_with_overload: str): |
|
from filelock import FileLock |
|
|
|
|
|
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT |
|
|
|
op_conf_lock_file = f"{op_func_name_with_overload}.lock" |
|
lock_dir = get_lock_dir() |
|
return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) |
|
|
|
|
|
def load_aoti_eager_cache(ns: str, op_func_name_with_overload: str, device_type: str): |
|
device_kernel_cache = aoti_eager_cache_dir(ns, device_type) |
|
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" |
|
if not op_conf.exists(): |
|
return [] |
|
|
|
with aoti_eager_op_conf_lock(op_func_name_with_overload): |
|
with open(op_conf) as f: |
|
json_data = json.load(f) |
|
for item in json_data: |
|
|
|
kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] |
|
item["kernel_path"] = kernel_lib_abs_path.as_posix() |
|
|
|
|
|
if not kernel_lib_abs_path.exists(): |
|
return [] |
|
|
|
for metadata in item["meta_info"]: |
|
assert not metadata[ |
|
"is_dynamic" |
|
], "Only support static shape for now" |
|
if metadata["device_type"] == "cpu": |
|
metadata["device_index"] = -1 |
|
metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1]) |
|
|
|
return json_data |
|
|
|
|
|
def aoti_compile_with_persistent_cache( |
|
ns: str, |
|
op_func_name_with_overload: str, |
|
device_type: str, |
|
dynamic: bool, |
|
f: Callable[..., Any], |
|
args: Tuple[Any], |
|
kwargs: Dict[str, Any], |
|
*, |
|
dynamic_shapes: Optional[Dict[str, Any]] = None, |
|
options: Optional[Dict[str, Any]] = None, |
|
remove_runtime_assertions: bool = False, |
|
disable_constraint_solver: bool = False, |
|
): |
|
""" |
|
Compile the given function with persistent cache for AOTI eager mode. |
|
""" |
|
assert not dynamic, "Only support static shape for now" |
|
type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} |
|
supported_scalar_types = tuple(type_to_torch_dtype.keys()) |
|
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) |
|
if not all( |
|
isinstance(input, (supported_scalar_types, torch.Tensor)) |
|
for input in flattened_inputs |
|
): |
|
raise NotImplementedError("Only support tensor, int, float, bool for now") |
|
|
|
persistent_cache = aoti_eager_cache_dir(ns, device_type) |
|
if not persistent_cache.exists(): |
|
persistent_cache.mkdir(parents=True) |
|
|
|
persistent_cache_lib = persistent_cache / "lib" |
|
if not persistent_cache_lib.exists(): |
|
persistent_cache_lib.mkdir() |
|
|
|
with mock.patch.dict( |
|
os.environ, |
|
{"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, |
|
): |
|
try: |
|
kernel_lib_path = torch._export.aot_compile( |
|
f, |
|
args, |
|
kwargs, |
|
dynamic_shapes=dynamic_shapes, |
|
options=options, |
|
remove_runtime_assertions=remove_runtime_assertions, |
|
disable_constraint_solver=disable_constraint_solver, |
|
|
|
|
|
|
|
same_signature=False, |
|
) |
|
|
|
kernel_metadata_items = [] |
|
for input in flattened_inputs: |
|
|
|
metadata: Dict[str, Any] = {} |
|
metadata["is_dynamic"] = dynamic |
|
|
|
if isinstance(input, torch.Tensor): |
|
metadata["device_type"] = f"{input.device.type}" |
|
if is_cpu_device([input]): |
|
metadata["device_index"] = -1 |
|
else: |
|
metadata["device_index"] = input.device.index |
|
metadata["dtype"] = f"{input.dtype}" |
|
metadata["sizes"] = list(input.size()) |
|
metadata["strides"] = list(input.stride()) |
|
else: |
|
assert isinstance(input, supported_scalar_types) |
|
|
|
metadata["device_type"] = device_type |
|
metadata["device_index"] = -1 if device_type == "cpu" else 0 |
|
metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" |
|
metadata["sizes"] = [] |
|
metadata["strides"] = [] |
|
metadata["scalar_value"] = input |
|
|
|
kernel_metadata_items.append(metadata) |
|
|
|
kernel_meta_info: Dict[str, Any] = {} |
|
kernel_meta_info["meta_info"] = kernel_metadata_items |
|
kernel_meta_info["kernel_path"] = ( |
|
Path(kernel_lib_path).relative_to(persistent_cache).as_posix() |
|
) |
|
|
|
json_data = [] |
|
update_json = True |
|
op_conf = persistent_cache / f"{op_func_name_with_overload}.json" |
|
mode = "r" if op_conf.exists() else "w" |
|
with aoti_eager_op_conf_lock(op_func_name_with_overload): |
|
with open(op_conf, mode) as op_conf_file: |
|
try: |
|
json_data = json.load(op_conf_file) |
|
except Exception as e: |
|
json_data = [] |
|
|
|
assert isinstance(json_data, list) |
|
for item in json_data: |
|
assert isinstance(item, dict) |
|
|
|
if item["meta_info"] == kernel_metadata_items: |
|
update_json = False |
|
break |
|
|
|
if update_json: |
|
json_data.append(kernel_meta_info) |
|
with open(op_conf, "w") as op_conf_file: |
|
json.dump(json_data, op_conf_file, indent=4) |
|
|
|
return kernel_lib_path |
|
except Exception as e: |
|
return "" |
|
|
|
|
|
def run_and_get_cpp_code(fn, *args, **kwargs): |
|
|
|
|
|
|
|
with unittest.mock.patch.object(config, "debug", True): |
|
torch._dynamo.reset() |
|
import io |
|
import logging |
|
|
|
log_capture_string = io.StringIO() |
|
ch = logging.StreamHandler(log_capture_string) |
|
from torch._inductor.graph import output_code_log |
|
|
|
output_code_log.addHandler(ch) |
|
prev_level = output_code_log.level |
|
output_code_log.setLevel(logging.DEBUG) |
|
result = fn(*args, **kwargs) |
|
s = log_capture_string.getvalue() |
|
output_code_log.setLevel(prev_level) |
|
output_code_log.removeHandler(ch) |
|
return result, s |
|
|