|
|
|
import itertools |
|
import logging |
|
import operator |
|
import os |
|
import re |
|
import sys |
|
import time |
|
from collections import defaultdict |
|
from contextlib import contextmanager |
|
from typing import ( |
|
Any, |
|
Callable, |
|
DefaultDict, |
|
Dict, |
|
List, |
|
Optional, |
|
Set, |
|
Tuple, |
|
TYPE_CHECKING, |
|
Union, |
|
) |
|
|
|
import sympy |
|
|
|
import torch |
|
import torch._logging |
|
import torch.fx |
|
from torch._decomp import get_decompositions |
|
from torch._dynamo.utils import defake, dynamo_timed |
|
from torch._logging import LazyString, trace_structured |
|
from torch._prims_common import make_channels_last_strides_for |
|
from torch._subclasses.fake_tensor import FakeTensor |
|
from torch.fx.experimental._backward_state import BackwardState |
|
from torch.fx.experimental.sym_node import magic_methods, method_to_operator |
|
from torch.fx.experimental.symbolic_shapes import ( |
|
free_unbacked_symbols, |
|
has_free_symbols, |
|
resolve_unbacked_bindings, |
|
RuntimeAssert, |
|
ShapeEnv, |
|
SymTypes, |
|
) |
|
from torch.utils._mode_utils import no_dispatch |
|
|
|
from . import config, ir |
|
from .codegen.common import ( |
|
DeviceOpOverrides, |
|
get_device_op_overrides, |
|
get_scheduling_for_device, |
|
get_wrapper_codegen_for_device, |
|
register_backend_for_device, |
|
) |
|
from .codegen.cpp_wrapper_cpu import CppWrapperCpu |
|
from .codegen.cpp_wrapper_cuda import CppWrapperCuda |
|
from .codegen.wrapper import WrapperCodeGen |
|
from .exc import ( |
|
CppWrapperCodeGenError, |
|
LoweringException, |
|
MissingOperatorWithDecomp, |
|
MissingOperatorWithoutDecomp, |
|
) |
|
from .ir import ( |
|
Constant, |
|
FixedLayout, |
|
InputBuffer, |
|
Pointwise, |
|
Reduction, |
|
StorageBox, |
|
TensorBox, |
|
TorchBindObject, |
|
) |
|
from .lowering import ( |
|
constrain_to_fx_strides, |
|
FALLBACK_ALLOW_LIST, |
|
fallback_handler, |
|
fallback_node_due_to_unsupported_type, |
|
layout_constraints, |
|
lowerings, |
|
make_fallback, |
|
needs_realized_inputs, |
|
unsupported_output_tensor, |
|
) |
|
from .sizevars import SizeVarAllocator |
|
from .utils import ( |
|
convert_shape_to_inductor, |
|
gather_origins, |
|
get_cloned_parameter_buffer_name, |
|
get_sympy_Expr_dtype, |
|
maybe_get_suppress_shape_guards_ctx, |
|
should_assume_input_aligned, |
|
) |
|
from .virtualized import NullHandler, V |
|
|
|
if TYPE_CHECKING: |
|
from torch._higher_order_ops.effects import _EffectType |
|
|
|
log = logging.getLogger(__name__) |
|
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") |
|
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") |
|
aten = torch.ops.aten |
|
|
|
_post_grad_graph_counter = itertools.count() |
|
|
|
if config.is_fbcode(): |
|
from torch._inductor.fb.utils import log_module_code |
|
else: |
|
|
|
def log_module_code(*args, **kwargs): |
|
pass |
|
|
|
|
|
def supported_dtype_of_cpp_wrapper(dtype, cuda): |
|
supported_dtype = { |
|
torch.float32, |
|
torch.float64, |
|
torch.int64, |
|
torch.int32, |
|
torch.int16, |
|
torch.int8, |
|
torch.uint8, |
|
torch.bool, |
|
torch.bfloat16, |
|
torch.complex32, |
|
torch.complex64, |
|
torch.complex128, |
|
torch.float16, |
|
} |
|
if cuda: |
|
supported_dtype.add(torch.float8_e4m3fn) |
|
supported_dtype.add(torch.float8_e5m2) |
|
supported_dtype.add(torch.float8_e4m3fnuz) |
|
supported_dtype.add(torch.float8_e5m2fnuz) |
|
|
|
return dtype in supported_dtype |
|
|
|
|
|
def may_get_constant_buffer_dtype(constant_buffer): |
|
assert isinstance( |
|
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) |
|
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" |
|
if isinstance(constant_buffer, sympy.core.numbers.Integer): |
|
return torch.int64 |
|
|
|
if isinstance(constant_buffer, sympy.Expr): |
|
return get_sympy_Expr_dtype(constant_buffer) |
|
|
|
if constant_buffer.is_integer: |
|
return torch.int64 |
|
elif constant_buffer.is_float: |
|
return torch.float32 |
|
else: |
|
return None |
|
|
|
|
|
def is_magic_method(op): |
|
magic_ops = {method_to_operator(m) for m in magic_methods} |
|
return op in magic_ops |
|
|
|
|
|
def getattr_recursive(obj, target): |
|
target_atoms = target.split(".") |
|
attr_itr = obj |
|
for i, atom in enumerate(target_atoms): |
|
if not hasattr(attr_itr, atom): |
|
raise RuntimeError( |
|
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" |
|
) |
|
attr_itr = getattr(attr_itr, atom) |
|
return attr_itr |
|
|
|
|
|
def mark_nodes_dislike_padding(g): |
|
""" |
|
Nodes like convolution/convolution_backward want its input to be dense. |
|
If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction. |
|
|
|
The pass finds nodes that dislike padding. These are nodes that can be reached |
|
from a convolution/convolution_backward in the backward direction without |
|
going thru a reduction. |
|
""" |
|
if not config.comprehensive_padding: |
|
return |
|
ops_dislike_padding = { |
|
aten.convolution, |
|
aten.convolution_backward, |
|
} |
|
|
|
ops_like_padding = { |
|
aten.var_mean, |
|
aten.sum, |
|
aten.mean, |
|
aten.prod, |
|
aten.any, |
|
aten.amin, |
|
aten.amax, |
|
aten.min, |
|
aten.max, |
|
aten.argmin, |
|
aten.argmax, |
|
aten.scatter_reduce, |
|
} |
|
|
|
def _get_overload_packet(node): |
|
return ( |
|
node.target._overloadpacket |
|
if node.op == "call_function" and hasattr(node.target, "_overloadpacket") |
|
else None |
|
) |
|
|
|
for cur in reversed(g.nodes): |
|
op = _get_overload_packet(cur) |
|
if not op: |
|
continue |
|
if op in ops_dislike_padding: |
|
cur.meta["dislike_padding"] = True |
|
|
|
if cur.meta.get("dislike_padding", False): |
|
|
|
for prior in cur.all_input_nodes: |
|
prior_op = _get_overload_packet(prior) |
|
if not prior_op: |
|
continue |
|
if prior_op not in ops_like_padding: |
|
prior.meta["dislike_padding"] = True |
|
|
|
|
|
class GraphLowering(torch.fx.Interpreter): |
|
graph_outputs: List[ir.IRNode] |
|
|
|
def symbolic_sizes_strides(self, ex: torch.Tensor): |
|
""" |
|
Support dynamic shapes and dynamic strides by assigning variables |
|
to each dimension. We duck-shape tensors, so if two tensors |
|
have the same size they get assigned the same symbolic variable. |
|
""" |
|
if self.reuse_shape_env: |
|
return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor( |
|
ex.stride() |
|
) |
|
else: |
|
from torch._dynamo.source import ConstantSource |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source = ConstantSource( |
|
f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}" |
|
) |
|
( |
|
size, |
|
stride, |
|
_, |
|
) = self._shape_env.create_symbolic_sizes_strides_storage_offset( |
|
ex, |
|
source, |
|
) |
|
|
|
size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size] |
|
stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride] |
|
return size, stride |
|
|
|
def static_sizes_strides(self, ex: torch.Tensor): |
|
""" |
|
Primarily used to weights |
|
""" |
|
size = [sympy.Integer(i) for i in ex.size()] |
|
stride = [sympy.Integer(i) for i in ex.stride()] |
|
return size, stride |
|
|
|
def init_backend_registration(self): |
|
if get_scheduling_for_device("cpu") is None: |
|
from .codegen.cpp import CppScheduling |
|
|
|
register_backend_for_device( |
|
"cpu", CppScheduling, WrapperCodeGen, CppWrapperCpu |
|
) |
|
|
|
if get_scheduling_for_device("cuda") is None: |
|
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling |
|
|
|
|
|
register_backend_for_device( |
|
"cuda", CUDACombinedScheduling, WrapperCodeGen, CppWrapperCuda |
|
) |
|
|
|
if get_scheduling_for_device("xpu") is None: |
|
from .codegen.triton import TritonScheduling |
|
|
|
register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen) |
|
|
|
def __init__( |
|
self, |
|
gm: torch.fx.GraphModule, |
|
example_inputs: Optional[List[torch.Tensor]] = None, |
|
shape_env=None, |
|
graph_id=None, |
|
cpp_wrapper=False, |
|
aot_mode=False, |
|
user_visible_outputs=None, |
|
layout_opt=None, |
|
extern_node_serializer=None, |
|
is_inference=False, |
|
is_const_graph=False, |
|
const_output_index=None, |
|
const_code=None, |
|
const_module=None, |
|
name=None, |
|
): |
|
super().__init__(gm) |
|
self.example_inputs = example_inputs |
|
self.layout_opt = ( |
|
layout_opt |
|
if layout_opt is not None |
|
else self.decide_layout_opt(gm, is_inference=is_inference) |
|
) |
|
self.num_channels_last_conv = 0 |
|
self.is_inference = is_inference |
|
self.is_const_graph = is_const_graph |
|
self.const_code = const_code |
|
self.const_module = const_module |
|
|
|
self.extra_traceback = False |
|
if shape_env is None: |
|
shape_env = ShapeEnv() |
|
self.reuse_shape_env = False |
|
else: |
|
self._shape_env = shape_env |
|
self.reuse_shape_env = True |
|
self._shape_env = shape_env |
|
|
|
|
|
shape_env.freeze_runtime_asserts() |
|
|
|
self.ras_by_symbol: Dict[ |
|
sympy.Symbol, List[RuntimeAssert] |
|
] = shape_env.deferred_runtime_asserts.copy() |
|
self.bound_unbacked_symbols: Set[sympy.Symbol] = set() |
|
self.sizevars = SizeVarAllocator(shape_env) |
|
self.graph_input_names: List[str] = [] |
|
self.graph_inputs: Dict[str, TensorBox] = {} |
|
self.graph_inputs_original: Dict[str, InputBuffer] = {} |
|
self.device_types: Set[str] = ( |
|
const_module.device_types if const_module else set() |
|
) |
|
self.device_idxs: Set[int] = const_module.device_idxs if const_module else set() |
|
self.cuda = False |
|
self.buffers: List[ir.Buffer] = [] |
|
self.const_output_index: Dict[str, int] = ( |
|
const_output_index if const_output_index else {} |
|
) |
|
self.folded_constants: Set[str] = ( |
|
set(const_output_index.keys()) if const_output_index else set() |
|
) |
|
self.constants: Dict[str, torch.Tensor] = ( |
|
const_module.constants if const_module else {} |
|
) |
|
self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {} |
|
self.constant_reprs: Dict[str, str] = {} |
|
self.removed_buffers: Set[str] = set() |
|
self.removed_inplace_buffers: Set[str] = set() |
|
self.mutated_buffers: Set[str] = set() |
|
self.never_reuse_buffers: Set[str] = set() |
|
self.inplaced_to_remove: Set[str] = set() |
|
self.device_ops: DeviceOpOverrides = None |
|
self.wrapper_code: WrapperCodeGen = None |
|
|
|
self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] |
|
self.extern_node_serializer: Optional[ |
|
Callable[[List[ir.ExternKernelNode]], Any] |
|
] = extern_node_serializer |
|
self.current_node: torch.fx.Node = None |
|
self.lists: Dict[str, List[str]] = {} |
|
self.mutated_inputs: Set[str] = set() |
|
self.mutated_input_idxs: List[int] = [] |
|
self.name_to_buffer: Dict[str, ir.Buffer] = {} |
|
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list) |
|
self.creation_time = time.time() |
|
self.name = name |
|
self.cpp_wrapper = cpp_wrapper |
|
|
|
|
|
|
|
|
|
self.record_multi_kernel_choice = cpp_wrapper |
|
self.multi_kernel_to_choice: Dict[str, int] = {} |
|
|
|
self.aot_mode = aot_mode |
|
self.graph_id = graph_id |
|
self.post_grad_graph_id = next(_post_grad_graph_counter) |
|
self.scheduler: torch._inductor.scheduler.Scheduler = None |
|
self.nodes_prefer_channels_last = ( |
|
self.find_nodes_prefer_channels_last() if self.layout_opt else set() |
|
) |
|
mark_nodes_dislike_padding(gm.graph) |
|
self._warned_fallback = {"aten.convolution_backward"} |
|
self.user_visible_outputs = ( |
|
user_visible_outputs if user_visible_outputs is not None else {} |
|
) |
|
self.cache_key: str = "" |
|
self.cache_path: str = "" |
|
self.cache_linemap: List[ |
|
Tuple[int, str] |
|
] = ( |
|
[] |
|
) |
|
|
|
self.disable_cudagraphs_reason: Optional[str] = None |
|
|
|
|
|
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {} |
|
self.orig_gm: torch.fx.GraphModule = gm.__copy__() |
|
self.dynamo_flat_name_to_original_fqn = self.module.meta.get( |
|
"dynamo_flat_name_to_original_fqn", {} |
|
) |
|
self.allocated_constant_name = ( |
|
const_module.allocated_constant_name if const_module is not None else {} |
|
) |
|
self.init_backend_registration() |
|
|
|
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {} |
|
|
|
self.aligned_inputs: Set[str] = set() |
|
|
|
@staticmethod |
|
def decide_layout_opt(gm, *, is_inference) -> bool: |
|
""" |
|
Decide if we should enable layout optimization for this graph based on |
|
heuristics. |
|
""" |
|
if not config.layout_optimization: |
|
return False |
|
|
|
if config.force_layout_optimization: |
|
return True |
|
|
|
conv_nodes = [ |
|
n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default |
|
] |
|
nconv = len(conv_nodes) |
|
|
|
if nconv == 0: |
|
return False |
|
|
|
|
|
if ( |
|
torch.backends.mkldnn.enabled |
|
and torch.backends.mkldnn.is_available() |
|
and all( |
|
n.args[idx].meta["val"].device == torch.device("cpu") |
|
for n in conv_nodes |
|
for idx in [0, 1] |
|
) |
|
): |
|
return True |
|
|
|
|
|
|
|
|
|
if len(list(gm.graph.nodes)) >= 300 * nconv: |
|
log.debug("Skipped layout opt because only a few conv") |
|
return False |
|
|
|
if any( |
|
has_free_symbols(n.args[idx].meta["val"]) |
|
for n in conv_nodes |
|
for idx in [0, 1] |
|
): |
|
log.debug( |
|
"See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670" |
|
) |
|
return False |
|
|
|
def is_grouped(n): |
|
return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1 |
|
|
|
def is_in_out_channel(n): |
|
return ( |
|
n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) |
|
and n.args[1].meta["val"].size(2) > 1 |
|
) |
|
|
|
def is_small_channel(n): |
|
return ( |
|
n.args[1].meta["val"].size(0) <= 64 |
|
and n.args[1].meta["val"].size(1) <= 64 |
|
) |
|
|
|
|
|
if is_inference: |
|
from torch.utils.flop_counter import FlopCounterMode |
|
|
|
flop_counts: Dict[str, float] = defaultdict(float) |
|
for node in conv_nodes: |
|
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( |
|
node |
|
) |
|
|
|
if success: |
|
with FlopCounterMode(display=False) as flop_counter_mode: |
|
with V.fake_mode: |
|
node.target(*args, **kwargs) |
|
|
|
counted_flops = flop_counter_mode.get_total_flops() |
|
if is_grouped(node): |
|
node_type = "grouped" |
|
elif is_small_channel(node): |
|
node_type = "small" |
|
elif is_in_out_channel(node): |
|
node_type = "in_out" |
|
else: |
|
node_type = "default" |
|
|
|
flop_counts[node_type] += counted_flops |
|
else: |
|
log.debug("Conv inputs meta not found") |
|
|
|
|
|
|
|
|
|
GROUPED_MULTIPLIER = 1.358 |
|
DEFAULT_MULTIPLIER = 0.823 |
|
IN_OUT_MULTIPLIER = 0.725 |
|
SMALL_MULTIPLIER = 0.783 |
|
|
|
total_flops = sum(flop_counts.values()) |
|
|
|
weighted_flops = ( |
|
flop_counts["grouped"] * GROUPED_MULTIPLIER |
|
+ flop_counts["small"] * SMALL_MULTIPLIER |
|
+ flop_counts["in_out"] * IN_OUT_MULTIPLIER |
|
+ flop_counts["default"] * DEFAULT_MULTIPLIER |
|
) |
|
do_layout_opt = weighted_flops <= total_flops |
|
if not do_layout_opt: |
|
log.debug( |
|
"Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d", |
|
total_flops, |
|
weighted_flops, |
|
) |
|
return do_layout_opt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if any(map(is_grouped, conv_nodes)): |
|
log.debug( |
|
"Skip layout opt because found grouped convolution with >1 in_channels!" |
|
) |
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if any(map(is_in_out_channel, conv_nodes)): |
|
log.debug( |
|
"Skip layout opt because some convolutions have smaller out_channel" |
|
) |
|
return False |
|
|
|
|
|
|
|
if all(map(is_small_channel, conv_nodes)): |
|
log.debug("Skip layout opt because all convolution channels are too small") |
|
return False |
|
|
|
return True |
|
|
|
def qualify_name(self, name: str) -> str: |
|
"""Prepend the given name with the graph name if any.""" |
|
if self.name is not None: |
|
return f"{self.name}_{name}" |
|
return name |
|
|
|
def make_subgraph( |
|
self, |
|
gm: torch.fx.GraphModule, |
|
example_inputs: List[torch.Tensor], |
|
subgraph_name: str, |
|
) -> "GraphLowering": |
|
""" |
|
Make a subgraph of the current graph with all inherited |
|
parts, except the graph module (`gm`) and `example_inputs`. |
|
The subgraphs are lowered separately, but intended to be |
|
inlined in the parent graph's codegening. Hence the need |
|
for maintaining the same `shape_env` and other properties. |
|
The subgraph name is qualified by the parent graph's name. |
|
""" |
|
return GraphLowering( |
|
gm=gm, |
|
example_inputs=example_inputs, |
|
shape_env=self._shape_env, |
|
cpp_wrapper=self.cpp_wrapper, |
|
aot_mode=self.aot_mode, |
|
extern_node_serializer=self.extern_node_serializer, |
|
is_inference=self.is_inference, |
|
name=self.qualify_name(subgraph_name), |
|
) |
|
|
|
def find_nodes_prefer_channels_last(self): |
|
""" |
|
The rule to decide if an node prefer channels last is simple. |
|
1. if it's input/output of a convolution |
|
2. if one of its user prefers channels last |
|
|
|
We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs; |
|
Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers |
|
channels last. |
|
|
|
Consider the scenario: conv -> batch-norm -> relu -> conv |
|
Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies: |
|
1. the output of batch-norm should be channels last initially since its input is a conv's output. |
|
Forcing the batch-norm's output to be contiguous results in the first copy |
|
2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output. |
|
We need convert it to channels last layout which results in the second copy. |
|
With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies |
|
can be saved. |
|
""" |
|
output_set = set() |
|
for n in reversed(self.module.graph.nodes): |
|
if n.target == torch.ops.aten.convolution.default: |
|
output_set.add(n) |
|
continue |
|
|
|
for user in n.users: |
|
if user in output_set: |
|
output_set.add(n) |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for n in self.module.graph.nodes: |
|
if n in output_set: |
|
output_set.update(n.users) |
|
|
|
return output_set |
|
|
|
def warn_fallback(self, name): |
|
if name not in self._warned_fallback: |
|
self._warned_fallback.add(name) |
|
perf_hint_log.info("Using FallbackKernel: %s", name) |
|
|
|
def add_device_info(self, device: torch.device): |
|
self.device_types.add(device.type) |
|
if device.index is not None: |
|
self.device_idxs.add(device.index) |
|
if V.graph.current_node and device not in self.device_node_mapping: |
|
self.device_node_mapping[device] = V.graph.current_node |
|
|
|
@property |
|
def fake_mode(self): |
|
return V.fake_mode |
|
|
|
def get_buffer(self, buffer_name: str): |
|
if buffer_name in self.name_to_buffer: |
|
return self.name_to_buffer[buffer_name] |
|
if buffer_name in self.graph_inputs: |
|
return self.graph_inputs[buffer_name] |
|
if buffer_name in self.constants: |
|
data = V.graph.constants[buffer_name] |
|
return ir.ConstantBuffer( |
|
buffer_name, |
|
ir.FixedLayout( |
|
data.device, data.dtype, *V.graph.static_sizes_strides(data) |
|
), |
|
) |
|
return None |
|
|
|
def get_dtype(self, buffer_name: str): |
|
if buffer_name in self.constants: |
|
return self.constants[buffer_name].dtype |
|
if buffer_name in self.name_to_buffer: |
|
return self.name_to_buffer[buffer_name].get_dtype() |
|
if buffer_name in self.graph_inputs: |
|
return self.graph_inputs[buffer_name].get_dtype() |
|
m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) |
|
if m: |
|
return self.get_dtype(m.group(1)) |
|
raise KeyError(f"could not find {buffer_name}") |
|
|
|
def get_numel(self, buffer_name: str): |
|
from .ir import MultiOutputLayout |
|
|
|
if buffer_name in self.constants: |
|
return self.constants[buffer_name].numel() |
|
if buffer_name in self.name_to_buffer: |
|
buf = self.name_to_buffer[buffer_name] |
|
if isinstance(getattr(buf, "layout", None), MultiOutputLayout): |
|
return 1 |
|
return buf.get_numel() |
|
if buffer_name in self.graph_inputs: |
|
return self.graph_inputs[buffer_name].get_numel() |
|
raise KeyError(f"could not find {buffer_name}") |
|
|
|
@dynamo_timed |
|
def run(self, *args): |
|
return super().run(*args) |
|
|
|
def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False): |
|
name = self.qualify_name(f"buf{len(self.buffers)}") |
|
self.buffers.append(buffer) |
|
self.name_to_buffer[name] = buffer |
|
|
|
if ( |
|
not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements()) |
|
and buffer.get_device() is not None |
|
): |
|
self.add_device_info(buffer.get_device()) |
|
|
|
if set_name: |
|
buffer.name = name |
|
return name |
|
|
|
def register_list(self, buffer_names: List[str]): |
|
name = self.qualify_name("list_" + "_".join(buffer_names)) |
|
self.lists[name] = buffer_names |
|
return name |
|
|
|
def register_users_of(self, node_output): |
|
def register(value): |
|
if isinstance(value, (list, tuple)): |
|
for x in value: |
|
register(x) |
|
if isinstance(value, ir.IRNode): |
|
if ( |
|
not hasattr(value, "data") |
|
or not isinstance(value.data, ir.IRNode) |
|
or not ( |
|
hasattr(value.data, "data") |
|
and isinstance(value.data.data, ir.IRNode) |
|
) |
|
): |
|
return |
|
|
|
for read_name in value.get_read_names(): |
|
self.name_to_users[read_name].append(value) |
|
|
|
register(node_output) |
|
|
|
def mark_buffer_mutated(self, name: str): |
|
""" |
|
When a buffer is mutated we need to make sure all the reads to |
|
the old version are realized before the mutation happens. |
|
""" |
|
assert isinstance(name, str) |
|
self.mutated_buffers.add(name) |
|
|
|
if name not in self.name_to_users: |
|
return |
|
|
|
for user in self.name_to_users[name]: |
|
user.realize() |
|
|
|
def get_original_value_of_constant(self, name: str): |
|
""" |
|
In AOTI, module buffers may have been mutated during the tracing and compilation. |
|
Thus we need to read from previously stored original buffers, to make sure the |
|
generated model.so uses correct initial values. |
|
""" |
|
assert name in self.allocated_constant_name and name in self.constants, ( |
|
"Can not find the original value for " + name |
|
) |
|
orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name]) |
|
return ( |
|
self.module.meta[orig_name] |
|
if orig_name in self.module.meta |
|
else self.constants[name] |
|
) |
|
|
|
def allocate_non_dup_const_name(self, name, data): |
|
orig_name = name |
|
if not config.aot_inductor.use_runtime_constant_folding: |
|
for constant_name, value in self.constants.items(): |
|
if ( |
|
not data.is_mkldnn |
|
and data.size() == value.size() |
|
and data.stride() == value.stride() |
|
and data.dtype == value.dtype |
|
and data.device == value.device |
|
and data.untyped_storage().data_ptr() |
|
== value.untyped_storage().data_ptr() |
|
and data.storage_offset() == value.storage_offset() |
|
): |
|
return constant_name |
|
|
|
if name is None: |
|
name = f"constant{len(self.constants)}" |
|
if name[0].isdigit(): |
|
name = f"constant_{name}" |
|
name = self.qualify_name(name) |
|
|
|
|
|
prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) |
|
name = prefix |
|
cnt = 0 |
|
while name in self.constants: |
|
name = f"{prefix}_{cnt}" |
|
cnt += 1 |
|
self.constants[name] = data |
|
self.constant_reprs[name] = ( |
|
f"{data.device!r} {data.dtype!r} " |
|
f"{tuple(data.size())!r} {tuple(data.stride())!r} " |
|
f"{hash(data):x}" |
|
) |
|
self.allocated_constant_name[name] = orig_name |
|
return name |
|
|
|
def add_tensor_constant(self, data, name=None): |
|
new_name = self.allocate_non_dup_const_name(name, data) |
|
return TensorBox.create( |
|
ir.ConstantBuffer( |
|
new_name, |
|
FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)), |
|
) |
|
) |
|
|
|
def constant_name(self, name: str, device_override: Optional[torch.device]): |
|
""" |
|
We AOT copy constants to the devices they are needed on. |
|
If device_override doesn't match the constant's device, then |
|
copy it and return a different name. |
|
""" |
|
if self.constants[name].device == device_override or device_override is None: |
|
return name |
|
with torch.utils._python_dispatch._disable_current_modes(): |
|
|
|
|
|
return self.allocate_non_dup_const_name( |
|
f"{name}_{device_override.type}{device_override.index or 0}", |
|
self.constants[name].to(device_override), |
|
) |
|
|
|
def placeholder(self, target: str, args, kwargs): |
|
example = super().placeholder(target, args, kwargs) |
|
self.graph_input_names.append(target) |
|
if isinstance(example, SymTypes): |
|
expr = example.node.expr |
|
self.graph_inputs[target] = expr |
|
return expr |
|
elif isinstance(example, (int, bool, float)): |
|
expr = sympy.sympify(example) |
|
self.graph_inputs[target] = expr |
|
return expr |
|
if isinstance(example, BackwardState): |
|
|
|
|
|
return None |
|
assert isinstance(example, torch.Tensor), example |
|
|
|
|
|
|
|
|
|
if not example._has_symbolic_sizes_strides: |
|
|
|
sizes, strides = self.static_sizes_strides(example) |
|
else: |
|
sizes, strides = self.symbolic_sizes_strides(example) |
|
|
|
target = self.qualify_name(target) |
|
tensor = TensorBox.create( |
|
InputBuffer( |
|
target, |
|
FixedLayout(example.device, example.dtype, sizes, strides), |
|
) |
|
) |
|
self.graph_inputs[target] = tensor |
|
self.graph_inputs_original[target] = tensor.data.data |
|
self.add_device_info(example.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with maybe_get_suppress_shape_guards_ctx(): |
|
if should_assume_input_aligned(example): |
|
self.aligned_inputs.add(target) |
|
return tensor |
|
|
|
def call_function(self, target, args, kwargs): |
|
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): |
|
return super().call_function(target, args, kwargs) |
|
|
|
if hasattr(target, "_inductor_lowering_function"): |
|
|
|
return target(*args, **kwargs) |
|
|
|
def get_custom_op_layout_constraints(target, args, kwargs): |
|
|
|
|
|
|
|
layout_constraint = None |
|
if torch._C.Tag.needs_fixed_stride_order in target.tags: |
|
|
|
|
|
|
|
args, kwargs = constrain_to_fx_strides( |
|
self.current_node, *args, **kwargs |
|
) |
|
|
|
|
|
layout_constraint = constrain_to_fx_strides |
|
return layout_constraint, args, kwargs |
|
|
|
if target not in lowerings: |
|
assert isinstance( |
|
target, torch._ops.OpOverload |
|
), f"{target} is not an OpOverload" |
|
base_name = target.name().split(".")[0] |
|
if base_name in FALLBACK_ALLOW_LIST: |
|
make_fallback(target) |
|
elif config.implicit_fallbacks: |
|
layout_constraint, args, kwargs = get_custom_op_layout_constraints( |
|
target, args, kwargs |
|
) |
|
error = ( |
|
MissingOperatorWithDecomp |
|
if get_decompositions([target]) |
|
else MissingOperatorWithoutDecomp |
|
) |
|
log.info( |
|
"Creating implicit fallback for:\n%s", |
|
error.operator_str(target, args, kwargs), |
|
) |
|
make_fallback(target, layout_constraint) |
|
|
|
elif get_decompositions([target]): |
|
|
|
|
|
|
|
raise MissingOperatorWithDecomp(target, args, kwargs) |
|
else: |
|
raise MissingOperatorWithoutDecomp(target, args, kwargs) |
|
|
|
try: |
|
log.debug(" via %s", lowerings[target]) |
|
out = lowerings[target](*args, **kwargs) |
|
return out |
|
except Exception as e: |
|
raise LoweringException(e, target, args, kwargs).with_traceback( |
|
e.__traceback__ |
|
) from None |
|
|
|
@staticmethod |
|
def can_inline_constant(t: torch.Tensor) -> bool: |
|
""" |
|
True if this is a small constant attr that will be inlined. |
|
""" |
|
return len(t.shape) == 1 and t.shape[0] <= 8 |
|
|
|
def get_attr(self, target, args, kwargs): |
|
|
|
value = getattr_recursive(self.module, target) |
|
|
|
if isinstance(value, torch.fx.GraphModule): |
|
return ir.Subgraph(name=target, graph_module=value) |
|
|
|
if isinstance(value, torch._C.ScriptObject): |
|
self.torchbind_constants[target] = value |
|
self.constant_reprs[target] = "" |
|
return TorchBindObject(target, value) |
|
|
|
if ( |
|
config.aot_inductor.use_runtime_constant_folding |
|
or config.always_keep_tensor_constants |
|
or unsupported_output_tensor(value) |
|
): |
|
return self.add_tensor_constant(value, target) |
|
|
|
with no_dispatch(): |
|
if value.shape == (): |
|
return Constant(value.item(), value.dtype, value.device) |
|
if self.can_inline_constant(value): |
|
|
|
from .lowering import tensor |
|
|
|
return tensor(value.tolist(), dtype=value.dtype, device=value.device) |
|
|
|
return self.add_tensor_constant(value, target) |
|
|
|
def call_module(self, target, args, kwargs): |
|
raise AssertionError |
|
|
|
def call_method(self, target, args, kwargs): |
|
raise AssertionError |
|
|
|
def output(self, target, args, kwargs): |
|
result = super().output(target, args, kwargs) |
|
if not isinstance(result, (tuple, list)): |
|
|
|
result = (result,) |
|
assert isinstance(result, (tuple, list)), type(result) |
|
assert all( |
|
isinstance( |
|
x, |
|
( |
|
TensorBox, |
|
ir.Constant, |
|
type(None), |
|
ir.ConstantBuffer, |
|
sympy.Expr, |
|
sympy.logic.boolalg.Boolean, |
|
int, |
|
ir.EffectfulKernel, |
|
), |
|
) |
|
for x in result |
|
), result |
|
|
|
fx_node_args = V.graph.current_node.args[0] |
|
if not isinstance(fx_node_args, (tuple, list)): |
|
|
|
fx_node_args = (fx_node_args,) |
|
result = [ir.ExternKernel.realize_input(x) for x in result] |
|
result_correct_strides = [] |
|
|
|
assert len(fx_node_args) == len(result) |
|
for r, fx_node in zip(result, fx_node_args): |
|
if not isinstance(r, (ir.TensorBox, ir.BaseView)): |
|
result_correct_strides.append(r) |
|
else: |
|
|
|
|
|
result_correct_strides.append( |
|
self.try_match_insignificant_strides( |
|
r, fx_node.meta["val"].stride() |
|
) |
|
) |
|
|
|
self.graph_outputs = result_correct_strides |
|
value: ir.IRNode |
|
for name, value in self.graph_inputs.items(): |
|
assert isinstance( |
|
value, (TensorBox, sympy.Expr) |
|
), f"Unsupported inductor graph input type: {type(value)}" |
|
if not isinstance(value, TensorBox): |
|
continue |
|
value.realize() |
|
assert isinstance(value, TensorBox) |
|
value = value.data |
|
assert isinstance(value, ir.StorageBox) |
|
value_storage_box = value |
|
value = value.data |
|
if not isinstance(value, InputBuffer) or value.get_name() != name: |
|
|
|
ir.MutationLayoutSHOULDREMOVE.realize_into( |
|
value, self.graph_inputs_original[name] |
|
) |
|
|
|
try: |
|
ind = self.graph_outputs.index(value_storage_box) |
|
self.graph_outputs[ind] = self.graph_inputs_original[name] |
|
except ValueError: |
|
pass |
|
|
|
self.finalize() |
|
log.debug( |
|
"Force channels last inputs for %d conv for the current graph with id %d", |
|
self.num_channels_last_conv, |
|
self.graph_id if self.graph_id is not None else -1, |
|
) |
|
|
|
def finalize(self): |
|
for buf in self.buffers: |
|
buf.decide_layout() |
|
|
|
@contextmanager |
|
def set_current_node(self, node: torch.fx.Node): |
|
old = self.current_node |
|
try: |
|
self.current_node = node |
|
yield |
|
finally: |
|
self.current_node = old |
|
|
|
def try_match_insignificant_strides( |
|
self, |
|
tensor, |
|
meta_strides_inp: Tuple[Union[int, torch.SymInt], ...], |
|
) -> ir.TensorBox: |
|
""" |
|
Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant |
|
dimensions - size 0 or 1 - will be updated. |
|
|
|
If there are real stride differences (NHWC vs NCHW) then the input will be returned. |
|
""" |
|
|
|
|
|
assert torch._inductor.ir.is_storage_and_layout(tensor) |
|
|
|
meta_strides = [ |
|
s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_strides_inp |
|
] |
|
|
|
if all( |
|
self.sizevars.statically_known_equals(s1, s2) |
|
for s1, s2 in zip(meta_strides, tensor.get_stride()) |
|
): |
|
return tensor |
|
|
|
def significant_strides_equal(shape, meta_strides, tensor_strides): |
|
for dim, s1, s2 in zip(shape, meta_strides, tensor_strides): |
|
if self.sizevars.statically_known_leq(dim, 1): |
|
continue |
|
|
|
if not self.sizevars.statically_known_equals(s1, s2): |
|
return False |
|
|
|
return True |
|
|
|
if not significant_strides_equal( |
|
tensor.get_size(), meta_strides, tensor.get_stride() |
|
): |
|
return tensor |
|
|
|
storage, old_layout = torch._inductor.ir.as_storage_and_layout(tensor) |
|
new_stride = list(old_layout.stride) |
|
for i, s in enumerate(tensor.get_size()): |
|
if self.sizevars.statically_known_leq(s, 1): |
|
new_stride[i] = meta_strides[i] |
|
|
|
new_layout = torch._inductor.ir.FixedLayout( |
|
old_layout.device, |
|
old_layout.dtype, |
|
old_layout.size, |
|
new_stride, |
|
old_layout.offset, |
|
) |
|
return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout)) |
|
|
|
def run_node(self, n: torch.fx.Node): |
|
def debug(msg): |
|
log.debug("lowering %s %s", LazyString(n.format_node), msg) |
|
|
|
buffer_watermark = len(self.buffers) |
|
|
|
origins = {n} |
|
if n.op == "call_function": |
|
args, kwargs = self.fetch_args_kwargs_from_env(n) |
|
origins |= gather_origins(args, kwargs) |
|
with ir.IRNode.current_origins(origins), self.set_current_node( |
|
n |
|
), V.set_current_node(n): |
|
if ( |
|
n.op == "call_function" |
|
and n.target is not operator.getitem |
|
and fallback_node_due_to_unsupported_type(n) |
|
): |
|
debug("fallback_handler") |
|
result = fallback_handler(n.target, add_to_fallback_set=False)( |
|
*args, **kwargs |
|
) |
|
elif n.op == "call_function" and n.target in layout_constraints: |
|
debug("layout_constraints") |
|
args, kwargs = layout_constraints[n.target](n, *args, **kwargs) |
|
result = self.call_function(n.target, args, kwargs) |
|
elif is_magic_method(n.target): |
|
|
|
|
|
|
|
debug("is_magic_method") |
|
if isinstance( |
|
n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) |
|
): |
|
result = n.meta["val"].node.expr |
|
else: |
|
result = super().run_node(n) |
|
else: |
|
debug("") |
|
result = super().run_node(n) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
as_strided_ops = [ |
|
torch.ops.aten.as_strided.default, |
|
torch.ops.aten.as_strided_.default, |
|
torch.ops.aten.as_strided_scatter.default, |
|
torch.ops.aten.resize.default, |
|
torch.ops.aten.resize_as.default, |
|
] |
|
is_output = any(user.op == "output" for user in n.users) |
|
is_input_for_as_strided = any( |
|
user.target in as_strided_ops for user in n.users |
|
) |
|
|
|
if n.meta.get("inductor_realize_to_strides", False) and isinstance( |
|
result, TensorBox |
|
): |
|
result.realize() |
|
strides = n.meta["val"].stride() |
|
sym_strides = torch._inductor.utils.any_is_symbolic(*strides) |
|
if ( |
|
not hasattr(result, "get_stride") |
|
or result.get_stride() != strides |
|
and not sym_strides |
|
): |
|
stride_order = ir.get_stride_order(strides) |
|
result = ir.ExternKernel.require_stride_order(result, stride_order) |
|
if ( |
|
is_output |
|
and isinstance(result, TensorBox) |
|
and isinstance(result.data, ir.BaseView) |
|
): |
|
|
|
result.realize() |
|
|
|
if (is_output or is_input_for_as_strided) and isinstance( |
|
n.meta["val"], torch.Tensor |
|
): |
|
strides = n.meta["val"].stride() |
|
dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]) |
|
unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0 |
|
|
|
|
|
if not unbacked_symbols_in_strides and dense and len(strides): |
|
stride_order = ir.get_stride_order(strides) |
|
if ( |
|
len(result.get_size()) == 4 |
|
and n in self.nodes_prefer_channels_last |
|
and n.name not in self.user_visible_outputs |
|
and not is_input_for_as_strided |
|
): |
|
stride_order = ir.NHWC_STRIDE_ORDER |
|
|
|
allow_padding = ( |
|
n.name not in self.user_visible_outputs |
|
and not is_input_for_as_strided |
|
) |
|
result = ir.ExternKernel.require_stride_order( |
|
result, stride_order, allow_padding=allow_padding |
|
) |
|
|
|
|
|
|
|
num_users = len(set(n.users)) |
|
if num_users > 1 and isinstance(result, TensorBox): |
|
for user in n.users: |
|
if user.target in needs_realized_inputs: |
|
result.realize_hint() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
need_fixed_layout = [ |
|
torch.ops.aten.convolution_backward.default, |
|
torch.ops.aten.mm.default, |
|
torch.ops.aten._int_mm.default, |
|
] |
|
need_fixed_channels_last_layout = [] |
|
if not self.layout_opt: |
|
need_fixed_layout.append(torch.ops.aten.convolution.default) |
|
if torch._C._has_mkldnn: |
|
need_fixed_layout += [ |
|
torch.ops.mkldnn._linear_pointwise.default, |
|
torch.ops.mkldnn._linear_pointwise.binary, |
|
torch.ops.aten.mkldnn_rnn_layer.default, |
|
torch.ops.onednn.qlinear_pointwise.default, |
|
torch.ops.onednn.qlinear_pointwise.tensor, |
|
torch.ops.onednn.qlinear_pointwise.binary, |
|
torch.ops.onednn.qlinear_pointwise.binary_tensor, |
|
] |
|
need_fixed_channels_last_layout += [ |
|
torch.ops.mkldnn._convolution_pointwise.default, |
|
torch.ops.mkldnn._convolution_pointwise.binary, |
|
torch.ops.mkldnn._convolution_pointwise_.binary, |
|
torch.ops.mkldnn._convolution_transpose_pointwise.default, |
|
torch.ops.onednn.qconv2d_pointwise.default, |
|
torch.ops.onednn.qconv2d_pointwise.binary, |
|
] |
|
if torch._C.has_mkl: |
|
need_fixed_layout += [torch.ops.mkl._mkl_linear.default] |
|
if user.target in need_fixed_layout: |
|
result = ir.ExternKernel.require_stride_order( |
|
result, |
|
ir.get_stride_order(n.meta["val"].stride()), |
|
allow_padding=True, |
|
) |
|
if ( |
|
user.target in need_fixed_channels_last_layout |
|
and n is user.args[0] |
|
): |
|
result = ir.ExternKernel.require_stride_order( |
|
result, |
|
ir.get_stride_order( |
|
make_channels_last_strides_for(n.meta["val"].shape) |
|
), |
|
) |
|
if user.op == "output": |
|
if isinstance(result.data.data, (Pointwise, Reduction)): |
|
result.realize() |
|
|
|
|
|
result.mark_reuse(len(n.users)) |
|
|
|
|
|
if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): |
|
|
|
|
|
|
|
result.realize_hint() |
|
|
|
|
|
|
|
if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): |
|
curr = result.data.data |
|
if isinstance(curr, Pointwise): |
|
|
|
if curr.has_large_inner_fn(): |
|
result.realize() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): |
|
if isinstance(result.data.data, ir.Loops): |
|
result.data.data.origin_node = n |
|
elif isinstance(result.data.data, ir.Buffer): |
|
result.data.data.origin_node = n |
|
if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( |
|
result.data.data.data, ir.Loops |
|
): |
|
result.data.data.data.origin_node = n |
|
|
|
elif ( |
|
isinstance(result.data.data, ir.MultiOutput) |
|
and not result.data.data.indices |
|
): |
|
if isinstance(result.data.data.inputs[0], ir.Buffer): |
|
result.data.data.inputs[0].origin_node = n |
|
|
|
self.register_users_of(result) |
|
|
|
new_unbacked_defs = set() |
|
for i in range(buffer_watermark, len(self.buffers)): |
|
new_unbacked_defs |= self.buffers[i].get_unbacked_symbol_defs() |
|
|
|
def format_buffers(): |
|
r = [] |
|
for b in self.buffers[buffer_watermark:]: |
|
r.append( |
|
f"unbacked_symbol_defs={b.get_unbacked_symbol_defs()} in:\n{b}\n" |
|
) |
|
return "***\n".join(r) |
|
|
|
if n.op != "placeholder": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shape_env = V.graph.sizevars.shape_env |
|
|
|
for i0 in new_unbacked_defs: |
|
ras = self.ras_by_symbol.pop(i0, []) |
|
|
|
vr = shape_env.var_to_range[i0] |
|
if not shape_env._default_unspecified_value_range().issubset(vr): |
|
|
|
def convert(s): |
|
try: |
|
return int(s) |
|
except TypeError: |
|
return None |
|
|
|
if (lower := convert(vr.lower)) is not None: |
|
self.register_buffer( |
|
ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"), |
|
set_name=True, |
|
) |
|
if (upper := convert(vr.upper)) is not None: |
|
self.register_buffer( |
|
ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"), |
|
set_name=True, |
|
) |
|
|
|
for ra in ras: |
|
fvs = free_unbacked_symbols(ra.expr) |
|
missing = fvs - self.bound_unbacked_symbols |
|
if missing: |
|
i1 = sorted(missing, key=lambda x: str(x))[0] |
|
self.ras_by_symbol.setdefault(i1, []).append(ra) |
|
else: |
|
self.register_buffer( |
|
ir.AssertScalar(ra.expr, f"{ra.expr}"), set_name=True |
|
) |
|
|
|
self.bound_unbacked_symbols |= new_unbacked_defs |
|
|
|
unbacked_bindings = resolve_unbacked_bindings( |
|
V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {}) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
renamed_unbacked_bindings = { |
|
V.fake_mode.shape_env.unbacked_renamings.get(s, s) |
|
for s in unbacked_bindings.keys() |
|
} |
|
assert new_unbacked_defs >= renamed_unbacked_bindings, ( |
|
f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" |
|
f"fx node is: {n.format_node()}\n" |
|
f"new buffers are:\n\n{format_buffers()}" |
|
) |
|
|
|
return result |
|
|
|
def validate_can_generate_cpp_wrapper(self): |
|
if config.disable_cpp_codegen: |
|
raise CppWrapperCodeGenError("C++ codegen is disabled") |
|
|
|
if sys.platform not in ["linux", "darwin"]: |
|
raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}") |
|
|
|
for value in self.graph_inputs.values(): |
|
dtype = None |
|
if isinstance(value, TensorBox): |
|
dtype = value.get_dtype() |
|
elif isinstance( |
|
value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) |
|
): |
|
dtype = may_get_constant_buffer_dtype(value) |
|
|
|
if not supported_dtype_of_cpp_wrapper(dtype, self.cuda): |
|
raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}") |
|
|
|
def init_wrapper_code(self): |
|
self.cuda = "cuda" in self.device_types |
|
if self.cpp_wrapper: |
|
self.validate_can_generate_cpp_wrapper() |
|
|
|
device_types = self.device_types.copy() |
|
device_types.discard("cpu") |
|
device_types.discard("meta") |
|
|
|
assert len(device_types) <= 1, "Does not support mixing {}".format( |
|
"+".join(device_types) |
|
) |
|
only_cpu = len(device_types) == 0 |
|
device_type = "cpu" if only_cpu else device_types.pop() |
|
|
|
self.device_ops = get_device_op_overrides(device_type) |
|
wrapper_code_gen_cls = get_wrapper_codegen_for_device( |
|
device_type, self.cpp_wrapper |
|
) |
|
assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported" |
|
self.wrapper_code = wrapper_code_gen_cls() |
|
|
|
if self.const_module: |
|
|
|
|
|
self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter |
|
self.wrapper_code.src_to_kernel = ( |
|
self.const_module.wrapper_code.src_to_kernel |
|
) |
|
|
|
def codegen_with_cpp_wrapper(self): |
|
""" |
|
For CPU, the cpp wrapper codegen is done in one pass. |
|
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python |
|
wrapper code and run it to generate autotuned kernel binaries in the first pass; and then |
|
generate cpp wrapper code and compile it to a dynamic library in the second pass. |
|
""" |
|
if "cuda" in self.device_types: |
|
|
|
self.cpp_wrapper = False |
|
|
|
|
|
|
|
with config.patch({"triton.store_cubin": True}): |
|
compiled = self.compile_to_module().call |
|
|
|
def materialize(x): |
|
if isinstance(x, (torch.SymInt, torch.SymFloat)): |
|
|
|
return x.node.hint |
|
elif isinstance(x, FakeTensor): |
|
return defake(x) |
|
else: |
|
assert isinstance( |
|
x, torch.Tensor |
|
), "Unknown type when creating real inputs" + str(type(x)) |
|
return x |
|
|
|
tracing_context = torch._guards.TracingContext.try_get() |
|
if tracing_context is not None and not isinstance( |
|
V.real_inputs, NullHandler |
|
): |
|
if tracing_context.output_strides: |
|
tracing_context.output_strides.clear() |
|
|
|
params_flat = [ |
|
param |
|
for param in tracing_context.params_flat |
|
if param is not None |
|
] |
|
real_inputs = [ |
|
materialize(x) for x in itertools.chain(params_flat, V.real_inputs) |
|
] |
|
else: |
|
|
|
|
|
|
|
real_inputs = [ |
|
materialize(x) |
|
for x in ( |
|
self.example_inputs |
|
if isinstance(V.real_inputs, NullHandler) |
|
else V.real_inputs |
|
) |
|
] |
|
|
|
if self.mutated_inputs: |
|
from .compile_fx import clone_preserve_strides |
|
|
|
mutated_input_idxs = [ |
|
idx |
|
for idx, name in enumerate(self.graph_inputs) |
|
if name in self.mutated_inputs |
|
and isinstance(real_inputs[idx], torch.Tensor) |
|
] |
|
for idx in mutated_input_idxs: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
real_inputs[idx] = clone_preserve_strides(real_inputs[idx]) |
|
|
|
with torch.utils._python_dispatch._disable_current_modes(): |
|
compiled(real_inputs) |
|
del real_inputs |
|
|
|
|
|
|
|
self.cpp_wrapper = True |
|
self.removed_buffers.clear() |
|
self.inplaced_to_remove.clear() |
|
V.graph.sizevars.precomputed_replacements.clear() |
|
V.graph.sizevars.inv_precomputed_replacements.clear() |
|
return self.codegen() |
|
else: |
|
|
|
return self.codegen() |
|
|
|
def codegen(self): |
|
from .scheduler import Scheduler |
|
|
|
self.init_wrapper_code() |
|
|
|
self.scheduler = Scheduler(self.buffers) |
|
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) |
|
|
|
self.wrapper_code.push_codegened_graph(self) |
|
self.scheduler.codegen() |
|
result = self.wrapper_code.generate(self.is_inference) |
|
self.wrapper_code.pop_codegened_graph() |
|
return result |
|
|
|
def codegen_subgraph(self, parent_graph): |
|
""" |
|
This is a more compact version of the `codegen()` above |
|
where we codegen this graph as a subgraph of some parent |
|
graph. The parent graph is passed as an argument: the |
|
intention is to inline codegening of the subgraph in |
|
the parent graph's wrapper code (including the generated |
|
kerenls). The wrapper code is not finalized (via `.generate()` |
|
call), as this will be done in the parent graph's `codegen()`. |
|
""" |
|
from .scheduler import Scheduler |
|
|
|
self.wrapper_code = parent_graph.wrapper_code |
|
self.device_ops = parent_graph.device_ops |
|
self.cpp_wrapper = parent_graph.cpp_wrapper |
|
|
|
self.scheduler = Scheduler(self.buffers) |
|
self.scheduler.codegen() |
|
|
|
def count_bytes(self): |
|
total_bytes = 0 |
|
node_counts = [] |
|
node_runtimes = [] |
|
for node in self.scheduler.nodes: |
|
num_bytes = node.get_read_write_buffers_sizes() |
|
total_bytes += num_bytes |
|
node_counts.append((node, num_bytes // 4)) |
|
node_runtimes.append((node, node.get_estimated_runtime())) |
|
return total_bytes, node_counts, node_runtimes |
|
|
|
@dynamo_timed(phase_name="code_gen", fwd_only=False) |
|
def compile_to_module(self): |
|
from .codecache import PyCodeCache |
|
|
|
code, linemap = ( |
|
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() |
|
) |
|
|
|
output_code_log.debug("Output code: \n%s", code) |
|
try: |
|
linemap = [(line_no, node.stack_trace) for line_no, node in linemap] |
|
key, path = PyCodeCache.write(code) |
|
except Exception: |
|
trace_structured( |
|
"inductor_output_code", |
|
|
|
payload_fn=lambda: code, |
|
) |
|
raise |
|
else: |
|
trace_structured( |
|
"inductor_output_code", |
|
lambda: {"filename": path}, |
|
payload_fn=lambda: code, |
|
) |
|
|
|
mod = PyCodeCache.load_by_key_path( |
|
key, |
|
path, |
|
linemap=linemap, |
|
attrs={**self.constants, **self.torchbind_constants}, |
|
) |
|
self.cache_key = key |
|
self.cache_path = path |
|
self.cache_linemap = linemap |
|
|
|
|
|
|
|
assert mod.__file__ is not None |
|
|
|
log_module_code(mod.__file__) |
|
log.debug("Output code written to: %s", mod.__file__) |
|
output_code_log.info("Output code written to: %s", mod.__file__) |
|
if config.benchmark_kernel: |
|
print(f"Compiled module path: {mod.__file__}", file=sys.stderr) |
|
V.debug.output_code(mod.__file__) |
|
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") |
|
return mod |
|
|
|
def compile_to_fn(self): |
|
if self.aot_mode: |
|
from .codecache import AotCodeCompiler |
|
|
|
assert self.cpp_wrapper, "AOT mode only supports C++ wrapper" |
|
code, linemap = self.codegen_with_cpp_wrapper() |
|
output_code_log.debug("Output code: \n%s", code) |
|
|
|
serialized_extern_kernel_nodes = None |
|
if ( |
|
config.is_fbcode() |
|
and self.extern_kernel_nodes |
|
and self.extern_node_serializer |
|
): |
|
serialized_extern_kernel_nodes = self.extern_node_serializer( |
|
self.extern_kernel_nodes |
|
) |
|
output_code_log.debug( |
|
"Serialized Extern Kernel Nodes: \n%s", |
|
serialized_extern_kernel_nodes, |
|
) |
|
|
|
|
|
return AotCodeCompiler.compile( |
|
self, code, serialized_extern_kernel_nodes, cuda=self.cuda |
|
) |
|
else: |
|
return self.compile_to_module().call |
|
|
|
def get_output_names(self): |
|
return [ |
|
node.get_name() |
|
for node in self.graph_outputs |
|
if not isinstance(node, ir.NoneAsConstantBuffer) |
|
and not isinstance(node, ir.ShapeAsConstantBuffer) |
|
] |
|
|
|
def is_unspec_arg(self, name: str): |
|
|
|
|
|
return ( |
|
name in self.graph_inputs.keys() |
|
and self.graph_inputs[name].get_numel() == 1 |
|
and self.graph_inputs[name].get_device().type == "cpu" |
|
) |
|
|