|
|
|
import collections |
|
import dataclasses |
|
import functools |
|
import itertools |
|
import logging |
|
import math |
|
import operator |
|
import os |
|
import pprint |
|
import textwrap |
|
import typing |
|
from typing import ( |
|
Any, |
|
Counter, |
|
DefaultDict, |
|
Dict, |
|
Generic, |
|
List, |
|
Optional, |
|
Sequence, |
|
Set, |
|
Tuple, |
|
TypeVar, |
|
Union, |
|
) |
|
|
|
import sympy |
|
|
|
import torch |
|
import torch._inductor.async_compile |
|
from torch._dynamo.utils import counters, dynamo_timed |
|
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled |
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols |
|
from torch.utils._sympy.symbol import free_symbol_is_type, SymT |
|
from torch.utils._triton import has_triton |
|
|
|
from . import comms, config, dependencies, ir, metrics |
|
from .codecache import write_text |
|
from .codegen.common import get_scheduling_for_device, Kernel |
|
from .comm_analysis import estimate_nccl_collective_runtime |
|
from .dependencies import Dep, MemoryDep, StarDep, WeakDep |
|
from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout |
|
from .runtime.runtime_utils import green_text, red_text |
|
from .sizevars import SimplifyIndexing |
|
from .utils import ( |
|
cache_on_self, |
|
cmp, |
|
device_need_guard, |
|
get_device_tflops, |
|
get_dtype_size, |
|
get_gpu_dram_gbps, |
|
IndentedBuffer, |
|
is_collective, |
|
is_gpu, |
|
is_wait, |
|
sympy_product, |
|
) |
|
from .virtualized import V |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") |
|
|
|
|
|
class BaseSchedulerNode: |
|
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] |
|
read_writes: dependencies.ReadWrites |
|
unmet_dependencies: Set[Dep] |
|
|
|
def __init__(self, scheduler: "Scheduler", node: ir.Buffer) -> None: |
|
self.scheduler: Scheduler = scheduler |
|
self.node: Optional[ir.Buffer] = node |
|
self.users: List[NodeUser] = [] |
|
self.inverse_users: List[BaseSchedulerNode] = [] |
|
self.node_users: List[BaseSchedulerNode] = [] |
|
self.set_read_writes(node.get_read_writes()) |
|
self.ancestors: Set[str] = set() |
|
self.min_order: int |
|
self.max_order: int |
|
self.last_usage: Set[ |
|
str |
|
] = set() |
|
self.written = False |
|
|
|
def __repr__(self) -> str: |
|
return f"{type(self).__name__}(name={self.get_name()!r})" |
|
|
|
def debug_str(self) -> str: |
|
"""Longer form printout for trace logs""" |
|
name = self.get_name() |
|
lines = [ |
|
f"{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})", |
|
f"{name}.writes = {pformat(self.read_writes.writes)}", |
|
f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}", |
|
f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}", |
|
f"{name}.users = {self.users}", |
|
] |
|
try: |
|
lines += [ |
|
self.debug_str_extra(), |
|
] |
|
except Exception: |
|
log.warning("Ignoring error in debug_str()", exc_info=True) |
|
|
|
return "\n".join(lines).rstrip() |
|
|
|
def debug_str_extra(self) -> str: |
|
return "" |
|
|
|
def log_details(self) -> None: |
|
log.info( |
|
"%s: unmet_dependencies = %s, writes = %s", |
|
self, |
|
self.unmet_dependencies, |
|
self.read_writes.writes, |
|
) |
|
|
|
def update_mutated_names(self, renames: Dict[str, str]) -> None: |
|
self.set_read_writes(self.read_writes.rename(renames)) |
|
|
|
def add_fake_dep(self, dep: Dep) -> None: |
|
self.set_read_writes(self.read_writes.with_read(dep)) |
|
|
|
def set_users(self, users: List["NodeUser"]) -> None: |
|
|
|
result: Dict[int, NodeUser] = {} |
|
for use in users: |
|
if id(use.node) in result: |
|
result[id(use.node)] = use.merge(result[id(use.node)]) |
|
else: |
|
result[id(use.node)] = use |
|
self.users = list(result.values()) |
|
|
|
def set_last_usage( |
|
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] |
|
) -> None: |
|
used_buffers = self.used_or_aliased_buffer_names() |
|
used_buffers = {mutation_real_name.get(k, k) for k in used_buffers} |
|
self.last_usage = used_buffers - future_used_buffers |
|
|
|
def get_aliases(self) -> Sequence[str]: |
|
assert self.node is not None |
|
return self.node.get_inputs_that_alias_output() |
|
|
|
def get_mutations(self) -> List[str]: |
|
assert self.node is not None |
|
return self.node.get_mutation_names() |
|
|
|
def has_aliasing_or_mutation(self) -> bool: |
|
return bool(self.get_aliases() or self.get_mutations()) |
|
|
|
def set_read_writes(self, rw: dependencies.ReadWrites) -> None: |
|
self.read_writes = rw |
|
self.unmet_dependencies = self.read_writes.reads |
|
self.prune_deps() |
|
|
|
def op_counts(self) -> Counter[str]: |
|
return self.read_writes.op_counts |
|
|
|
def used_buffer_names(self) -> Set[str]: |
|
return { |
|
dep.name |
|
for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) |
|
} |
|
|
|
def used_or_aliased_buffer_names(self) -> Set[str]: |
|
used_names = set() |
|
|
|
deps = [ |
|
dep.name |
|
for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) |
|
] |
|
while len(deps) > 0: |
|
dep = deps.pop() |
|
used_names.add(dep) |
|
if V.graph.name_to_buffer.get(dep): |
|
for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output(): |
|
if alias not in used_names: |
|
deps.append(alias) |
|
return used_names |
|
|
|
def prune_deps(self) -> None: |
|
self.unmet_dependencies = { |
|
dep |
|
for dep in self.unmet_dependencies |
|
if dep.name not in self.scheduler.available_buffer_names |
|
} |
|
|
|
def prune_weak_deps(self) -> None: |
|
|
|
def should_prune(dep: Dep) -> bool: |
|
return isinstance(dep, WeakDep) and dep.name in V.graph.removed_buffers |
|
|
|
to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)} |
|
self.set_read_writes(self.read_writes.remove_reads(to_remove)) |
|
|
|
def prune_redundant_deps( |
|
self, name_to_fused_node: Dict[str, "BaseSchedulerNode"] |
|
) -> None: |
|
_prune_redundant_deps(self, name_to_fused_node) |
|
|
|
def get_name(self) -> str: |
|
assert self.node is not None |
|
return self.node.get_name() |
|
|
|
def get_first_name(self) -> str: |
|
return self.get_name() |
|
|
|
def get_names(self) -> Set[str]: |
|
return {self.get_name()} |
|
|
|
def get_nodes(self) -> Sequence["BaseSchedulerNode"]: |
|
return [self] |
|
|
|
def get_device(self) -> torch.device: |
|
assert self.node is not None |
|
return self.node.get_device() |
|
|
|
def is_reduction(self) -> bool: |
|
return False |
|
|
|
def is_split_scan(self) -> bool: |
|
return False |
|
|
|
def is_template(self) -> bool: |
|
return False |
|
|
|
def is_extern(self) -> bool: |
|
return False |
|
|
|
def is_foreach(self) -> bool: |
|
return False |
|
|
|
def can_inplace(self, read_dep: dependencies.Dep) -> bool: |
|
return False |
|
|
|
def has_side_effects(self) -> bool: |
|
return False |
|
|
|
def decide_inplace_update(self) -> None: |
|
""" |
|
Decide if there should be inplace updates for the node |
|
and record the decision in the active kernel. |
|
""" |
|
assert self.node is not None |
|
if not self.node.should_allocate(): |
|
return |
|
|
|
if isinstance(self, (SchedulerNode,)) and ( |
|
self.node.get_inputs_that_alias_output() or self.node.get_mutation_names() |
|
): |
|
return |
|
|
|
if ( |
|
isinstance(self, (SchedulerNode,)) |
|
and config.inplace_buffers |
|
and ( |
|
not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel) |
|
or getattr(V.kernel, "mutations", None) is not None |
|
) |
|
): |
|
from .codegen.wrapper import buffer_reuse_key |
|
|
|
ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name) |
|
|
|
for read in ordered_reads: |
|
input_node: Optional[ |
|
BaseSchedulerNode |
|
] = self.scheduler.name_to_node.get(read.name) |
|
if ( |
|
input_node |
|
and V.graph.wrapper_code.can_reuse(input_node, self) |
|
and not isinstance(input_node, NopKernelSchedulerNode) |
|
): |
|
assert input_node.users is not None |
|
remaining_uses = [ |
|
x |
|
for x in input_node.users |
|
if x.node.get_name() |
|
not in self.scheduler.available_buffer_names |
|
] |
|
if ( |
|
len(remaining_uses) == 1 |
|
and remaining_uses[0].can_inplace |
|
and remaining_uses[0].node is self |
|
and input_node.node is not None |
|
and not isinstance( |
|
input_node.node.get_layout(), |
|
( |
|
ir.MultiOutputLayout, |
|
ir.MutationLayoutSHOULDREMOVE, |
|
), |
|
) |
|
and not ( |
|
isinstance( |
|
input_node.node, (ir.FallbackKernel, ir.MultiOutput) |
|
) |
|
and len(input_node.node.get_inputs_that_alias_output()) > 0 |
|
) |
|
and buffer_reuse_key(input_node.node) |
|
== buffer_reuse_key(self.node) |
|
): |
|
|
|
if hasattr(V.kernel, "args"): |
|
|
|
|
|
|
|
V.kernel.args.make_inplace( |
|
input_node.get_name(), self.get_name() |
|
) |
|
|
|
if isinstance( |
|
V.kernel, torch._inductor.codegen.simd.SIMDKernel |
|
): |
|
V.kernel.mutations.add(input_node.get_name()) |
|
V.kernel.mutations.add(self.get_name()) |
|
|
|
|
|
self.last_usage.discard(input_node.get_name()) |
|
|
|
V.kernel.inplace_update_buffers[ |
|
self.get_name() |
|
] = input_node.get_name() |
|
break |
|
|
|
def allocate(self) -> None: |
|
assert self.node is not None |
|
if not self.node.should_allocate(): |
|
return |
|
|
|
if isinstance(self, (SchedulerNode,)) and ( |
|
self.node.get_inputs_that_alias_output() or self.node.get_mutation_names() |
|
): |
|
V.graph.wrapper_code.codegen_allocation(self.node) |
|
return |
|
|
|
|
|
if ( |
|
hasattr(V.kernel, "args") |
|
and self.get_name() in V.kernel.inplace_update_buffers |
|
): |
|
V.graph.wrapper_code.codegen_inplace_reuse( |
|
self.scheduler.name_to_node[ |
|
V.kernel.inplace_update_buffers[self.get_name()] |
|
].node, |
|
self.node, |
|
) |
|
else: |
|
V.graph.wrapper_code.codegen_allocation(self.node) |
|
|
|
def can_free(self) -> bool: |
|
|
|
assert self.node is not None |
|
if isinstance(self.node.layout, ir.NoneLayout): |
|
return False |
|
for use in self.users: |
|
if isinstance(use.node, OutputNode): |
|
return False |
|
return True |
|
|
|
def codegen_originating_info( |
|
self, buffer: IndentedBuffer, only_once: bool = True |
|
) -> None: |
|
if not config.comment_origin: |
|
return |
|
|
|
if only_once and self.written: |
|
return |
|
assert self.node is not None |
|
origins = self.node.origins |
|
out_lines = [] |
|
|
|
for o in origins: |
|
if o.op == "output": |
|
|
|
continue |
|
|
|
out_lines.append("") |
|
|
|
out_lines.append("#pragma CMT ORIGIN:") |
|
op_info_str = f"#pragma CMT {o.op} {o.target}" |
|
if "seq_nr" in o.meta: |
|
op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" |
|
out_lines.append(op_info_str) |
|
if "stack_trace" in o.meta: |
|
stack_trace = f"{o.meta['stack_trace']}" |
|
stack_trace_last_line = stack_trace.split("|")[-1] |
|
out_lines.append( |
|
"#pragma CMT " |
|
+ stack_trace_last_line.replace("{", "{{") |
|
.replace("}", "}}") |
|
.replace("\n", "\\") |
|
) |
|
out_lines.append("#pragma CMT END ORIGIN") |
|
out_lines.append("") |
|
|
|
if len(out_lines) == 0: |
|
return |
|
|
|
|
|
|
|
buffer.writelines(out_lines) |
|
self.written = True |
|
|
|
def get_read_write_buffers_sizes(self) -> int: |
|
""" |
|
Counting the number of bytes accessed for a kernel is |
|
surprisingly tricky. In particular, there is a differentiation |
|
between 'theoretical' memory accesses and practical memory |
|
accesses. For example, a layernorm kernel may actually access an |
|
input 3 times, but in theory, it only needs to access its input |
|
once (and may be optimized to do so through say, persistent |
|
reductions) |
|
|
|
Another example is that even though a buffer is passed in, we may |
|
not access the entire buffer. This may occur if we are accessing |
|
a slice of the buffer. Another tricky case is for indirect |
|
indexing, where the amount of bytes accessed depends on the |
|
values of the input. |
|
|
|
What this function aims to compute is the memory accesses for |
|
worst-case inputs, best-case optimization. What this means is |
|
that for each buffer we compute the amount of potential accesses in two ways and take the minimum. |
|
|
|
1. Numel in ranges multiplied by number of deps the buffer has |
|
2. The buffer size |
|
""" |
|
if isinstance(self, NopKernelSchedulerNode): |
|
return 0 |
|
if isinstance(self, ExternKernelSchedulerNode) and isinstance( |
|
self.node, MultiOutput |
|
): |
|
|
|
return 0 |
|
|
|
def try_size_hint(s: sympy.Expr) -> int: |
|
return V.graph.sizevars.size_hint(s, fallback=0) |
|
|
|
if isinstance(self, SchedulerNode): |
|
node_numel = try_size_hint( |
|
sympy_product(self.get_ranges()[0]) |
|
* sympy_product(self.get_ranges()[1]), |
|
) |
|
else: |
|
node_numel = int(1e9) |
|
buf_accesses = collections.defaultdict(list) |
|
for dep in self.read_writes.reads | self.read_writes.writes: |
|
buf_accesses[dep.name].append(dep) |
|
|
|
reads = {dep.name for dep in self.read_writes.reads} |
|
writes = {dep.name for dep in self.read_writes.writes} |
|
|
|
def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool: |
|
users = [ |
|
user |
|
for user in self.scheduler.name_to_node[buf].users |
|
if not user.is_weak |
|
] |
|
buf_uses = {user.node for user in users} |
|
return len(buf_uses - set(snodes)) > 0 |
|
|
|
if isinstance(self, FusedSchedulerNode): |
|
removed_buffers = { |
|
dep for dep in writes if not is_materialized(dep, self.snodes) |
|
} |
|
writes = writes - removed_buffers |
|
reads = reads - removed_buffers |
|
node_bytes = 0 |
|
|
|
for buf_name in reads | writes: |
|
buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name]) |
|
buf: Union[ir.Buffer, ir.TensorBox] |
|
if buf_name in V.graph.name_to_buffer: |
|
buf = V.graph.name_to_buffer[buf_name] |
|
elif buf_name in V.graph.graph_inputs: |
|
buf = V.graph.graph_inputs[buf_name] |
|
else: |
|
continue |
|
|
|
def get_buf_elems(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int: |
|
if not buf: |
|
return 0 |
|
|
|
|
|
if isinstance(buf.layout, MultiOutputLayout): |
|
users = self.scheduler.name_to_node[buf.get_name()].users |
|
tot = 0 |
|
for user in users: |
|
assert isinstance(user.node, BaseSchedulerNode) |
|
if isinstance(user.node.node, MultiOutput): |
|
tot += get_buf_elems(user.node.node) |
|
else: |
|
|
|
|
|
|
|
return 0 |
|
return tot |
|
else: |
|
return try_size_hint(sympy_product(buf.get_size())) |
|
|
|
buf_elems = get_buf_elems(buf) |
|
node_bytes += min(buf_elems, buf_accessed_elems) * get_dtype_size( |
|
buf.get_dtype() |
|
) |
|
|
|
return node_bytes |
|
|
|
def get_estimated_runtime(self) -> float: |
|
""" |
|
Returns estimated op runtime in nanoseconds (ns) |
|
""" |
|
layout = None |
|
dtype = None |
|
if not hasattr(self, "node") or not self.node: |
|
assert isinstance( |
|
self, (FusedSchedulerNode, ForeachKernelSchedulerNode) |
|
), f"{type(self)=}" |
|
assert self.snodes |
|
if not self.snodes[0].node: |
|
return 0 |
|
layout = self.snodes[0].node.get_layout() |
|
dtype = self.snodes[0].node.get_dtype() |
|
else: |
|
layout = self.node.get_layout() |
|
dtype = self.node.get_dtype() |
|
|
|
if layout.device is not None and not is_gpu(layout.device.type): |
|
|
|
return 0 |
|
|
|
|
|
if is_collective(self.node): |
|
assert self.node is not None |
|
try: |
|
return estimate_nccl_collective_runtime(self.node) |
|
except ValueError as e: |
|
|
|
|
|
log.info(e) |
|
return 0 |
|
|
|
elif is_wait(self.node): |
|
|
|
|
|
|
|
|
|
return 0 |
|
|
|
try: |
|
gpu_memory_bandwidth = get_gpu_dram_gbps() |
|
gpu_flops = get_device_tflops(dtype) * 10**12 |
|
except Exception: |
|
return 0 |
|
|
|
if isinstance(self, ExternKernelSchedulerNode): |
|
assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}" |
|
op = kernel_name_to_op.get( |
|
getattr(self.node, "python_kernel_name", ""), None |
|
) |
|
|
|
|
|
if op is not None: |
|
from torch._subclasses.fake_tensor import FakeTensorMode |
|
from torch.utils.flop_counter import FlopCounterMode |
|
|
|
if any( |
|
len(free_unbacked_symbols(n.get_numel())) > 0 |
|
for n in self.node.inputs |
|
): |
|
|
|
|
|
return 0 |
|
|
|
with FakeTensorMode() as fake_mode, FlopCounterMode( |
|
display=False |
|
) as flop_counter_mode, V.set_current_node( |
|
self.node.fx_node |
|
), V.set_fake_mode( |
|
fake_mode |
|
): |
|
from .ir import ir_node_to_tensor |
|
|
|
fake_inputs = [ |
|
ir_node_to_tensor(input, guard_shape=False) |
|
for input in self.node.inputs |
|
] |
|
cls = self.node.__class__ |
|
cls.process_kernel(op, *fake_inputs, **self.node.kwargs) |
|
|
|
|
|
factor = 1.0 |
|
counted_flops = flop_counter_mode.get_total_flops() |
|
counted_bytes = self.get_read_write_buffers_sizes() |
|
compute_time = (factor * counted_flops / gpu_flops) * 1e9 |
|
transfer_time = counted_bytes / gpu_memory_bandwidth |
|
|
|
|
|
return max(compute_time, transfer_time) |
|
|
|
elif isinstance(self, FusedSchedulerNode) or isinstance( |
|
self.node, ComputedBuffer |
|
): |
|
|
|
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth |
|
|
|
return 0 |
|
|
|
def get_template_node(self) -> Optional[ir.TemplateBuffer]: |
|
return None |
|
|
|
|
|
class WhyNoFuse: |
|
|
|
|
|
__slots__ = ["node1", "node2", "reason", "args"] |
|
reason: str |
|
args: Tuple[Any, ...] |
|
|
|
def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): |
|
self.node1 = node1 |
|
self.node2 = node2 |
|
|
|
def __call__(self, reason: str, *args: Any) -> None: |
|
self.reason = reason |
|
self.args = args |
|
fusion_log.debug(self) |
|
|
|
def __str__(self) -> str: |
|
return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( |
|
self.reason % self.args |
|
) |
|
|
|
|
|
def pformat(obj: Any) -> str: |
|
if isinstance(obj, set): |
|
|
|
obj = sorted(obj, key=str) |
|
result = pprint.pformat(obj, indent=4) |
|
if "\n" in result: |
|
return f"\n{textwrap.indent(result, ' '*4)}" |
|
return result |
|
|
|
|
|
class OutputNode: |
|
def __init__(self, dep: StarDep) -> None: |
|
self.unmet_dependencies = {dep} |
|
self.inverse_users: List[BaseSchedulerNode] = [] |
|
|
|
def is_reduction(self) -> bool: |
|
return False |
|
|
|
def get_inputs_that_alias_output(self) -> Sequence[str]: |
|
return () |
|
|
|
def get_name(self) -> str: |
|
return "OUTPUT" |
|
|
|
__repr__ = get_name |
|
|
|
|
|
def _prune_redundant_deps( |
|
node: BaseSchedulerNode, name_to_fused_node: Dict[str, BaseSchedulerNode] |
|
) -> None: |
|
""" |
|
Prunes weakdeps intended for mutation ordering |
|
on an upstream fused node if after fusion there is another dependency |
|
on the fused upstream node, making the weakdep redundant |
|
|
|
In essence this enforces an ordering on fusions. As fusions occur, weakdeps will |
|
be incrementally removed, enabling other fusions, ensuring they are fused in order. |
|
""" |
|
name_to_dep_count: Counter[str] = collections.Counter() |
|
|
|
for dep in node.unmet_dependencies: |
|
if not isinstance(dep, WeakDep): |
|
name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 |
|
|
|
def should_prune(dep: Dep) -> bool: |
|
if isinstance(dep, WeakDep): |
|
is_redundant = ( |
|
name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 |
|
) |
|
|
|
|
|
|
|
is_self_dep = name_to_fused_node[dep.name] == node |
|
return is_redundant or is_self_dep |
|
else: |
|
return False |
|
|
|
deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)} |
|
|
|
if deps_to_prune: |
|
node.unmet_dependencies = node.unmet_dependencies - deps_to_prune |
|
node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) |
|
|
|
|
|
|
|
kernel_name_to_op = { |
|
"extern_kernels.convolution": torch.ops.aten.convolution, |
|
"extern_kernels.mm": torch.ops.aten.mm, |
|
"extern_kernels.bmm": torch.ops.aten.bmm, |
|
"extern_kernels.addmm": torch.ops.aten.addmm, |
|
} |
|
|
|
|
|
class ExternKernelSchedulerNode(BaseSchedulerNode): |
|
def debug_str_extra(self) -> str: |
|
return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" |
|
|
|
def is_extern(self) -> bool: |
|
return True |
|
|
|
def has_side_effects(self) -> bool: |
|
assert self.node is not None |
|
return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() |
|
|
|
|
|
class NopKernelSchedulerNode(BaseSchedulerNode): |
|
pass |
|
|
|
|
|
class SchedulerNode(BaseSchedulerNode): |
|
def __init__( |
|
self, |
|
scheduler: "Scheduler", |
|
node: Union[ir.ComputedBuffer, ir.TemplateBuffer], |
|
) -> None: |
|
super().__init__(scheduler, node) |
|
self._compute_attrs() |
|
|
|
def _compute_attrs( |
|
self, |
|
extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, |
|
) -> None: |
|
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) |
|
self._sizes, self._body = self.node.simplify_and_reorder( |
|
extra_indexing_constraints=extra_indexing_constraints |
|
) |
|
|
|
group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn |
|
self.group = (self.node.get_device(), group_fn(self._sizes)) |
|
|
|
if isinstance(self.node, ir.TemplateBuffer): |
|
self.set_read_writes(self.node.normalized_read_writes()) |
|
else: |
|
self.set_read_writes( |
|
dependencies.extract_read_writes( |
|
self._body, *self._sizes, normalize=True |
|
) |
|
) |
|
|
|
def recompute_size_and_body( |
|
self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] |
|
) -> None: |
|
self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) |
|
|
|
def debug_str_extra(self) -> str: |
|
name = self.get_name() |
|
lines = [ |
|
f"{name}.group.device = {self.group[0]}", |
|
f"{name}.group.iteration = {self.group[1]}", |
|
f"{name}.sizes = {self._sizes}", |
|
] |
|
for dep in self.read_writes.reads_and_writes(): |
|
buf_name = dep.name |
|
buf = V.graph.get_buffer(buf_name) |
|
lines.append(f"{buf_name}_layout = {pformat(buf.layout)}") |
|
if self.get_aliases(): |
|
lines.append(f"{name}.aliases = {pformat(self.get_aliases())}") |
|
if self.get_mutations(): |
|
lines.append(f"{name}.mutations = {pformat(self.get_mutations())}") |
|
if isinstance(self._body, ir.LoopBody): |
|
lines.append(f"class {name}_loop_body:") |
|
lines.append(textwrap.indent(self._body.debug_str(), " ")) |
|
|
|
assert self.node is not None |
|
if ir.is_triton(self.node.get_device()): |
|
lines.extend(debug_triton_code(self)) |
|
|
|
return "\n".join(lines) |
|
|
|
def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]: |
|
return self._sizes |
|
|
|
def is_reduction(self) -> bool: |
|
assert isinstance( |
|
self.node, (ir.ComputedBuffer, ir.TemplateBuffer) |
|
), f"{type(self.node)=}" |
|
return bool(self.node.get_reduction_type()) |
|
|
|
def is_split_scan(self) -> bool: |
|
assert isinstance( |
|
self.node, (ir.ComputedBuffer, ir.TemplateBuffer) |
|
), f"{type(self.node)=}" |
|
return isinstance(self.node, ir.ComputedBuffer) and isinstance( |
|
self.node.data, ir.SplitScan |
|
) |
|
|
|
def is_template(self) -> bool: |
|
return isinstance(self.node, ir.TemplateBuffer) |
|
|
|
def get_template_node(self) -> Optional[ir.TemplateBuffer]: |
|
return self.node if isinstance(self.node, ir.TemplateBuffer) else None |
|
|
|
def run(self, *index_vars: Sequence[sympy.Expr]) -> None: |
|
self.decide_inplace_update() |
|
self.mark_run() |
|
self.codegen(index_vars) |
|
|
|
def mark_run(self) -> None: |
|
self.allocate() |
|
|
|
def ranges_from_index_vars( |
|
self, index_vars: Sequence[Sequence[sympy.Expr]] |
|
) -> Dict[sympy.Expr, sympy.Expr]: |
|
sizes = self._sizes |
|
assert sum(map(len, sizes)) == sum(map(len, index_vars)) |
|
var_ranges = dict( |
|
zip( |
|
itertools.chain.from_iterable(index_vars), |
|
itertools.chain.from_iterable(sizes), |
|
) |
|
) |
|
return var_ranges |
|
|
|
def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: |
|
var_ranges = self.ranges_from_index_vars(index_vars) |
|
try: |
|
with V.set_ops_handler( |
|
SimplifyIndexing(V.get_ops_handler(), var_ranges) |
|
), V.kernel.set_current_node(self): |
|
self._body(*index_vars) |
|
except Exception: |
|
log.fatal("Error in codegen for %s", self.node) |
|
raise |
|
|
|
def pointwise_read_writes(self) -> dependencies.ReadWrites: |
|
""" |
|
Get the memory dependencies in the non-reduction axis. |
|
""" |
|
sizes, reduction_sizes = self._sizes |
|
|
|
def fn(index: Sequence[sympy.Symbol]) -> str: |
|
return self._body(index, [sympy.Integer(0) for _ in reduction_sizes]) |
|
|
|
return dependencies.extract_read_writes(fn, sizes) |
|
|
|
def can_inplace(self, read_dep: dependencies.Dep) -> bool: |
|
if self.get_aliases() or self.is_template(): |
|
return False |
|
if len(self.read_writes.writes) == 1 and isinstance( |
|
read_dep, dependencies.MemoryDep |
|
): |
|
write_dep = next(iter(self.read_writes.writes)) |
|
assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" |
|
return read_dep.index == write_dep.index and read_dep.size == write_dep.size |
|
return False |
|
|
|
@cache_on_self |
|
def _get_atomic_add_buffers(self) -> Set[str]: |
|
buffers_store_as_atomic_add = set() |
|
if isinstance(self._body, ir.LoopBody): |
|
for node in self._body.get_nodes(): |
|
if ( |
|
node.op == "call_method" |
|
and node.target == "store" |
|
and ( |
|
("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") |
|
or (len(node.args) == 5 and node.args[4] == "atomic_add") |
|
) |
|
): |
|
buffers_store_as_atomic_add.add( |
|
node.kwargs["name"] |
|
if "name" in node.kwargs |
|
else (node.args[1] if len(node.args) >= 2 else "") |
|
) |
|
return buffers_store_as_atomic_add |
|
|
|
|
|
class FusedSchedulerNode(BaseSchedulerNode): |
|
""" |
|
This is a "fake" scheduler node that represents a group of scheduler nodes |
|
that are meant to be fused together. The way it does this is by maintaining |
|
its unmet dependencies as the union of its constituent nodes. |
|
""" |
|
|
|
@classmethod |
|
def fuse( |
|
cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> "FusedSchedulerNode": |
|
assert node1.scheduler is node2.scheduler |
|
assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) |
|
assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) |
|
nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) |
|
return cls(node1.scheduler, nodes) |
|
|
|
def __init__( |
|
self, scheduler: "Scheduler", snodes: Sequence[BaseSchedulerNode] |
|
) -> None: |
|
|
|
self.snodes = snodes |
|
self.scheduler = scheduler |
|
self.node = None |
|
self.users: List[NodeUser] = [] |
|
self.inverse_users = [] |
|
self.node_users = [] |
|
self.group = max(snodes, key=lambda x: int(x.is_reduction())).group |
|
self.ancestors = set.union( |
|
*[x.ancestors for x in snodes if x.ancestors is not None] |
|
) |
|
|
|
self.set_read_writes( |
|
dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) |
|
) |
|
|
|
self.unmet_dependencies = { |
|
dep |
|
for dep in set.union(*[x.unmet_dependencies for x in snodes]) |
|
if dep.name not in self.get_names() |
|
} - self.read_writes.writes |
|
self.min_order = min(x.min_order for x in self.snodes) |
|
self.max_order = max(x.max_order for x in self.snodes) |
|
|
|
@cache_on_self |
|
def get_name(self) -> str: |
|
return "_".join([x.get_name() for x in self.snodes]) |
|
|
|
def get_first_name(self) -> str: |
|
return self.snodes[0].get_name() |
|
|
|
@cache_on_self |
|
def get_names(self) -> Set[str]: |
|
return set.union(*[x.get_names() for x in self.snodes]) |
|
|
|
def debug_str_extra(self) -> str: |
|
lines = [ |
|
f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" |
|
for i, node in enumerate(self.snodes) |
|
] |
|
node = self.snodes[0].node |
|
assert node is not None |
|
device = node.get_device() |
|
if ir.is_triton(device): |
|
lines.extend(debug_triton_code(self)) |
|
|
|
return textwrap.indent("\n".join(lines).rstrip(), " ") |
|
|
|
def set_last_usage( |
|
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] |
|
) -> None: |
|
|
|
|
|
super().set_last_usage(future_used_buffers, mutation_real_name) |
|
|
|
|
|
future_used_buffers: Set[str] = set() |
|
for node in reversed(self.snodes): |
|
node.set_last_usage(future_used_buffers, mutation_real_name) |
|
future_used_buffers.update(node.last_usage) |
|
|
|
@cache_on_self |
|
def used_buffer_names(self) -> Set[str]: |
|
return set.union(*[x.used_buffer_names() for x in self.snodes]) |
|
|
|
@cache_on_self |
|
def used_or_aliased_buffer_names(self) -> Set[str]: |
|
return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes]) |
|
|
|
def get_nodes(self) -> Sequence[BaseSchedulerNode]: |
|
return self.snodes |
|
|
|
def __repr__(self) -> str: |
|
return f"{type(self).__name__}(nodes={self.get_name()})" |
|
|
|
@cache_on_self |
|
def is_reduction(self) -> bool: |
|
return any(x.is_reduction() for x in self.snodes) |
|
|
|
@cache_on_self |
|
def is_split_scan(self) -> bool: |
|
return any(x.is_split_scan() for x in self.snodes) |
|
|
|
@cache_on_self |
|
def is_template(self) -> bool: |
|
return any(x.is_template() for x in self.snodes) |
|
|
|
@cache_on_self |
|
def get_template_node(self) -> Optional[ir.TemplateBuffer]: |
|
for node in self.snodes: |
|
if node.is_template(): |
|
return node.get_template_node() |
|
return None |
|
|
|
def get_device(self) -> torch.device: |
|
return self.group[0] |
|
|
|
@cache_on_self |
|
def has_aliasing_or_mutation(self) -> bool: |
|
return any(x.has_aliasing_or_mutation() for x in self.snodes) |
|
|
|
@cache_on_self |
|
def op_counts(self) -> Counter[str]: |
|
op_counts: Counter[str] = collections.Counter() |
|
for node in self.snodes: |
|
op_counts.update(node.op_counts()) |
|
return op_counts |
|
|
|
|
|
|
|
def update_mutated_names(self, renames: Dict[str, str]) -> None: |
|
raise NotImplementedError |
|
|
|
def add_fake_dep(self, name: Dep) -> None: |
|
raise NotImplementedError |
|
|
|
def set_users(self, users: List["NodeUser"]) -> None: |
|
raise NotImplementedError |
|
|
|
def get_aliases(self) -> Sequence[str]: |
|
raise NotImplementedError |
|
|
|
def get_mutations(self) -> List[str]: |
|
raise NotImplementedError |
|
|
|
def can_inplace(self, read_dep: dependencies.Dep) -> bool: |
|
raise NotImplementedError |
|
|
|
def allocate(self) -> None: |
|
raise NotImplementedError |
|
|
|
def can_free(self) -> bool: |
|
raise NotImplementedError |
|
|
|
def debug_str(self) -> str: |
|
"""Longer form printout for trace logs""" |
|
name = self.get_name() |
|
node_typestr = ",".join(type(n).__name__ for n in self.snodes) |
|
lines = [ |
|
f"{name}: {type(self).__name__}({node_typestr})", |
|
f"{name}.writes = {pformat(self.read_writes.writes)}", |
|
f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}", |
|
f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}", |
|
f"{name}.users = {self.users}", |
|
] |
|
try: |
|
lines += [ |
|
self.debug_str_extra(), |
|
] |
|
except Exception: |
|
log.warning("Ignoring error in debug_str()", exc_info=True) |
|
|
|
return "\n".join(lines).rstrip() |
|
|
|
|
|
class ForeachKernelSchedulerNode(FusedSchedulerNode): |
|
"""Scheduler node which consists of a list of scheduler nodes that each operate on a |
|
distinct tensor in a list of tensors.""" |
|
|
|
def get_consumer_subnode_for( |
|
self, producer: BaseSchedulerNode |
|
) -> Optional[BaseSchedulerNode]: |
|
if producer.get_name() in self.read_to_node: |
|
return self.read_to_node[producer.get_name()] |
|
|
|
return None |
|
|
|
def get_producer_subnode_for( |
|
self, consumer: BaseSchedulerNode |
|
) -> Optional[BaseSchedulerNode]: |
|
for rd in consumer.read_writes.reads: |
|
if rd.name in self.name_to_node: |
|
return self.name_to_node[rd.name] |
|
|
|
return None |
|
|
|
@classmethod |
|
def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: |
|
why = WhyNoFuse(producer, consumer) |
|
if producer.is_foreach() and consumer.is_foreach(): |
|
producer = typing.cast(ForeachKernelSchedulerNode, producer) |
|
consumer = typing.cast(ForeachKernelSchedulerNode, consumer) |
|
foreach_match = len(producer.snodes) == len(consumer.snodes) |
|
if not foreach_match: |
|
why("foreach do not have same length") |
|
return foreach_match and all( |
|
producer.scheduler.can_fuse(l, r) |
|
for l, r in zip(producer.snodes, consumer.snodes) |
|
) |
|
elif consumer.is_foreach(): |
|
if producer.is_reduction(): |
|
why( |
|
"candidate producer is a reduction, foreach ops cannot be fused with reductions currently" |
|
) |
|
return False |
|
|
|
consumer = typing.cast(ForeachKernelSchedulerNode, consumer) |
|
consumer_subnode = consumer.get_consumer_subnode_for(producer) |
|
if consumer_subnode is not None: |
|
return consumer.scheduler.can_fuse(producer, consumer_subnode) |
|
|
|
why("candidate producer is not dep of any foreach consumer") |
|
return False |
|
|
|
elif producer.is_foreach(): |
|
if consumer.is_reduction(): |
|
why( |
|
"candidate consumer is a reduction, foreach ops cannot be fused with reductions currently" |
|
) |
|
return False |
|
|
|
producer = typing.cast(ForeachKernelSchedulerNode, producer) |
|
producer_subnode = producer.get_producer_subnode_for(consumer) |
|
if producer_subnode is not None: |
|
return producer.scheduler.can_fuse(producer_subnode, consumer) |
|
|
|
why("candidate consumer has no dep in any foreach producer") |
|
return False |
|
|
|
raise AssertionError( |
|
"At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" |
|
) |
|
|
|
@classmethod |
|
def fuse( |
|
cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode |
|
) -> "ForeachKernelSchedulerNode": |
|
assert producer.is_foreach() or consumer.is_foreach() |
|
prev_node_1 = None |
|
prev_node_2 = None |
|
fused_nodes: List[BaseSchedulerNode] |
|
if producer.is_foreach() and consumer.is_foreach(): |
|
producer = typing.cast(ForeachKernelSchedulerNode, producer) |
|
consumer = typing.cast(ForeachKernelSchedulerNode, consumer) |
|
fused_nodes = [ |
|
FusedSchedulerNode.fuse(l, r) |
|
for l, r in zip(producer.snodes, consumer.snodes) |
|
] |
|
elif producer.is_foreach(): |
|
producer = typing.cast(ForeachKernelSchedulerNode, producer) |
|
producer_subnode = producer.get_producer_subnode_for(consumer) |
|
fused_nodes = [] |
|
prev_node_1 = producer |
|
prev_node_2 = None |
|
for node in producer.snodes: |
|
if node is producer_subnode: |
|
new_node = FusedSchedulerNode.fuse(node, consumer) |
|
prev_node_2 = new_node |
|
fused_nodes.append(new_node) |
|
else: |
|
fused_nodes.append(node) |
|
|
|
elif consumer.is_foreach(): |
|
consumer = typing.cast(ForeachKernelSchedulerNode, consumer) |
|
consumer_subnode = consumer.get_consumer_subnode_for(producer) |
|
fused_nodes = [] |
|
prev_node_1 = consumer |
|
prev_node_2 = None |
|
|
|
for node in consumer.snodes: |
|
if node is consumer_subnode: |
|
new_node = FusedSchedulerNode.fuse(producer, node) |
|
prev_node_2 = new_node |
|
fused_nodes.append(new_node) |
|
else: |
|
fused_nodes.append(node) |
|
|
|
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) |
|
|
|
def __init__( |
|
self, |
|
scheduler: "Scheduler", |
|
nodes: Sequence[BaseSchedulerNode], |
|
prev_node_1: Optional[BaseSchedulerNode] = None, |
|
prev_node_2: Optional[BaseSchedulerNode] = None, |
|
) -> None: |
|
self.read_to_node = {} |
|
self.name_to_node = {} |
|
|
|
if prev_node_1 is None or prev_node_2 is None: |
|
super().__init__(scheduler, nodes) |
|
|
|
for node in nodes: |
|
for read in node.read_writes.reads: |
|
self.read_to_node[read.name] = node |
|
|
|
for name in node.get_names(): |
|
self.name_to_node[name] = node |
|
else: |
|
self.scheduler = scheduler |
|
self.snodes = nodes |
|
self.node = None |
|
self.users: List[NodeUser] = [] |
|
|
|
self.set_read_writes( |
|
dependencies.ReadWrites.merge_list( |
|
[prev_node_1.read_writes, prev_node_2.read_writes] |
|
) |
|
) |
|
|
|
self.unmet_dependencies = { |
|
dep |
|
for dep in set.union( |
|
prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies |
|
) |
|
if dep.name not in self.get_names() |
|
} - self.read_writes.writes |
|
|
|
self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) |
|
self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) |
|
|
|
if prev_node_1.is_foreach(): |
|
assert isinstance(prev_node_1, ForeachKernelSchedulerNode) |
|
foreach_node, other_node = prev_node_1, prev_node_2 |
|
else: |
|
assert isinstance(prev_node_2, ForeachKernelSchedulerNode) |
|
foreach_node, other_node = prev_node_2, prev_node_1 |
|
|
|
self.ancestors = foreach_node.ancestors |
|
self.ancestors.update(other_node.ancestors) |
|
|
|
self.name_to_node = foreach_node.name_to_node |
|
for name in other_node.get_names(): |
|
self.name_to_node[name] = other_node |
|
|
|
self.group = (nodes[0].get_device(), ((sympy.Expr("foreach"),),)) |
|
|
|
self.origins: Set[torch.fx.Node] = set() |
|
|
|
def mark_run(self) -> None: |
|
raise NotImplementedError |
|
|
|
def codegen(self) -> None: |
|
assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" |
|
self.node.get_store_function()(self.node.make_loader()()) |
|
|
|
def can_free(self) -> bool: |
|
raise NotImplementedError |
|
|
|
def is_foreach(self) -> bool: |
|
return True |
|
|
|
def get_subkernel_nodes(self) -> List[BaseSchedulerNode]: |
|
"""Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists. |
|
These nodes may be vertically fused.""" |
|
return list(self.snodes) |
|
|
|
def get_nodes(self) -> Sequence[BaseSchedulerNode]: |
|
"""Returns all nodes contained in this kernel, unpacking fused nodes |
|
into their constituent scheduler nodes.""" |
|
return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) |
|
|
|
def get_first_name(self) -> str: |
|
return self.snodes[0].get_first_name() |
|
|
|
def prune_redundant_deps( |
|
self, name_to_fused_node: Dict[str, BaseSchedulerNode] |
|
) -> None: |
|
_prune_redundant_deps(self, name_to_fused_node) |
|
|
|
for node in self.snodes: |
|
node.prune_redundant_deps(name_to_fused_node) |
|
|
|
|
|
def pick_loop_order( |
|
stride_lengths: List[List[int]], |
|
sizes: List[sympy.Expr], |
|
priority_idx: Tuple[int, ...] = (), |
|
) -> List[int]: |
|
""" |
|
A heuristic to decide loop iteration orders. This has not been well |
|
tuned and may be something we should autotune. |
|
""" |
|
|
|
@functools.cmp_to_key |
|
def index_cmp(a: int, b: int) -> int: |
|
if sizes[a] == 1 or sizes[b] == 1: |
|
|
|
return cmp(sizes[a] == 1, sizes[b] == 1) |
|
|
|
|
|
|
|
stride_len_a = [abs(sl[a]) for sl in stride_lengths] |
|
stride_len_b = [abs(sl[b]) for sl in stride_lengths] |
|
|
|
|
|
|
|
a_first = sum( |
|
sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) |
|
) |
|
b_first = sum( |
|
sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) |
|
) |
|
if a_first > b_first: |
|
return -1 |
|
if b_first > a_first: |
|
return 1 |
|
|
|
|
|
return cmp(b, a) |
|
|
|
order = list(reversed(range(len(stride_lengths[0])))) |
|
if len(priority_idx) > 0: |
|
|
|
stride_lengths = [stride_lengths[pi] for pi in priority_idx] |
|
if config.pick_loop_orders: |
|
order.sort(key=index_cmp) |
|
return order |
|
|
|
|
|
@dataclasses.dataclass |
|
class NodeUser: |
|
node: Union[BaseSchedulerNode, OutputNode] |
|
can_inplace: bool = False |
|
|
|
|
|
|
|
is_weak: bool = False |
|
|
|
def __hash__(self) -> int: |
|
return hash((self.node.get_name(), self.can_inplace, self.is_weak)) |
|
|
|
def __eq__(self, other: object) -> bool: |
|
return ( |
|
isinstance(other, NodeUser) |
|
and self.get_name() == other.get_name() |
|
and self.can_inplace == other.can_inplace |
|
and self.is_weak == other.is_weak |
|
) |
|
|
|
def get_name(self) -> str: |
|
return self.node.get_name() |
|
|
|
def merge(self, other: "NodeUser") -> "NodeUser": |
|
assert self.node is other.node |
|
return NodeUser( |
|
self.node, |
|
self.can_inplace and other.can_inplace, |
|
self.is_weak and other.is_weak, |
|
) |
|
|
|
|
|
_post_grad_graph_counter = itertools.count() |
|
|
|
|
|
class Scheduler: |
|
__dep_size_hint_cache: Dict[Dep, int] |
|
|
|
@dynamo_timed |
|
def __init__(self, nodes: List[ir.Buffer]) -> None: |
|
super().__init__() |
|
self.__dep_size_hint_cache = {} |
|
V.graph.scheduler = self |
|
self.backends: Dict[torch.device, BaseScheduling] = {} |
|
self.post_grad_graph_id = next(_post_grad_graph_counter) |
|
|
|
self.available_buffer_names = { |
|
*V.graph.graph_inputs.keys(), |
|
*V.graph.constants.keys(), |
|
*V.graph.torchbind_constants.keys(), |
|
} |
|
|
|
self.nodes = [self.create_scheduler_node(n) for n in nodes] |
|
|
|
|
|
self.available_buffer_names.update(V.graph.constants.keys()) |
|
for node in self.nodes: |
|
node.prune_deps() |
|
|
|
self.name_to_node: Dict[str, BaseSchedulerNode] = { |
|
n.get_name(): n for n in self.nodes |
|
} |
|
self.name_to_fused_node: Dict[ |
|
str, BaseSchedulerNode |
|
] = dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.mutation_real_name: Dict[str, str] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.mutation_renames: Dict[str, str] = {} |
|
|
|
self.compute_dependencies() |
|
self.topological_sort_schedule() |
|
self.dead_node_elimination() |
|
if config.reorder_for_compute_comm_overlap: |
|
comms.decide_global_ordering_of_comms(self.nodes) |
|
self.compute_ancestors() |
|
|
|
metrics.ir_nodes_pre_fusion += len(self.nodes) |
|
V.debug.ir_pre_fusion(self.nodes) |
|
self.num_orig_nodes = len(self.nodes) |
|
self.name_to_fused_node = {n.get_name(): n for n in self.nodes} |
|
self.create_foreach_nodes() |
|
self.topological_sort_schedule() |
|
self.logged_slow_fusion: Set[Tuple[str, str]] = set() |
|
self.fuse_nodes() |
|
self.finalize_multi_template_buffers() |
|
if config.reorder_for_compute_comm_overlap: |
|
|
|
self.compute_node_users() |
|
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) |
|
self.compute_last_usage() |
|
V.debug.ir_post_fusion(self.nodes) |
|
V.debug.graph_diagram(self.nodes) |
|
self.debug_draw_graph() |
|
|
|
|
|
self.current_device: Optional[torch.device] = None |
|
self.buffer_names_to_free: Set[str] = set() |
|
|
|
|
|
|
|
self.origin_to_index: Dict[torch.fx.Node, int] = {} |
|
|
|
get_metric_table("graph_stats").add_row( |
|
lambda: { |
|
"graph_id": self.post_grad_graph_id, |
|
"num_nodes_before_fusion": self.num_orig_nodes, |
|
"num_nodes_after_fusion": len(self.nodes), |
|
} |
|
) |
|
|
|
def get_current_device_or_throw(self) -> torch.device: |
|
if device := self.current_device: |
|
return device |
|
else: |
|
raise RuntimeError("No current device") |
|
|
|
def debug_draw_graph(self) -> None: |
|
"""Generate an image of the graph for debugging""" |
|
if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": |
|
from .debug import draw_buffers |
|
|
|
draw_buffers(self.nodes, print_graph=True) |
|
|
|
def debug_print_nodes(self, label: str) -> None: |
|
if log.isEnabledFor(logging.INFO): |
|
log.info("%s:", label) |
|
for node in self.nodes: |
|
node.log_details() |
|
|
|
def create_scheduler_node(self, node: ir.Buffer) -> BaseSchedulerNode: |
|
assert ( |
|
node.origins is not None |
|
), "All nodes passed to scheduling must have an origin" |
|
if node.is_no_op(): |
|
return NopKernelSchedulerNode(self, node) |
|
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): |
|
return SchedulerNode(self, node) |
|
elif isinstance(node, ir.ExternKernel): |
|
return ExternKernelSchedulerNode(self, node) |
|
else: |
|
raise NotImplementedError(node) |
|
|
|
def create_foreach_nodes(self) -> None: |
|
removed_node_names = set() |
|
fe_nodes = [] |
|
kept_node_names = self.name_to_fused_node.keys() |
|
|
|
for names in V.graph.lists.values(): |
|
names = [ |
|
name |
|
for name in names |
|
if name in kept_node_names |
|
and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) |
|
] |
|
if not names: |
|
|
|
continue |
|
|
|
removed_node_names.update(names) |
|
snodes = [self.name_to_node[name] for name in names] |
|
|
|
fe_node = ForeachKernelSchedulerNode(self, snodes) |
|
|
|
fe_nodes.append(fe_node) |
|
|
|
for name in names: |
|
self.name_to_fused_node[name] = fe_node |
|
|
|
self.nodes = [ |
|
node for node in self.nodes if node.get_name() not in removed_node_names |
|
] + list(fe_nodes) |
|
|
|
def compute_dependencies(self) -> None: |
|
""" |
|
Create dependency edges between nodes, handling aliasing and |
|
mutation properly. |
|
""" |
|
|
|
T = TypeVar("T") |
|
|
|
class DedupList(Generic[T]): |
|
""" |
|
This data structure behaves like a list except it makes sure the |
|
elements remain unique. |
|
Normally one could use a set/dict for this purpose however |
|
the list in question gets elements appended as it is being |
|
iterated over which means that we need to keep the list |
|
semantics. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
items: Optional[List[T]] = None, |
|
membership: Optional[Set[T]] = None, |
|
) -> None: |
|
self.items = items or list() |
|
self.membership = membership or set() |
|
|
|
def append(self, node_user: T) -> None: |
|
if node_user in self.membership: |
|
return |
|
self.items.append(node_user) |
|
self.membership.add(node_user) |
|
|
|
def __add__(self, other: "DedupList[T]") -> "DedupList[T]": |
|
new_membership = set.union(self.membership, other.membership) |
|
new_items = self.items + [ |
|
x for x in other.items if x not in self.membership |
|
] |
|
return DedupList(new_items, new_membership) |
|
|
|
name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict( |
|
DedupList |
|
) |
|
|
|
|
|
|
|
|
|
for node1 in self.nodes: |
|
node1_name = node1.get_name() |
|
for node2_name in node1.get_aliases(): |
|
if node1_name in name_to_users and node2_name in name_to_users: |
|
|
|
list1 = name_to_users[node1_name] |
|
list2 = name_to_users[node2_name] |
|
combined = list1 + list2 |
|
for key in name_to_users.keys(): |
|
if name_to_users[key] is list1 or name_to_users[key] is list2: |
|
name_to_users[key] = combined |
|
elif node1_name in name_to_users: |
|
name_to_users[node2_name] = name_to_users[node1_name] |
|
else: |
|
name_to_users[node1_name] = name_to_users[node2_name] |
|
|
|
def rename(n: str) -> str: |
|
if n in self.mutation_renames: |
|
return rename(self.mutation_renames[n]) |
|
return n |
|
|
|
def dep_closure(node_name: str) -> Set[str]: |
|
reachable_names = {node_name} |
|
node = self.name_to_node[node_name] |
|
write_dep = next(iter(node.read_writes.writes)) |
|
for read_dep in node.read_writes.reads: |
|
if ( |
|
read_dep.name in self.name_to_node |
|
and isinstance(read_dep, dependencies.MemoryDep) |
|
and isinstance(write_dep, dependencies.MemoryDep) |
|
and read_dep.index == write_dep.index |
|
and read_dep.size == write_dep.size |
|
): |
|
reachable_names.update(dep_closure(read_dep.name)) |
|
return reachable_names |
|
|
|
def add_user( |
|
used_by_name: str, |
|
user_node: Union[BaseSchedulerNode, OutputNode], |
|
can_inplace: bool = False, |
|
is_weak: bool = False, |
|
) -> None: |
|
name_to_users[rename(used_by_name)].append( |
|
NodeUser(user_node, can_inplace, is_weak) |
|
) |
|
|
|
unbacked_symbol_to_origin_node = {} |
|
|
|
|
|
|
|
|
|
for name, val in V.graph.graph_inputs.items(): |
|
if isinstance(val, sympy.Expr): |
|
for fs in val.free_symbols: |
|
unbacked_symbol_to_origin_node[fs] = None |
|
|
|
for node in self.nodes: |
|
log.debug("scheduling %s", node.node) |
|
|
|
|
|
|
|
assert node.node is not None |
|
unbacked_symbol_defs = sorted( |
|
node.node.get_unbacked_symbol_defs(), key=lambda x: x.name |
|
) |
|
for s in unbacked_symbol_defs: |
|
assert isinstance(s, sympy.Symbol) |
|
|
|
|
|
|
|
if s not in unbacked_symbol_to_origin_node: |
|
unbacked_symbol_to_origin_node[s] = node.get_name() |
|
|
|
unbacked_symbol_uses = sorted( |
|
node.node.get_unbacked_symbol_uses(), key=lambda x: x.name |
|
) |
|
|
|
for s in unbacked_symbol_uses: |
|
assert ( |
|
s in unbacked_symbol_to_origin_node |
|
), f"{s} not in {unbacked_symbol_to_origin_node}" |
|
if (r := unbacked_symbol_to_origin_node[s]) is not None: |
|
node.add_fake_dep(StarDep(r)) |
|
|
|
if ( |
|
len(node.read_writes.writes) == 1 |
|
and (dep := next(iter(node.read_writes.writes))) |
|
and isinstance(dep, MemoryDep) |
|
): |
|
node_mode = dep.mode |
|
else: |
|
node_mode = None |
|
|
|
|
|
assert len(node.get_mutations()) <= 1 |
|
for alt_name in node.get_mutations(): |
|
alt_name = rename(alt_name) |
|
|
|
add_user(alt_name, node) |
|
node.add_fake_dep(StarDep(alt_name, mode=node_mode)) |
|
for other_node in name_to_users[alt_name].items: |
|
|
|
other_name = rename(other_node.get_name()) |
|
known_dep_node_names = dep_closure(node.get_name()) |
|
if other_name not in known_dep_node_names: |
|
|
|
|
|
node.add_fake_dep(WeakDep(other_name)) |
|
add_user(other_name, node, is_weak=True) |
|
|
|
|
|
for read in node.read_writes.reads: |
|
is_weak = isinstance(read, WeakDep) |
|
add_user(read.name, node, node.can_inplace(read), is_weak) |
|
|
|
node.update_mutated_names(self.mutation_renames) |
|
|
|
|
|
for alt_name in node.get_mutations(): |
|
self.mutation_renames[rename(alt_name)] = node.get_name() |
|
self.mutation_renames[alt_name] = node.get_name() |
|
self.mutation_real_name[node.get_name()] = self.mutation_real_name.get( |
|
alt_name, alt_name |
|
) |
|
|
|
|
|
for node_name in V.graph.get_output_names(): |
|
log.debug("scheduling output %s", node_name) |
|
add_user(node_name, OutputNode(StarDep(node_name))) |
|
|
|
|
|
for node in V.graph.graph_outputs: |
|
for s in node.get_unbacked_symbol_uses(): |
|
assert ( |
|
s in unbacked_symbol_to_origin_node |
|
), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" |
|
if (node_name := unbacked_symbol_to_origin_node[s]) is not None: |
|
log.debug( |
|
"scheduling output %s for unbacked symint %s", node_name, s |
|
) |
|
add_user(node_name, OutputNode(StarDep(node_name))) |
|
|
|
|
|
for name in self.mutation_renames: |
|
if name in V.graph.graph_inputs: |
|
add_user(name, OutputNode(StarDep(name))) |
|
V.graph.mutated_inputs.add(name) |
|
elif name in V.graph.constants: |
|
|
|
add_user(name, OutputNode(StarDep(name))) |
|
|
|
inp_names = { |
|
name: index for index, name in enumerate(V.graph.graph_inputs.keys()) |
|
} |
|
V.graph.mutated_input_idxs = [ |
|
inp_names[name] for name in V.graph.mutated_inputs |
|
] |
|
|
|
|
|
for node in self.nodes: |
|
node.set_users(name_to_users[node.get_name()].items) |
|
|
|
|
|
for node in self.nodes: |
|
for user in node.users: |
|
user.node.inverse_users.append(node) |
|
|
|
def compute_node_users(self) -> None: |
|
|
|
buf_to_snode: Dict[str, BaseSchedulerNode] = {} |
|
for node in self.nodes: |
|
if isinstance(node, FusedSchedulerNode): |
|
for x in node.snodes: |
|
buf_to_snode[x.get_name()] = node |
|
buf_to_snode[node.get_name()] = node |
|
|
|
for node in self.nodes: |
|
node.node_users = [] |
|
node.inverse_users = [] |
|
|
|
|
|
for node in self.nodes: |
|
inverse_users: List[BaseSchedulerNode] = [] |
|
for dep in node.unmet_dependencies: |
|
assert dep.name in buf_to_snode |
|
dep_node = buf_to_snode[dep.name] |
|
inverse_users.append(dep_node) |
|
node.inverse_users = inverse_users |
|
|
|
|
|
|
|
|
|
|
|
node_to_users: Dict[BaseSchedulerNode, List[BaseSchedulerNode]] = {} |
|
for node in self.nodes: |
|
for inverse_user in node.inverse_users: |
|
node_to_users.setdefault(inverse_user, []).append(node) |
|
for node, users in node_to_users.items(): |
|
node.node_users = users |
|
|
|
def dead_node_elimination(self) -> None: |
|
""" |
|
Remove any nodes without users |
|
""" |
|
again = True |
|
while again: |
|
updated_nodes = [] |
|
for node in self.nodes: |
|
|
|
def can_eliminate_user(user: NodeUser) -> bool: |
|
return user.is_weak or user.get_name() in V.graph.removed_buffers |
|
|
|
can_eliminate = not node.has_side_effects() and all( |
|
can_eliminate_user(u) for u in node.users |
|
) |
|
|
|
if not can_eliminate: |
|
updated_nodes.append(node) |
|
else: |
|
|
|
log.debug("removed dead node: %s", node.get_name()) |
|
V.graph.removed_buffers.add(node.get_name()) |
|
|
|
again = len(self.nodes) > len(updated_nodes) |
|
self.nodes = updated_nodes |
|
|
|
|
|
for node in self.nodes: |
|
node.prune_weak_deps() |
|
|
|
def topological_sort_schedule(self) -> None: |
|
""" |
|
Ensure self.nodes is in topologically sorted order |
|
""" |
|
seen: Set[BaseSchedulerNode] = set() |
|
name_to_node: Dict[str, BaseSchedulerNode] = dict() |
|
result: List[BaseSchedulerNode] = [] |
|
|
|
def visit(n: BaseSchedulerNode) -> None: |
|
if n not in seen: |
|
seen.add(n) |
|
for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): |
|
visit(name_to_node[dep.name]) |
|
result.append(n) |
|
|
|
for node in self.nodes: |
|
for name in node.get_names(): |
|
name_to_node[name] = node |
|
for node in self.nodes: |
|
visit(node) |
|
self.nodes = result |
|
|
|
def compute_ancestors(self) -> None: |
|
""" |
|
Populate each node.ancestors |
|
""" |
|
|
|
name_to_ancestors: Dict[str, Set[str]] = {} |
|
for node in self.nodes: |
|
ancestors = set() |
|
for dep in node.unmet_dependencies: |
|
ancestors.add(dep.name) |
|
ancestors |= name_to_ancestors[dep.name] |
|
name_to_ancestors[node.get_name()] = ancestors |
|
node.ancestors = ancestors |
|
|
|
for order, node in enumerate(self.nodes): |
|
node.min_order = order |
|
node.max_order = order |
|
|
|
def fuse_nodes(self) -> None: |
|
""" |
|
Mutates self.nodes to combine nodes into FusedSchedulerNodes. |
|
""" |
|
for i in range(10): |
|
old_len = len(self.nodes) |
|
fusion_log.debug( |
|
"===== attempting fusion (%d/10): %d nodes =====", |
|
i + 1, |
|
old_len, |
|
) |
|
self.fuse_nodes_once() |
|
new_len = len(self.nodes) |
|
fusion_log.debug( |
|
"completed fusion round (%d/10): fused %d nodes into %d nodes\n", |
|
i + 1, |
|
old_len, |
|
new_len, |
|
) |
|
if new_len == old_len or new_len == 1: |
|
fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) |
|
break |
|
|
|
def benchmark_fused_nodes( |
|
self, nodes: Sequence[BaseSchedulerNode] |
|
) -> Tuple[float, str]: |
|
""" |
|
Benchmark fused list of nodes and return the execution time |
|
in milliseconds on randomly generated inputs. |
|
""" |
|
assert len(nodes) > 0 |
|
device = nodes[0].get_device() |
|
self.current_device = device |
|
backend = self.get_backend(device) |
|
return backend.benchmark_fused_nodes(nodes) |
|
|
|
def finalize_multi_template_buffers(self) -> None: |
|
def replace_buffer( |
|
orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer |
|
) -> None: |
|
replaced_name = new_node.name |
|
orig_name = orig_node.get_name() |
|
assert isinstance(orig_name, str) and isinstance(replaced_name, str) |
|
|
|
del V.graph.name_to_buffer[replaced_name] |
|
new_node.name = orig_name |
|
|
|
orig = V.graph.buffers.index(orig_node) |
|
V.graph.buffers.remove(new_node) |
|
V.graph.buffers[orig] = new_node |
|
V.graph.name_to_buffer[orig_name] = new_node |
|
|
|
for i, node in enumerate(self.nodes): |
|
if isinstance(node, SchedulerNode) and isinstance( |
|
node.node, ir.MultiTemplateBuffer |
|
): |
|
multi_node = node.node |
|
min_node_unfused, _ = multi_node.get_min_choice() |
|
|
|
if isinstance( |
|
min_node_unfused, |
|
torch._inductor.ir.TritonTemplateCallerBase, |
|
): |
|
node.node.finalize_as_triton_caller(min_node_unfused) |
|
continue |
|
|
|
out_tensorbox = min_node_unfused.output_node() |
|
out_storage = out_tensorbox.data |
|
assert isinstance(out_storage, ir.StorageBox) |
|
out_buffer = out_storage.data |
|
assert isinstance(out_buffer, ir.Buffer) |
|
|
|
out_buffer.layout = multi_node.layout |
|
replace_buffer(multi_node, out_buffer) |
|
new_scheduler_node = self.create_scheduler_node(out_buffer) |
|
|
|
self.nodes[i] = new_scheduler_node |
|
self.name_to_node[node.get_name()] = new_scheduler_node |
|
self.name_to_fused_node[node.get_name()] = new_scheduler_node |
|
|
|
new_scheduler_node.users = node.users |
|
new_scheduler_node.min_order = node.min_order |
|
new_scheduler_node.max_order = node.max_order |
|
new_scheduler_node.last_usage = node.last_usage |
|
for user in new_scheduler_node.users: |
|
user.node.inverse_users.remove(node) |
|
user.node.inverse_users.append(new_scheduler_node) |
|
|
|
def speedup_by_fusion( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> bool: |
|
""" |
|
If config.benchmark_fusion is False, always return True. |
|
Otherwise, return True if fusion can brings speedup. |
|
""" |
|
|
|
is_multi_template = node1.is_template() and isinstance( |
|
node1.get_template_node(), ir.MultiTemplateBuffer |
|
) |
|
if not config.benchmark_fusion and not is_multi_template: |
|
return True |
|
|
|
if ( |
|
node1.is_template() |
|
and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) |
|
or node1.is_foreach() |
|
or node2.is_foreach() |
|
): |
|
|
|
return True |
|
|
|
node_list_1 = node1.get_nodes() |
|
device = node_list_1[0].get_device() |
|
|
|
|
|
if device.type == "cpu": |
|
return True |
|
|
|
node_list_2 = node2.get_nodes() |
|
node_list_fused = list(itertools.chain(node_list_1, node_list_2)) |
|
|
|
|
|
|
|
|
|
if any( |
|
hasattr(n.node, "data") |
|
and n.node is not None |
|
and hasattr(n.node.data, "scatter_mode") |
|
and n.node.data.scatter_mode == "atomic_add" |
|
for n in node_list_fused |
|
): |
|
return True |
|
|
|
from triton.compiler.errors import CompilationError |
|
|
|
why = WhyNoFuse(node1, node2) |
|
|
|
def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: |
|
if fusion_log.isEnabledFor(logging.DEBUG): |
|
if ms_fused < ms1 + ms2: |
|
fusion_log.debug( |
|
"can fuse (benchmark): fusing %s with %s cause %sx speedup", |
|
node1.get_names(), |
|
node2.get_names(), |
|
green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), |
|
) |
|
else: |
|
fusion_log.debug( |
|
"cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", |
|
node1.get_names(), |
|
node2.get_names(), |
|
red_text(f"{ms_fused / (ms1 + ms2):.3f}"), |
|
) |
|
|
|
if isinstance(node1, SchedulerNode) and isinstance( |
|
node1.node, ir.MultiTemplateBuffer |
|
): |
|
multi_node = node1.node |
|
choice_timings = multi_node.choice_timings |
|
|
|
_, ms1 = multi_node.get_min_choice() |
|
ms2, path2 = self.benchmark_fused_nodes(node_list_2) |
|
|
|
min_ms_fused = float("inf") |
|
ms_fused_choice = None |
|
|
|
triton_choices = 0 |
|
|
|
for choice, unfused_time in choice_timings.items(): |
|
if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): |
|
continue |
|
|
|
if unfused_time >= ms1 + ms2: |
|
continue |
|
|
|
triton_choices += 1 |
|
if triton_choices > config.max_epilogue_benchmarked_choices: |
|
break |
|
|
|
|
|
|
|
with node1.node.swap_as_triton_caller(choice): |
|
ms_fused, _ = self.benchmark_fused_nodes(node_list_fused) |
|
|
|
if ms_fused < min_ms_fused: |
|
min_ms_fused = ms_fused |
|
ms_fused_choice = choice |
|
|
|
log_fusion(min_ms_fused, ms1, ms2) |
|
|
|
|
|
|
|
if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None: |
|
node1.node.finalize_as_triton_caller(ms_fused_choice) |
|
return True |
|
else: |
|
return False |
|
else: |
|
try: |
|
ms1, path1 = self.benchmark_fused_nodes(node_list_1) |
|
if math.isinf(ms1): |
|
why("register spilling of the first kernel") |
|
return False |
|
ms2, path2 = self.benchmark_fused_nodes(node_list_2) |
|
if math.isinf(ms2): |
|
why("register spilling of the second kernel") |
|
return False |
|
ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused) |
|
if math.isinf(ms_fused): |
|
why("register spilling of the fused kernel") |
|
return False |
|
except CompilationError as e: |
|
|
|
if "Loop-carried variable" in str(e): |
|
return True |
|
else: |
|
raise |
|
|
|
log_fusion(ms_fused, ms1, ms2) |
|
if ( |
|
is_metric_table_enabled("slow_fusion") |
|
and ms_fused >= ms1 + ms2 |
|
and (path1, path2) not in self.logged_slow_fusion |
|
): |
|
self.logged_slow_fusion.add((path1, path2)) |
|
get_metric_table("slow_fusion").add_row( |
|
lambda: { |
|
"kernel1_path": path1, |
|
"kernel1_latency": ms1, |
|
"kernel2_path": path2, |
|
"kernel2_latency": ms2, |
|
"fused_kernel_path": path_fused, |
|
"fused_kernel_latency": ms_fused, |
|
"slow_down_ratio": ms_fused / (ms1 + ms2), |
|
} |
|
) |
|
return ms_fused < ms1 + ms2 |
|
|
|
def fuse_nodes_once(self) -> None: |
|
""" |
|
Mutates self.nodes to combine nodes into FusedSchedulerNodes. |
|
|
|
This relies on two key functions to control the logic: |
|
- self.can_fuse(): checks if a fusion is legal |
|
- self.score_fusion(): assigns priority to a given fusion |
|
""" |
|
fused_nodes = set(self.nodes) |
|
for node1, node2 in self.get_possible_fusions(): |
|
node1 = self.name_to_fused_node[node1.get_first_name()] |
|
node2 = self.name_to_fused_node[node2.get_first_name()] |
|
if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( |
|
node1, node2 |
|
): |
|
if not self.speedup_by_fusion(node1, node2): |
|
continue |
|
fusion_log.debug( |
|
"fusing %s with %s", node1.get_name(), node2.get_name() |
|
) |
|
|
|
|
|
device = node1.get_device() |
|
node3 = self.get_backend(device).fuse(node1, node2) |
|
fused_nodes.remove(node1) |
|
fused_nodes.remove(node2) |
|
fused_nodes.add(node3) |
|
self.name_to_fused_node.update( |
|
{n.get_name(): node3 for n in node3.get_nodes()} |
|
) |
|
self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) |
|
self.topological_sort_schedule() |
|
self.prune_redundant_deps() |
|
|
|
def prune_redundant_deps(self) -> None: |
|
for node in self.nodes: |
|
node.prune_redundant_deps(self.name_to_fused_node) |
|
|
|
def get_possible_fusions(self) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: |
|
""" |
|
Helper to find all legal fusion opportunities, sorted by self.score_fusion() |
|
""" |
|
possible_fusions = [] |
|
seen = set() |
|
|
|
def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None: |
|
for node1_index, node1 in enumerate(nodes): |
|
for node2 in nodes[node1_index + 1 :]: |
|
key = (node1, node2) |
|
if key in seen: |
|
continue |
|
seen.add(key) |
|
|
|
if self.can_fuse(node1, node2): |
|
possible_fusions.append(key) |
|
elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( |
|
node2, node1 |
|
): |
|
|
|
possible_fusions.append((node2, node1)) |
|
|
|
buffer_names_grouping = collections.defaultdict(list) |
|
for node in self.nodes: |
|
for buf in node.used_buffer_names(): |
|
buffer_names_grouping[buf].append(node) |
|
for node_grouping in buffer_names_grouping.values(): |
|
check_all_pairs(node_grouping) |
|
|
|
if config.aggressive_fusion: |
|
group_grouping = collections.defaultdict(list) |
|
for node in self.nodes: |
|
group = getattr(node, "group", None) |
|
if group: |
|
group_grouping[group].append(node) |
|
for node_grouping in group_grouping.values(): |
|
check_all_pairs(node_grouping) |
|
|
|
possible_fusions = self.get_possible_fusions_with_highest_priority( |
|
possible_fusions |
|
) |
|
possible_fusions.sort(key=self.score_fusion_key, reverse=True) |
|
fusion_log.debug("found %d possible fusions", len(possible_fusions)) |
|
return possible_fusions |
|
|
|
def will_fusion_create_cycle( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> bool: |
|
""" |
|
Finds whether there's a path from node1 to node2 (or vice-versa) |
|
caused indirectly by other fusions. |
|
""" |
|
|
|
visited = set() |
|
|
|
def found_path(node: BaseSchedulerNode) -> bool: |
|
|
|
if isinstance(node, FusedSchedulerNode) and node not in visited: |
|
visited.add(node) |
|
if node.get_names().issubset(combined_ancestors): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return False |
|
else: |
|
|
|
return bool(combined_names & node.ancestors) or any( |
|
found_path(self.name_to_fused_node[n]) |
|
for n in node.ancestors - combined_ancestors |
|
) |
|
return False |
|
|
|
combined_names = node1.get_names() | node2.get_names() |
|
combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names |
|
cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) |
|
if cycle: |
|
WhyNoFuse(node1, node2)("will create cycle") |
|
return cycle |
|
|
|
def can_fusion_increase_peak_memory( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> bool: |
|
""" |
|
This function prevents fusion for nodes that can increase memory |
|
footprint. This problem is more common in horizontal fusion, where nodes |
|
that are far apart in the original order get fused, lengthening the live |
|
intervals of tensors. This is very evident in models with activation |
|
checkpointing, where the recomputed nodes from different checkpointed |
|
regions get fused and significantly increase the memory footprint. |
|
|
|
The current attempt is a quick, possibly hacky, heuristic to prevent the |
|
fusion of nodes that are far away in the original order. |
|
|
|
A better but difficult to implement heurisitic would be to use live |
|
intervals of the buffers, find region of peak pressure in the original |
|
program and prevent fusion that crosses that peak region. We might need |
|
special care or good approximation in this implementation, as fusion of |
|
node changes live intervals, and re-computing live intervals and peak |
|
memory after each fusion can introduce large compilation overhead. |
|
""" |
|
proximity_score = max( |
|
abs(node1.min_order - node2.max_order), |
|
abs(node2.min_order - node1.max_order), |
|
) |
|
return proximity_score > 64 |
|
|
|
def decide_fusion_fail_reason( |
|
self, |
|
node1: BaseSchedulerNode, |
|
node2: BaseSchedulerNode, |
|
common_buf_names: Tuple[str, ...], |
|
) -> str: |
|
""" |
|
Try to decide reasons why fusion fail due to no shared memory even though |
|
there are common buffers. |
|
""" |
|
reasons = {} |
|
node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} |
|
node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} |
|
|
|
for buf_name in common_buf_names: |
|
buf = V.graph.get_buffer(buf_name) |
|
lhs_dep = node1_name2dep[buf_name] |
|
rhs_dep = node2_name2dep[buf_name] |
|
|
|
if lhs_dep.get_numel() != rhs_dep.get_numel(): |
|
reasons[ |
|
buf_name |
|
] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" |
|
continue |
|
|
|
|
|
if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size): |
|
reasons[buf_name] = "broadcast" |
|
continue |
|
|
|
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): |
|
reasons[ |
|
buf_name |
|
] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" |
|
continue |
|
|
|
lhs_off = lhs_dep.get_offset() |
|
rhs_off = rhs_dep.get_offset() |
|
if lhs_off != rhs_off: |
|
|
|
|
|
|
|
reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}" |
|
continue |
|
|
|
if ( |
|
lhs_dep.normalize_with_stride_order() |
|
== rhs_dep.normalize_with_stride_order() |
|
): |
|
reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}" |
|
continue |
|
|
|
|
|
reasons[ |
|
buf_name |
|
] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. Layout: {buf.layout}" |
|
|
|
return str(reasons) |
|
|
|
def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: |
|
""" |
|
Determine if it is possible to combine node1 and node2 into a |
|
single fused node. |
|
""" |
|
|
|
if node1 is node2: |
|
return False |
|
|
|
why = WhyNoFuse(node1, node2) |
|
|
|
if ( |
|
isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) |
|
and not node1.is_template() |
|
): |
|
why("node1 is extern or nop") |
|
return False |
|
if ( |
|
isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) |
|
and not node2.is_template() |
|
): |
|
why("node2 is extern or nop") |
|
return False |
|
|
|
if node2.get_names() & node1.ancestors: |
|
why("node1 must go before node2") |
|
return False |
|
|
|
if node2.is_template(): |
|
why("templates can only fuse epilogues") |
|
return False |
|
if node1.is_template() and ( |
|
node2.has_aliasing_or_mutation() |
|
or node2.is_reduction() |
|
or not config.epilogue_fusion |
|
): |
|
why("template epilogue not satisfied") |
|
return False |
|
|
|
device = node1.get_device() |
|
device2 = node2.get_device() |
|
if device != device2: |
|
why("device mismatch (%s vs %s)", device, device2) |
|
return False |
|
del device2 |
|
|
|
no_shared_data = self.score_fusion_memory(node1, node2) == 0 |
|
if no_shared_data and ( |
|
not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() |
|
): |
|
if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"): |
|
common_buf_names = ( |
|
node1.read_writes.buffer_names() & node2.read_writes.buffer_names() |
|
) |
|
if len(common_buf_names) > 0: |
|
get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row( |
|
lambda: { |
|
"pre_grad_graph_id": V.graph.graph_id, |
|
"post_grad_graph_id": V.graph.post_grad_graph_id, |
|
"node1_name": node1.get_name(), |
|
"node2_name": node2.get_name(), |
|
"node1_debug_str": write_text(node1.debug_str()), |
|
"node2_debug_str": write_text(node2.debug_str()), |
|
"common_buffer_names": list(common_buf_names), |
|
"failure_reason": self.decide_fusion_fail_reason( |
|
node1, node2, common_buf_names |
|
), |
|
} |
|
) |
|
|
|
why("no shared data due to indexing mismatch") |
|
return False |
|
why("no shared data") |
|
return False |
|
|
|
if ( |
|
not node1.is_foreach() |
|
and not node2.is_foreach() |
|
and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size |
|
): |
|
why("exceeds max fusion") |
|
return False |
|
|
|
if node1.get_names() & node2.ancestors: |
|
|
|
if not self.can_fuse_vertical(node1, node2): |
|
return False |
|
return self.get_backend(device).can_fuse_vertical(node1, node2) |
|
else: |
|
if self.can_fusion_increase_peak_memory(node1, node2): |
|
why("will increase peak memory") |
|
return False |
|
return self.get_backend(device).can_fuse_horizontal(node1, node2) |
|
|
|
def can_fuse_vertical( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> bool: |
|
""" |
|
Check if it is legal to fuse a consumer (node2) into a producer (node1). |
|
|
|
We can fuse them if all the reads of node2 either match |
|
corresponding writes in node1, or are written by nodes that can |
|
be scheduled before the fusion of node1 and node2. |
|
|
|
We also disable fusion of a write subsequent to a read if the reads |
|
and writes do not align. |
|
""" |
|
node1_names = node1.get_names() |
|
computed_deps = set() |
|
why = WhyNoFuse(node1, node2) |
|
|
|
for cd in node1.read_writes.writes: |
|
if not isinstance(cd, MemoryDep): |
|
continue |
|
for rd in node2.unmet_dependencies: |
|
if self.fusable_read_and_write(rd, cd): |
|
computed_deps.add(rd) |
|
|
|
remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps} |
|
if remaining_deps & node1_names: |
|
|
|
|
|
|
|
|
|
why("memory deps did not match") |
|
return False |
|
for name in remaining_deps: |
|
if node1_names & self.name_to_fused_node[name].ancestors: |
|
why("intermediate nodes between node1 & node2") |
|
return False |
|
|
|
|
|
|
|
for write in node2.read_writes.writes: |
|
if not isinstance(write, MemoryDep): |
|
continue |
|
for read in node1.read_writes.reads: |
|
if write.name != self.mutation_renames.get(read.name, read.name): |
|
continue |
|
|
|
|
|
if not self.fusable_read_and_write(read, write): |
|
why("fusing a write into a read with different indexing formula") |
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: |
|
if isinstance(read, MemoryDep): |
|
if read.mode == write.mode and write.mode is not None: |
|
return True |
|
read_name = read.name |
|
if read_name in self.mutation_renames: |
|
read_name = self.mutation_renames[read_name] |
|
return ( |
|
read_name == write.name |
|
and not free_symbol_is_type(read.index, SymT.TMP) |
|
and not free_symbol_is_type(write.index, SymT.TMP) |
|
and read.index == write.index |
|
and len(read.size) >= len(write.size) |
|
and read.size[: len(write.size)] == write.size |
|
) |
|
elif isinstance(read, StarDep): |
|
read_name = self.mutation_renames.get(read.name, read.name) |
|
write_name = self.mutation_renames.get(write.name, write.name) |
|
if ( |
|
read.mode == write.mode |
|
and write.mode is not None |
|
and read_name == write_name |
|
): |
|
return True |
|
return False |
|
|
|
def score_fusion( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> Tuple[bool, bool, int, int]: |
|
""" |
|
Assign a score (higher comes first) to the fusion of node1 |
|
and node2. When different fusions conflict with each other, |
|
this is the way we decide what order to run them in. |
|
|
|
Our current score is based on: |
|
- Estimate of the saved memory operations |
|
- Fusions closer together in original order |
|
""" |
|
memory_score = self.score_fusion_memory(node1, node2) |
|
proximity_score = -max( |
|
abs(node1.min_order - node2.max_order), |
|
abs(node2.min_order - node1.max_order), |
|
) |
|
return ( |
|
node1.is_template() == config.epilogue_fusion_first and memory_score > 0, |
|
node1.is_reduction() == node2.is_reduction() and memory_score > 0, |
|
memory_score, |
|
proximity_score, |
|
) |
|
|
|
def dep_size_hint(self, dep: Dep) -> int: |
|
res = 0 |
|
if dep not in self.__dep_size_hint_cache: |
|
try: |
|
if not dep.has_unbacked_symbols(): |
|
res = dep.numbytes_hint() |
|
except KeyError: |
|
|
|
|
|
|
|
pass |
|
self.__dep_size_hint_cache[dep] = res |
|
else: |
|
res = self.__dep_size_hint_cache[dep] |
|
return res |
|
|
|
def score_fusion_memory( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> int: |
|
""" |
|
The first term in our fusion score that estimates number of saved |
|
memory operations. |
|
""" |
|
common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( |
|
node2.read_writes.reads | node2.read_writes.writes |
|
) |
|
return sum(self.dep_size_hint(dep) for dep in common_memory_deps) |
|
|
|
def get_possible_fusions_with_highest_priority( |
|
self, possible_fusions: List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] |
|
) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: |
|
|
|
|
|
if len(possible_fusions) == 0: |
|
return possible_fusions |
|
possible_fusions_group_by_priority: Dict[ |
|
int, List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] |
|
] = {} |
|
|
|
for node1, node2 in possible_fusions: |
|
assert node1.get_device() == node2.get_device() |
|
device = node1.get_device() |
|
fusion_pair_priority = int( |
|
self.get_backend(device).get_fusion_pair_priority(node1, node2) |
|
) |
|
if fusion_pair_priority not in possible_fusions_group_by_priority: |
|
possible_fusions_group_by_priority[fusion_pair_priority] = [ |
|
(node1, node2), |
|
] |
|
else: |
|
possible_fusions_group_by_priority[fusion_pair_priority].append( |
|
(node1, node2) |
|
) |
|
|
|
possible_fusions_with_highest_priority = min( |
|
possible_fusions_group_by_priority.items(), key=operator.itemgetter(0) |
|
)[1] |
|
assert len(possible_fusions_with_highest_priority) > 0 |
|
return possible_fusions_with_highest_priority |
|
|
|
def score_fusion_key( |
|
self, nodes: Tuple[BaseSchedulerNode, BaseSchedulerNode] |
|
) -> Tuple[bool, bool, int, int]: |
|
""" |
|
Shim for list.sort(key=...) |
|
""" |
|
node1, node2 = nodes |
|
return self.score_fusion(node1, node2) |
|
|
|
def compute_last_usage(self) -> None: |
|
""" |
|
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) |
|
""" |
|
|
|
future_used_buffers = set(V.graph.get_output_names()) |
|
|
|
for node in reversed(self.nodes): |
|
node.set_last_usage(future_used_buffers, self.mutation_real_name) |
|
future_used_buffers.update(node.last_usage) |
|
|
|
def free_buffers(self) -> None: |
|
"""Free any buffers that are no longer needed""" |
|
for name in sorted( |
|
self.buffer_names_to_free |
|
- V.graph.removed_buffers |
|
- V.graph.wrapper_code.freed |
|
): |
|
if name in self.name_to_node: |
|
node = self.name_to_node[name] |
|
if node.can_free(): |
|
V.graph.wrapper_code.codegen_free(node.node) |
|
elif name in V.graph.graph_inputs: |
|
storage = V.graph.graph_inputs[name].data |
|
assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() |
|
V.graph.wrapper_code.codegen_free(storage.data) |
|
|
|
self.buffer_names_to_free.clear() |
|
|
|
def remove_kernel_local_buffers(self) -> None: |
|
""" |
|
Any buffers that are both created and have a last use in the |
|
same kernel can be removed. |
|
""" |
|
|
|
|
|
|
|
fused_node_names = V.kernel.store_buffer_names |
|
names_to_remove = [] |
|
for out_buf in V.kernel.store_buffer_names: |
|
users = self.name_to_node[out_buf].users |
|
assert users is not None |
|
users = {user.get_name() for user in users if not user.is_weak} |
|
if users.issubset(fused_node_names): |
|
names_to_remove.append(out_buf) |
|
|
|
def remove_filter(n: str) -> bool: |
|
return ( |
|
n not in V.kernel.must_keep_buffers |
|
and n not in V.kernel.args.input_buffers |
|
and n not in self.mutation_renames |
|
and n not in self.mutation_real_name |
|
) |
|
|
|
names_to_remove = list(filter(remove_filter, names_to_remove)) |
|
|
|
for name in names_to_remove: |
|
if name in V.kernel.args.inplace_buffers: |
|
buf = V.kernel.args.inplace_buffers[name] |
|
if isinstance(buf, str) and buf.startswith("REMOVED"): |
|
continue |
|
remove = all(n in names_to_remove for n in buf.other_names) |
|
if remove: |
|
self.remove_inplace_buffer(name) |
|
V.kernel.inplaced_to_remove.add(name) |
|
else: |
|
self.remove_buffer(name) |
|
|
|
def remove_buffer(self, name: str) -> None: |
|
|
|
|
|
|
|
log.debug("remove_buffer(%r)", name) |
|
V.kernel.args.output_buffers[name] = "REMOVED" |
|
V.kernel.removed_buffers.add(name) |
|
|
|
def remove_inplace_buffer(self, name: str) -> None: |
|
log.debug("removing_inplace_buffer(%r)", name) |
|
inner_name = V.kernel.args.inplace_buffers[name].inner_name |
|
V.kernel.args.inplace_buffers[name] = inner_name.replace( |
|
"in_out_ptr", "REMOVED" |
|
) |
|
V.kernel.removed_buffers.add(name) |
|
|
|
def flush(self) -> None: |
|
for backend in self.backends.values(): |
|
backend.flush() |
|
self.free_buffers() |
|
|
|
def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None: |
|
assert isinstance(scheduler_node, ExternKernelSchedulerNode) |
|
|
|
|
|
|
|
|
|
counters["inductor"]["extern_calls"] += 1 |
|
with V.set_kernel_handler(Kernel(increase_kernel_count=False)): |
|
scheduler_node.decide_inplace_update() |
|
scheduler_node.allocate() |
|
node = scheduler_node.node |
|
assert isinstance(node, ir.ExternKernel), f"{type(node)=}" |
|
node.codegen(V.graph.wrapper_code) |
|
self.free_buffers() |
|
|
|
def create_backend(self, device: torch.device) -> "BaseScheduling": |
|
assert ( |
|
not is_gpu(device.type) or device.index is not None |
|
), f"{device} should have been normalized in lowering" |
|
V.graph.add_device_info(device) |
|
|
|
device_scheduling = get_scheduling_for_device(device.type) |
|
if device_scheduling is None: |
|
raise RuntimeError(f"Unsupported device type: {device.type}") |
|
|
|
if not has_triton(): |
|
if ( |
|
device.type == "cuda" |
|
and (device_props := torch.cuda.get_device_properties(device)).major < 7 |
|
): |
|
raise RuntimeError( |
|
f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" |
|
) |
|
elif is_gpu(device.type): |
|
raise RuntimeError( |
|
"Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" |
|
) |
|
|
|
return device_scheduling(self) |
|
|
|
def get_backend(self, device: torch.device) -> "BaseScheduling": |
|
if device not in self.backends: |
|
self.backends[device] = self.create_backend(device) |
|
return self.backends[device] |
|
|
|
def enter_context(self, node: BaseSchedulerNode) -> None: |
|
def get_order(n: torch.fx.Node) -> int: |
|
if n not in self.origin_to_index: |
|
self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) |
|
return self.origin_to_index[n] |
|
|
|
|
|
origins = { |
|
(get_order(e), e): None |
|
for n in node.get_nodes() |
|
if n.node is not None |
|
for e in n.node.origins |
|
} |
|
origins = list(origins.keys()) |
|
if origins: |
|
_, last = max(origins, key=operator.itemgetter(0)) |
|
V.graph.wrapper_code.enter_context(last) |
|
|
|
@dynamo_timed |
|
def codegen(self) -> None: |
|
for node in self.nodes: |
|
try: |
|
log.debug( |
|
"Generating code for node %s with estimated runtime %f", |
|
node.get_name(), |
|
node.get_estimated_runtime(), |
|
) |
|
except Exception as e: |
|
log.debug( |
|
"Generating code for node %s with estimated runtime 0.0", |
|
node.get_name(), |
|
) |
|
|
|
self.enter_context(node) |
|
|
|
if not isinstance(node, NopKernelSchedulerNode) and ( |
|
device := node.get_device() |
|
): |
|
if ( |
|
device != self.current_device |
|
or node.is_extern() |
|
or node.is_template() |
|
): |
|
self.flush() |
|
if device != self.current_device: |
|
if self.current_device and device_need_guard( |
|
self.current_device.type |
|
): |
|
V.graph.wrapper_code.codegen_device_guard_exit() |
|
if device_need_guard(device.type): |
|
assert device.index is not None, "device should have an index" |
|
V.graph.wrapper_code.codegen_device_guard_enter(device.index) |
|
|
|
self.current_device = device |
|
|
|
self.buffer_names_to_free.update(node.last_usage) |
|
|
|
if node.is_template(): |
|
node, *epilogue = node.get_nodes() |
|
self.get_backend(device).codegen_template(node, epilogue) |
|
elif node.is_extern(): |
|
node = typing.cast(ExternKernelSchedulerNode, node) |
|
self.codegen_extern_call(node) |
|
elif node.is_foreach(): |
|
node = typing.cast(ForeachKernelSchedulerNode, node) |
|
backend_ = self.get_backend(device) |
|
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling |
|
from .codegen.simd import SIMDScheduling |
|
|
|
if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): |
|
backend = backend_ |
|
else: |
|
raise AssertionError(f"{type(self)=}") |
|
backend.codegen_foreach(node) |
|
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): |
|
self.get_backend(device).codegen_node(node) |
|
else: |
|
assert isinstance(node, NopKernelSchedulerNode) |
|
node.allocate() |
|
|
|
if config.triton.debug_sync_kernel: |
|
self.get_backend(device).codegen_sync() |
|
|
|
self.available_buffer_names.update(node.get_names()) |
|
|
|
if not isinstance(node, NopKernelSchedulerNode): |
|
device = node.get_device() |
|
if device is not None and self.get_backend(device).ready_to_flush(): |
|
self.flush() |
|
|
|
if self.current_device and device_need_guard(self.current_device.type): |
|
|
|
|
|
V.graph.wrapper_code.codegen_device_guard_exit() |
|
|
|
self.flush() |
|
|
|
def get_buffer_layout(self, buf_name: str) -> ir.Layout: |
|
node = self.name_to_node[buf_name] |
|
assert node.node is not None |
|
return node.node.get_layout() |
|
|
|
|
|
class BaseScheduling: |
|
def can_fuse_vertical( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> bool: |
|
""" |
|
Check whether node1 and node2 can be vertically fused or not. |
|
""" |
|
raise NotImplementedError |
|
|
|
def can_fuse_horizontal( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> bool: |
|
""" |
|
Check whether node1 and node2 can be horizontally fused or not. |
|
""" |
|
raise NotImplementedError |
|
|
|
def fuse( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> FusedSchedulerNode: |
|
""" |
|
Fuse two nodes |
|
""" |
|
if node1.is_foreach() or node2.is_foreach(): |
|
return ForeachKernelSchedulerNode.fuse(node1, node2) |
|
else: |
|
return FusedSchedulerNode.fuse(node1, node2) |
|
|
|
def group_fn( |
|
self, sizes: Sequence[Sequence[sympy.Expr]] |
|
) -> Tuple[Tuple[sympy.Expr, ...], ...]: |
|
""" |
|
Process the iteration sizes in case a transformation needs to be applied. |
|
""" |
|
raise NotImplementedError |
|
|
|
def codegen_template( |
|
self, |
|
template_node: BaseSchedulerNode, |
|
epilogue_nodes: Sequence[BaseSchedulerNode], |
|
) -> Optional[str]: |
|
""" |
|
Given a template node, generate a kernel. |
|
|
|
This function is only available for triton now. If the third-party backend behaves as a sub-class |
|
of TritonScheduling, it can override it or reuse it. |
|
""" |
|
raise NotImplementedError |
|
|
|
def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: |
|
""" |
|
Generate a kernel given a list of pre-fused nodes. |
|
""" |
|
raise NotImplementedError |
|
|
|
def codegen_sync(self) -> None: |
|
""" |
|
Generate synchronization code for the kernel. This method depends on the hardware characteristics. |
|
""" |
|
raise NotImplementedError |
|
|
|
def ready_to_flush(self) -> bool: |
|
""" |
|
Check whether the backend is requesting the scheduler to flush the generated kernel. |
|
If not supported, please return False. |
|
""" |
|
return False |
|
|
|
def flush(self) -> None: |
|
""" |
|
Flush the generated kernel and python wrapper code to the source code file. |
|
""" |
|
raise NotImplementedError |
|
|
|
def benchmark_fused_nodes( |
|
self, nodes: Sequence[BaseSchedulerNode] |
|
) -> Tuple[float, str]: |
|
""" |
|
Benchmark fused list of nodes and return the execution time |
|
in milliseconds on randomly generated inputs. |
|
""" |
|
raise NotImplementedError |
|
|
|
def get_fusion_pair_priority( |
|
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode |
|
) -> int: |
|
""" |
|
Return an unsigned integer which represents the priority of this fusion pair. |
|
The smaller is with higher priority. |
|
""" |
|
return 0 |
|
|
|
|
|
def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]: |
|
lines = [] |
|
multi_template = node.get_template_node() |
|
assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) |
|
if multi_template and multi_template.make_kernel_render is None: |
|
lines.append(f"{node.get_name()} Unfinalized multi template buffer") |
|
else: |
|
from torch._inductor.codegen.cuda_combined_scheduling import ( |
|
CUDACombinedScheduling, |
|
) |
|
from torch._inductor.codegen.triton import TritonScheduling |
|
|
|
snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes |
|
device = snodes[0].get_device() |
|
backend = node.scheduler.get_backend(device) |
|
assert isinstance(backend, (TritonScheduling, CUDACombinedScheduling)) |
|
V.graph.scheduler.current_device = device |
|
|
|
|
|
|
|
|
|
old_generated_kernel_count = metrics.generated_kernel_count |
|
triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() |
|
metrics.generated_kernel_count = old_generated_kernel_count |
|
|
|
lines.append(f"{node.get_name()} Triton code:") |
|
lines.append(textwrap.indent(triton_code, " ")) |
|
return lines |
|
|