Spaces:
Running
Running
import collections | |
import dataclasses | |
import functools | |
import itertools | |
import logging | |
import math | |
import operator | |
import os | |
import pprint | |
import textwrap | |
from typing import ( | |
Any, | |
Counter, | |
DefaultDict, | |
Dict, | |
Generic, | |
List, | |
Optional, | |
Sequence, | |
Set, | |
Tuple, | |
TypeVar, | |
Union, | |
) | |
import sympy | |
import torch | |
from torch._dynamo.utils import dynamo_timed | |
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled | |
from torch.utils._triton import has_triton | |
from . import comms, config, dependencies, ir, metrics | |
from .codegen.common import get_scheduling_for_device, Kernel | |
from .comm_analysis import estimate_nccl_collective_runtime | |
from .dependencies import Dep, MemoryDep, StarDep, WeakDep | |
from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout | |
from .sizevars import SimplifyIndexing | |
from .utils import ( | |
cache_on_self, | |
cmp, | |
free_symbol_has, | |
get_device_tflops, | |
get_dtype_size, | |
get_gpu_dram_gbps, | |
green_text, | |
is_collective, | |
is_wait, | |
red_text, | |
sympy_product, | |
) | |
from .virtualized import V | |
log = logging.getLogger(__name__) | |
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") | |
class WhyNoFuse: | |
# TODO when we drop support for Python < 3.10, we can use | |
# @dataclass(slots=True) instead of manually specifying __slots__. | |
__slots__ = ["node1", "node2", "reason", "args"] | |
reason: str | |
args: Tuple[Any, ...] | |
def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"): | |
self.node1 = node1 | |
self.node2 = node2 | |
def __call__(self, reason, *args): | |
self.reason = reason | |
self.args = args | |
fusion_log.debug(self) | |
def __str__(self): | |
return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( | |
self.reason % self.args | |
) | |
def pformat(obj): | |
if isinstance(obj, set): | |
# pformat has trouble with sets of sympy exprs | |
obj = sorted(obj, key=str) | |
result = pprint.pformat(obj, indent=4) | |
if "\n" in result: | |
return f"\n{textwrap.indent(result, ' '*4)}" | |
return result | |
class OutputNode: | |
def __init__(self, dep): | |
self.unmet_dependencies = {dep} | |
self.inverse_users = [] | |
def is_reduction(self): | |
return False | |
def get_alias_names(self): | |
return () | |
def get_name(self): | |
return "OUTPUT" | |
__repr__ = get_name | |
def _prune_redundant_deps(node, name_to_fused_node): | |
""" | |
Prunes weakdeps intended for mutation ordering | |
on an upstream fused node if after fusion there is another dependency | |
on the fused upstream node, making the weakdep redundant | |
In essence this enforces an ordering on fusions. As fusions occur, weakdeps will | |
be incrementally removed, enabling other fusions, ensuring they are fused in order. | |
""" | |
name_to_dep_count: Counter[str] = collections.Counter() | |
for dep in node.unmet_dependencies: | |
if not isinstance(dep, WeakDep): | |
name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 | |
def should_prune(dep): | |
if isinstance(dep, WeakDep): | |
is_redundant = ( | |
name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 | |
) | |
# These can occur because fused nodes always gather deps from their snodes | |
# If B has a weakdep on A | |
# B gets fused with C, then any time BC is fused, the weakdep will reappear | |
is_self_dep = name_to_fused_node[dep.name] == node | |
return is_redundant or is_self_dep | |
else: | |
return False | |
deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)} | |
if deps_to_prune: | |
node.unmet_dependencies = node.unmet_dependencies - deps_to_prune | |
node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) | |
# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel | |
kernel_name_to_op = { | |
"extern_kernels.convolution": torch.ops.aten.convolution, | |
"extern_kernels.mm": torch.ops.aten.mm, | |
"extern_kernels.bmm": torch.ops.aten.bmm, | |
"extern_kernels.addmm": torch.ops.aten.addmm, | |
} | |
class BaseSchedulerNode: | |
def __init__(self, scheduler: "Scheduler", node: ir.Buffer): | |
self.scheduler: Scheduler = scheduler | |
self.node: ir.Buffer = node | |
self.users: List[NodeUser] = [] | |
self.inverse_users: List[BaseSchedulerNode] = [] | |
self.node_users: List[BaseSchedulerNode] = [] | |
self.set_read_writes(node.get_read_writes()) | |
self.ancestors: Set[str] = set() | |
self.min_order: int | |
self.max_order: int | |
self.last_usage: Set[ | |
str | |
] = set() # buffers that won't be used after this kernel | |
self.written = False | |
def __repr__(self): | |
return f"{type(self).__name__}(name={self.get_name()!r})" | |
def debug_str(self) -> str: | |
"""Longer form printout for trace logs""" | |
name = self.get_name() | |
lines = [ | |
f"{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})", | |
f"{name}.writes = {pformat(self.read_writes.writes)}", | |
f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}", | |
f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}", | |
f"{name}.users = {self.users}", | |
] | |
try: | |
lines += [ | |
self.debug_str_extra(), | |
] | |
except Exception: | |
log.warning("Ignoring error in debug_str()", exc_info=True) | |
return "\n".join(lines).rstrip() | |
def debug_str_extra(self) -> str: | |
return "" | |
def log_details(self): | |
log.info( | |
"%s: unmet_dependencies = %s, writes = %s", | |
self, | |
self.unmet_dependencies, | |
self.read_writes.writes, | |
) | |
def update_mutated_names(self, renames: Dict[str, str]): | |
self.set_read_writes(self.read_writes.rename(renames)) | |
def add_mutation_dep(self, dep): | |
self.set_read_writes(self.read_writes.with_read(dep)) | |
def add_fake_dep(self, dep): | |
self.set_read_writes(self.read_writes.with_read(dep)) | |
def set_users(self, users: List["NodeUser"]): | |
# deduplicate | |
result: Dict[int, NodeUser] = {} | |
for use in users: | |
if id(use.node) in result: | |
result[id(use.node)] = use.merge(result[id(use.node)]) | |
else: | |
result[id(use.node)] = use | |
self.users = list(result.values()) | |
def set_last_usage( | |
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] | |
): | |
used_buffers = self.used_or_aliased_buffer_names() | |
used_buffers = {mutation_real_name.get(k, k) for k in used_buffers} | |
self.last_usage = used_buffers - future_used_buffers | |
def get_aliases(self): | |
return self.node.get_alias_names() | |
def get_mutations(self): | |
return self.node.get_mutation_names() | |
def has_aliasing_or_mutation(self): | |
return bool(self.get_aliases() or self.get_mutations()) | |
def set_read_writes(self, rw: dependencies.ReadWrites): | |
self.read_writes: dependencies.ReadWrites = rw | |
self.unmet_dependencies = self.read_writes.reads | |
self.prune_deps() | |
def op_counts(self): | |
return self.read_writes.op_counts | |
def used_buffer_names(self) -> Set[str]: | |
return { | |
dep.name | |
for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) | |
} | |
def used_or_aliased_buffer_names(self) -> Set[str]: | |
used_names = set() | |
for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes): | |
used_names.add(dep.name) | |
if V.graph.name_to_buffer.get(dep.name): | |
layout = V.graph.name_to_buffer[dep.name].get_layout() | |
# needed to avoid deallocating aliased buffer | |
# if there are still uses of aliases ahead | |
if isinstance(layout, ir.AliasedLayout): | |
used_names.add(layout.view.data.get_name()) | |
return used_names | |
def prune_deps(self): | |
self.unmet_dependencies = { | |
dep | |
for dep in self.unmet_dependencies | |
if dep.name not in self.scheduler.available_buffer_names | |
} | |
def prune_weak_deps(self): | |
# Prune weak dependencies on buffers that have been removed | |
def should_prune(dep): | |
return isinstance(dep, WeakDep) and dep.name in V.graph.removed_buffers | |
to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)} | |
self.set_read_writes(self.read_writes.remove_reads(to_remove)) | |
def prune_redundant_deps(self, name_to_fused_node): | |
_prune_redundant_deps(self, name_to_fused_node) | |
def get_name(self) -> str: | |
return self.node.get_name() | |
def get_first_name(self) -> str: | |
return self.get_name() | |
def get_names(self) -> Set[str]: | |
return {self.get_name()} | |
def get_nodes(self) -> Sequence["BaseSchedulerNode"]: | |
return [self] | |
def get_device(self): | |
return self.node.get_device() | |
def is_reduction(self): | |
return False | |
def is_split_scan(self): | |
return False | |
def is_template(self): | |
return False | |
def is_extern(self): | |
return False | |
def is_foreach(self): | |
return False | |
def can_inplace(self, read_dep: dependencies.MemoryDep): | |
return False | |
def has_side_effects(self): | |
return False | |
def decide_inplace_update(self): | |
""" | |
Decide if there should be inplace updates for the node | |
and record the decision in the active kernel. | |
""" | |
if not self.node.should_allocate(): | |
return | |
if isinstance(self, (SchedulerNode,)) and ( | |
self.node.get_alias_names() or self.node.get_mutation_names() | |
): | |
return | |
if ( | |
( | |
isinstance(self, (SchedulerNode,)) | |
# o what have i done. lets make this an api | |
or ( | |
isinstance(self, ExternKernelSchedulerNode) | |
and isinstance(self.node, (ir.AllReduce, ir.InPlaceHint)) | |
) | |
) | |
and config.inplace_buffers | |
and ( | |
not isinstance(V.kernel, torch._inductor.codegen.triton.TritonKernel) | |
or getattr(V.kernel, "mutations", None) is not None | |
) | |
): | |
from .codegen.wrapper import buffer_reuse_key | |
ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name) | |
for read in ordered_reads: | |
input_node: Optional[ | |
BaseSchedulerNode | |
] = self.scheduler.name_to_node.get(read.name) | |
if input_node and V.graph.wrapper_code.can_reuse(input_node, self): | |
assert input_node.users is not None | |
remaining_uses = [ | |
x | |
for x in input_node.users | |
if x.node.get_name() | |
not in self.scheduler.available_buffer_names | |
] | |
if ( | |
len(remaining_uses) == 1 | |
and remaining_uses[0].can_inplace | |
and remaining_uses[0].node is self | |
and not isinstance( | |
input_node.node.get_layout(), | |
( | |
ir.MultiOutputLayout, | |
ir.MutationLayout, | |
ir.AliasedLayout, | |
), | |
) | |
and not ( | |
isinstance( | |
input_node.node, (ir.FallbackKernel, ir.MultiOutput) | |
) | |
and len(input_node.node.get_alias_names()) > 0 | |
) | |
and buffer_reuse_key(input_node.node) | |
== buffer_reuse_key(self.node) | |
): | |
# hacky check for if V.kernel is a real kernel or NullHandler | |
if hasattr(V.kernel, "args"): | |
# if there isn't a triton kernel, then we don't need to call triton-specific things. | |
# but TODO this might be a convenient place to signal to the Collective kernels to inplace | |
# (and, can we make "kernel" less generic of a name?) | |
V.kernel.args.make_inplace( | |
input_node.get_name(), self.get_name() | |
) | |
# mutations not tracked in cpp kernels | |
if isinstance( | |
V.kernel, torch._inductor.codegen.triton.TritonKernel | |
): | |
V.kernel.mutations.add(input_node.get_name()) | |
V.kernel.mutations.add(self.get_name()) | |
# update last usage of reused node | |
self.last_usage.discard(input_node.get_name()) | |
V.kernel.inplace_update_buffers[ | |
self.get_name() | |
] = input_node.get_name() | |
break | |
def allocate(self): | |
if not self.node.should_allocate(): | |
return | |
if isinstance(self, (SchedulerNode,)) and ( | |
self.node.get_alias_names() or self.node.get_mutation_names() | |
): | |
V.graph.wrapper_code.codegen_allocation(self.node) | |
return | |
# hacky check for if V.kernel is a real kernel or NullHandler | |
if ( | |
hasattr(V.kernel, "args") | |
and self.get_name() in V.kernel.inplace_update_buffers | |
): | |
V.graph.wrapper_code.codegen_inplace_reuse( | |
self.scheduler.name_to_node[ | |
V.kernel.inplace_update_buffers[self.get_name()] | |
].node, | |
self.node, | |
) | |
else: | |
V.graph.wrapper_code.codegen_allocation(self.node) | |
def can_free(self): | |
# There's no real allocated buffer, no need to free it | |
if isinstance(self.node.layout, ir.NoneLayout): | |
return False | |
for use in self.users: | |
if isinstance(use.node, OutputNode): | |
return False | |
return True | |
def codegen_originating_info(self, buffer, only_once=True): | |
if not config.comment_origin: | |
return | |
if only_once and self.written: | |
return | |
origins = self.node.origins | |
out_lines = [] | |
for o in origins: | |
if o.op == "output": | |
# These are boring and samey | |
continue | |
out_lines.append("") | |
# TODO(voz): Should the pragma be constant somewhere? | |
out_lines.append("#pragma CMT ORIGIN:") | |
op_info_str = f"#pragma CMT {o.op} {o.target}" | |
if "seq_nr" in o.meta: | |
op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" | |
out_lines.append(op_info_str) | |
if "stack_trace" in o.meta: | |
stack_trace = f"{o.meta['stack_trace']}" | |
stack_trace_last_line = stack_trace.split("|")[-1] | |
out_lines.append( | |
"#pragma CMT " | |
+ stack_trace_last_line.replace("{", "{{") | |
.replace("}", "}}") | |
.replace("\n", "\\") | |
) | |
out_lines.append("#pragma CMT END ORIGIN") | |
out_lines.append("") | |
if len(out_lines) == 0: | |
return | |
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does | |
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm. | |
buffer.writelines(out_lines) | |
self.written = True | |
def get_read_write_buffers_sizes(self) -> int: | |
""" | |
Counting the number of bytes accessed for a kernel is | |
surprisingly tricky. In particular, there is a differentiation | |
between 'theoretical' memory accesses and practical memory | |
accesses. For example, a layernorm kernel may actually access an | |
input 3 times, but in theory, it only needs to access its input | |
once (and may be optimized to do so through say, persistent | |
reductions) | |
Another example is that even though a buffer is passed in, we may | |
not access the entire buffer. This may occur if we are accessing | |
a slice of the buffer. Another tricky case is for indirect | |
indexing, where the amount of bytes accessed depends on the | |
values of the input. | |
What this function aims to compute is the memory accesses for | |
worst-case inputs, best-case optimization. What this means is | |
that for each buffer we compute the amount of potential accesses in two ways and take the minimum. | |
1. Numel in ranges multiplied by number of deps the buffer has | |
2. The buffer size | |
""" | |
if isinstance(self, NopKernelSchedulerNode): | |
return 0 | |
if isinstance(self, ExternKernelSchedulerNode) and isinstance( | |
self.node, MultiOutput | |
): | |
return 0 | |
if isinstance(self, SchedulerNode): | |
node_numel = V.graph.sizevars.size_hint( | |
sympy_product(self.get_ranges()[0]) | |
* sympy_product(self.get_ranges()[1]) | |
) | |
else: | |
node_numel = int(1e9) | |
buf_accesses = collections.defaultdict(list) | |
for dep in self.read_writes.reads | self.read_writes.writes: | |
buf_accesses[dep.name].append(dep) | |
reads = {dep.name for dep in self.read_writes.reads} | |
writes = {dep.name for dep in self.read_writes.writes} | |
def is_materialized(buf, snodes): | |
users = self.scheduler.name_to_node[buf].users | |
buf_uses = {user.node for user in users} | |
return len(buf_uses - set(snodes)) > 0 | |
if isinstance(self, FusedSchedulerNode): | |
removed_buffers = { | |
dep for dep in writes if not is_materialized(dep, self.snodes) | |
} | |
writes = writes - removed_buffers | |
reads = reads - removed_buffers | |
node_bytes = 0 | |
for buf_name in reads | writes: | |
buf_accessed_elems = sum([node_numel for dep in buf_accesses[buf_name]]) | |
buf: Union[ir.Buffer, ir.TensorBox] | |
if buf_name in V.graph.name_to_buffer: | |
buf = V.graph.name_to_buffer[buf_name] | |
elif buf_name in V.graph.graph_inputs: | |
buf = V.graph.graph_inputs[buf_name] | |
else: | |
continue | |
def get_buf_elems(buf): | |
return V.graph.sizevars.size_hint(sympy_product(buf.get_size())) | |
# Kind of a lazy way to get the MultiOutput nodes corresponding to | |
# a MultiOutputLayout | |
if isinstance(buf.layout, MultiOutputLayout): | |
users = self.scheduler.name_to_node[buf.get_name()].users | |
buf_elems = sum(get_buf_elems(user.node.node) for user in users) | |
else: | |
buf_elems = get_buf_elems(buf) | |
node_bytes += min(buf_elems, buf_accessed_elems) * get_dtype_size( | |
buf.get_dtype() | |
) | |
return node_bytes | |
def get_estimated_runtime(self) -> float: | |
""" | |
Returns estimated op runtime in nanoseconds (ns) | |
""" | |
layout = None | |
dtype = None | |
if not hasattr(self, "node") or not self.node: | |
assert isinstance( | |
self, (FusedSchedulerNode, ForeachKernelSchedulerNode) | |
), f"{type(self)=}" | |
assert self.snodes | |
if not self.snodes[0].node: | |
return 0 | |
layout = self.snodes[0].node.get_layout() | |
dtype = self.snodes[0].node.get_dtype() | |
else: | |
layout = self.node.get_layout() | |
dtype = self.node.get_dtype() | |
if "cuda" != layout.device.type: | |
# default to no reordering based on runtime | |
return 0 | |
# Collective kernels | |
if is_collective(self.node): | |
return estimate_nccl_collective_runtime(self.node) | |
elif is_wait(self.node): | |
# ir.Wait is only used for collective ops. | |
# The time needed for the collective op is already estimated and considered | |
# when we are processing the collective op IR node, so ir.Wait takes 0 time | |
# since it doesn't take extra time to get the result after the collective is completed. | |
return 0 | |
try: | |
gpu_memory_bandwidth = get_gpu_dram_gbps() | |
gpu_flops = get_device_tflops(dtype) * 10**12 | |
except Exception: | |
return 0 | |
if isinstance(self, ExternKernelSchedulerNode): | |
assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}" | |
op = kernel_name_to_op.get( | |
getattr(self.node, "python_kernel_name", ""), None | |
) | |
# if there is a resolved op, dry-run using fake mode and record flop count | |
if op is not None: | |
from torch._subclasses.fake_tensor import FakeTensorMode | |
from torch.utils.flop_counter import FlopCounterMode | |
with FakeTensorMode(), FlopCounterMode( | |
display=False | |
) as flop_counter_mode: | |
from .ir import ir_node_to_tensor | |
fake_inputs = [ | |
ir_node_to_tensor(input, guard_shape=False) | |
for input in self.node.inputs | |
] | |
cls = self.node.__class__ | |
cls.process_kernel(op, *fake_inputs, **self.node.kwargs) | |
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship | |
factor = 1.0 | |
counted_flops = flop_counter_mode.get_total_flops() | |
counted_bytes = self.get_read_write_buffers_sizes() | |
compute_time = (factor * counted_flops / gpu_flops) * 1e9 | |
transfer_time = counted_bytes / gpu_memory_bandwidth | |
# Return estimated runtime in nanoseconds | |
return max(compute_time, transfer_time) | |
elif isinstance(self, FusedSchedulerNode) or isinstance( | |
self.node, ComputedBuffer | |
): | |
# Return estimated runtime in nanoseconds (bytes / gbps) | |
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth | |
return 0 | |
class ExternKernelSchedulerNode(BaseSchedulerNode): | |
def debug_str_extra(self) -> str: | |
return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" | |
def is_extern(self): | |
return True | |
def has_side_effects(self): | |
return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() | |
def can_inplace(self, read_dep: dependencies.MemoryDep): | |
if self.get_aliases() or self.is_template(): | |
return False | |
if read_dep.name not in self.scheduler.name_to_node: | |
# don't allow reuse of an 'input' buffer, we don't own it | |
# (would this have been fixed if I tracked mutations properly above?) | |
return False | |
if not isinstance( | |
self.node, (torch._inductor.ir.AllReduce, torch._inductor.ir.InPlaceHint) | |
): | |
# TODO make this a property of the IR | |
return False | |
if len(self.read_writes.writes) == 1: | |
write_dep = next(iter(self.read_writes.writes)) | |
numel_diff = read_dep.get_numel() - write_dep.get_numel() | |
return V.graph.sizevars.simplify(numel_diff) == 0 | |
return False | |
class NopKernelSchedulerNode(BaseSchedulerNode): | |
pass | |
class SchedulerNode(BaseSchedulerNode): | |
def __init__( | |
self, | |
scheduler: "Scheduler", | |
node: Union[ir.ComputedBuffer, ir.TemplateBuffer], | |
): | |
super().__init__(scheduler, node) | |
self._compute_attrs() | |
def _compute_attrs( | |
self, | |
extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, | |
): | |
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) | |
self._sizes, self._body = self.node.simplify_and_reorder( | |
extra_indexing_constraints=extra_indexing_constraints | |
) | |
group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn | |
self.group = (self.node.get_device(), group_fn(self._sizes)) | |
if isinstance(self.node, ir.TemplateBuffer): | |
self.set_read_writes(self.node.normalized_read_writes()) | |
else: | |
self.set_read_writes( | |
dependencies.extract_read_writes( | |
self._body, *self._sizes, normalize=True | |
) | |
) | |
def recompute_size_and_body( | |
self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] | |
): | |
self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) | |
def debug_str_extra(self) -> str: | |
name = self.get_name() | |
lines = [ | |
f"{name}.group.device = {self.group[0]}", | |
f"{name}.group.iteration = {self.group[1]}", | |
f"{name}.sizes = {self._sizes}", | |
] | |
if self.get_aliases(): | |
lines.append(f"{name}.aliases = {pformat(self.get_aliases())}") | |
if self.get_mutations(): | |
lines.append(f"{name}.mutations = {pformat(self.get_mutations())}") | |
if isinstance(self._body, ir.LoopBody): | |
lines.append(f"class {name}_loop_body:") | |
lines.append(textwrap.indent(self._body.debug_str(), " ")) | |
return "\n".join(lines) | |
def get_ranges(self): | |
return self._sizes | |
def is_reduction(self): | |
assert isinstance( | |
self.node, (ir.ComputedBuffer, ir.TemplateBuffer) | |
), f"{type(self.node)=}" | |
return bool(self.node.get_reduction_type()) | |
def is_split_scan(self): | |
assert isinstance( | |
self.node, (ir.ComputedBuffer, ir.TemplateBuffer) | |
), f"{type(self.node)=}" | |
return isinstance(self.node, ir.ComputedBuffer) and isinstance( | |
self.node.data, ir.SplitScan | |
) | |
def is_template(self): | |
return isinstance(self.node, ir.TemplateBuffer) | |
def get_template_node(self): | |
return self.node if self.is_template() else None | |
def run(self, *index_vars): | |
self.decide_inplace_update() | |
self.mark_run() | |
self.codegen(index_vars) | |
def mark_run(self): | |
self.allocate() | |
def ranges_from_index_vars(self, index_vars): | |
sizes = self._sizes | |
assert sum(map(len, sizes)) == sum(map(len, index_vars)) | |
var_ranges = dict( | |
zip( | |
itertools.chain.from_iterable(index_vars), | |
itertools.chain.from_iterable(sizes), | |
) | |
) | |
return var_ranges | |
def codegen(self, index_vars): | |
var_ranges = self.ranges_from_index_vars(index_vars) | |
try: | |
with V.set_ops_handler( | |
SimplifyIndexing(V.get_ops_handler(), var_ranges) | |
), V.kernel.set_current_node(self): | |
self._body(*index_vars) | |
except Exception: | |
log.fatal("Error in codegen for %s", self.node) | |
raise | |
def pointwise_read_writes(self): | |
""" | |
Get the memory dependencies in the non-reduction axis. | |
""" | |
sizes, reduction_sizes = self._sizes | |
def fn(index): | |
return self._body(index, [sympy.Integer(0) for _ in reduction_sizes]) | |
return dependencies.extract_read_writes(fn, sizes) | |
def can_inplace(self, read_dep: dependencies.MemoryDep): | |
if self.get_aliases() or self.is_template(): | |
return False | |
if len(self.read_writes.writes) == 1 and isinstance( | |
read_dep, dependencies.MemoryDep | |
): | |
write_dep = next(iter(self.read_writes.writes)) | |
assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" | |
return read_dep.index == write_dep.index and read_dep.size == write_dep.size | |
return False | |
def _get_atomic_add_buffers(self) -> Set[str]: | |
buffers_store_as_atomic_add = set() | |
if isinstance(self._body, ir.LoopBody): | |
for node in self._body.get_nodes(): | |
if ( | |
node.op == "call_method" | |
and node.target == "store" | |
and ( | |
("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") | |
or (len(node.args) == 5 and node.args[4] == "atomic_add") | |
) | |
): | |
buffers_store_as_atomic_add.add( | |
node.kwargs["name"] | |
if "name" in node.kwargs | |
else (node.args[1] if len(node.args) >= 2 else "") | |
) | |
return buffers_store_as_atomic_add | |
def has_atomic_add(self, check_buf): | |
return check_buf in self._get_atomic_add_buffers() | |
class FusedSchedulerNode(BaseSchedulerNode): | |
""" | |
This is a "fake" scheduler node that represents a group of scheduler nodes | |
that are meant to be fused together. The way it does this is by maintaining | |
its unmet dependencies as the union of its constituent nodes. | |
""" | |
def fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode): | |
assert node1.scheduler is node2.scheduler | |
assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance( | |
node2, (SchedulerNode, FusedSchedulerNode) | |
) | |
return cls(node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())) # type: ignore[arg-type] | |
def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]): | |
# NB: No need to call super().__init__() because we don't need to re-use any of its logic. | |
self.snodes = snodes | |
self.scheduler = scheduler | |
self.node: ir.Buffer = None # type: ignore[assignment] | |
self.users: List[NodeUser] = [] | |
self.inverse_users = [] | |
self.node_users = [] | |
self.group = max(snodes, key=lambda x: int(x.is_reduction())).group | |
self.ancestors = set.union( | |
*[x.ancestors for x in snodes if x.ancestors is not None] | |
) | |
self.set_read_writes( | |
dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) | |
) | |
self.unmet_dependencies = { | |
dep | |
for dep in set.union(*[x.unmet_dependencies for x in snodes]) | |
if dep.name not in self.get_names() | |
} - self.read_writes.writes | |
self.min_order = min([x.min_order for x in self.snodes]) | |
self.max_order = max([x.max_order for x in self.snodes]) | |
def get_name(self) -> str: | |
return "_".join([x.get_name() for x in self.snodes]) | |
def get_first_name(self) -> str: | |
return self.snodes[0].get_name() | |
def get_names(self) -> Set[str]: | |
return set.union(*[x.get_names() for x in self.snodes]) | |
def debug_str_extra(self) -> str: | |
lines = [ | |
f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" | |
for i, node in enumerate(self.snodes) | |
] | |
return textwrap.indent("\n".join(lines).rstrip(), " ") | |
def set_last_usage( | |
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] | |
): | |
# Set self.last_usage using the global information | |
# This will be used for inter-kernel optimisations | |
super().set_last_usage(future_used_buffers, mutation_real_name) | |
# Set self.last_usage on the snodes | |
# This will be used for optimisations within the kernel | |
future_used_buffers: Set[str] = set() | |
for node in reversed(self.snodes): | |
node.set_last_usage(future_used_buffers, mutation_real_name) | |
future_used_buffers.update(node.last_usage) # type: ignore[arg-type] | |
def used_buffer_names(self) -> Set[str]: | |
return set.union(*[x.used_buffer_names() for x in self.snodes]) | |
def used_or_aliased_buffer_names(self) -> Set[str]: | |
return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes]) | |
def get_nodes(self) -> List[SchedulerNode]: | |
return self.snodes | |
def __repr__(self): | |
return f"{type(self).__name__}(nodes={self.get_name()})" | |
def is_reduction(self): | |
return any(x.is_reduction() for x in self.snodes) | |
def is_split_scan(self): | |
return any(x.is_split_scan() for x in self.snodes) | |
def is_template(self): | |
return any(x.is_template() for x in self.snodes) | |
def get_template_node(self): | |
for node in self.snodes: | |
if node.is_template(): | |
return node | |
return None | |
def get_device(self): | |
return self.group[0] | |
def has_aliasing_or_mutation(self): | |
return any(x.has_aliasing_or_mutation() for x in self.snodes) | |
def op_counts(self): | |
op_counts: Counter[str] = collections.Counter() | |
for node in self.snodes: | |
op_counts.update(node.op_counts()) | |
return op_counts | |
def has_atomic_add(self, check_buf): | |
return any( | |
( | |
isinstance(sub_schedule_node1, SchedulerNode) | |
and sub_schedule_node1.has_atomic_add(check_buf) | |
) | |
for sub_schedule_node1 in self.get_nodes() | |
) | |
# None of these need to be implemented, as a FusedSchedulerNode is just an | |
# abstraction for scheduling purposes | |
def update_mutated_names(self, renames: Dict[str, str]): | |
raise NotImplementedError | |
def add_mutation_dep(self, name): | |
raise NotImplementedError | |
def set_users(self, users: List["NodeUser"]): | |
raise NotImplementedError | |
def get_aliases(self): | |
raise NotImplementedError | |
def get_mutations(self): | |
raise NotImplementedError | |
def can_inplace(self, read_dep: dependencies.MemoryDep): | |
raise NotImplementedError | |
def allocate(self): | |
raise NotImplementedError | |
def can_free(self): | |
raise NotImplementedError | |
def debug_str(self) -> str: | |
"""Longer form printout for trace logs""" | |
name = self.get_name() | |
node_typestr = ",".join(type(n).__name__ for n in self.snodes) | |
lines = [ | |
f"{name}: {type(self).__name__}({node_typestr})", | |
f"{name}.writes = {pformat(self.read_writes.writes)}", | |
f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}", | |
f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}", | |
f"{name}.users = {self.users}", | |
] | |
try: | |
lines += [ | |
self.debug_str_extra(), | |
] | |
except Exception: | |
log.warning("Ignoring error in debug_str()", exc_info=True) | |
return "\n".join(lines).rstrip() | |
class ForeachKernelSchedulerNode(FusedSchedulerNode): | |
"""Scheduler node which consists of a list of scheduler nodes that each operate on a | |
distinct tensor in a list of tensors.""" | |
def get_consumer_subnode_for(self, producer): | |
if producer.get_name() in self.read_to_node: | |
return self.read_to_node[producer.get_name()] | |
return None | |
def get_producer_subnode_for(self, consumer): | |
for rd in consumer.read_writes.reads: | |
if rd.name in self.name_to_node: | |
return self.name_to_node[rd.name] | |
return None | |
def can_fuse(cls, producer, consumer): | |
why = WhyNoFuse(producer, consumer) | |
if producer.is_foreach() and consumer.is_foreach(): | |
foreach_match = len(producer.snodes) == len(consumer.snodes) | |
if not foreach_match: | |
why("foreach do not have same length") | |
return foreach_match and all( | |
producer.scheduler.can_fuse(l, r) | |
for l, r in zip(producer.snodes, consumer.snodes) | |
) | |
elif consumer.is_foreach(): | |
consumer_subnode = consumer.get_consumer_subnode_for(producer) | |
if consumer_subnode is not None: | |
return consumer.scheduler.can_fuse(producer, consumer_subnode) | |
why("candidate producer is not dep of any foreach consumer") | |
return False | |
elif producer.is_foreach(): | |
producer_subnode = producer.get_producer_subnode_for(consumer) | |
if producer_subnode is not None: | |
return producer.scheduler.can_fuse(producer_subnode, consumer) | |
why("candidate consumer has no dep in any foreach producer") | |
return False | |
raise AssertionError( | |
"At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" | |
) | |
def fuse(cls, producer, consumer): | |
assert producer.is_foreach() or consumer.is_foreach() | |
prev_node_1 = None | |
prev_node_2 = None | |
if producer.is_foreach() and consumer.is_foreach(): | |
fused_nodes = [ | |
FusedSchedulerNode.fuse(l, r) | |
for l, r in zip(producer.snodes, consumer.snodes) | |
] | |
elif producer.is_foreach(): | |
producer_subnode = producer.get_producer_subnode_for(consumer) | |
fused_nodes = [] | |
prev_node_1 = producer | |
prev_node_2 = None | |
for node in producer.snodes: | |
if node is producer_subnode: | |
new_node = FusedSchedulerNode.fuse(node, consumer) | |
prev_node_2 = new_node | |
fused_nodes.append(new_node) | |
else: | |
fused_nodes.append(node) | |
elif consumer.is_foreach(): | |
consumer_subnode = consumer.get_consumer_subnode_for(producer) | |
fused_nodes = [] | |
prev_node_1 = consumer | |
prev_node_2 = None | |
for node in consumer.snodes: | |
if node is consumer_subnode: | |
new_node = FusedSchedulerNode.fuse(producer, node) | |
prev_node_2 = new_node | |
fused_nodes.append(new_node) | |
else: | |
fused_nodes.append(node) | |
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined] | |
def __init__( | |
self, | |
scheduler: "Scheduler", | |
nodes: List[SchedulerNode], | |
prev_node_1=None, | |
prev_node_2=None, | |
): | |
self.read_to_node = {} | |
self.name_to_node = {} | |
if prev_node_1 is None or prev_node_2 is None: | |
super().__init__(scheduler, nodes) | |
for node in nodes: | |
for read in node.read_writes.reads: | |
self.read_to_node[read.name] = node | |
for name in node.get_names(): | |
self.name_to_node[name] = node | |
else: | |
self.scheduler = scheduler | |
self.snodes = nodes | |
self.node: ir.Buffer = None # type: ignore[assignment] | |
self.users: List[NodeUser] = [] | |
self.set_read_writes( | |
dependencies.ReadWrites.merge_list( | |
[prev_node_1.read_writes, prev_node_2.read_writes] | |
) | |
) | |
self.unmet_dependencies = { | |
dep | |
for dep in set.union( | |
prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies | |
) | |
if dep.name not in self.get_names() | |
} - self.read_writes.writes | |
self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) | |
self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) | |
foreach_node = prev_node_1 if prev_node_1.is_foreach() else prev_node_2 | |
other_node = prev_node_2 if prev_node_1.is_foreach() else prev_node_1 | |
self.ancestors = foreach_node.ancestors | |
self.ancestors.update(other_node.ancestors) | |
self.name_to_node = foreach_node.name_to_node | |
for name in other_node.get_names(): | |
self.name_to_node[name] = other_node | |
self.group = (nodes[0].get_device(), "foreach") | |
self.origins: Set[torch.fx.Node] = set() | |
def mark_run(self): | |
raise NotImplementedError | |
def codegen(self): | |
assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" | |
self.node.get_store_function()(self.node.make_loader()()) | |
def can_free(self): | |
return NotImplementedError | |
def is_foreach(self): | |
return True | |
def get_subkernel_nodes(self): | |
"""Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists. | |
These nodes may be vertically fused.""" | |
return list(self.snodes) | |
def get_nodes(self): | |
"""Returns all nodes contained in this kernel, unpacking fused nodes into their constituent scheduler nodes.""" | |
return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) | |
def get_first_name(self): | |
return self.snodes[0].get_first_name() | |
def prune_redundant_deps(self, name_to_fused_node): | |
_prune_redundant_deps(self, name_to_fused_node) | |
for node in self.snodes: | |
node.prune_redundant_deps(name_to_fused_node) | |
def pick_loop_order(stride_lengths, sizes, priority_idx=()): | |
""" | |
A heuristic to decide loop iteration orders. This has not been well | |
tuned and may be something we should autotune. | |
""" | |
def index_cmp(a, b): | |
if sizes[a] == 1 or sizes[b] == 1: | |
# 1-sizes don't matter, just move them to the end | |
return cmp(sizes[a] == 1, sizes[b] == 1) | |
stride_len_a = [sl[a] for sl in stride_lengths] | |
stride_len_b = [sl[b] for sl in stride_lengths] | |
# equivalent to | |
# np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() | |
a_first = sum( | |
sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) | |
) | |
b_first = sum( | |
sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) | |
) | |
if a_first > b_first: | |
return -1 | |
if b_first > a_first: | |
return 1 | |
# otherwise contiguous | |
return cmp(b, a) | |
order = list(reversed(range(len(stride_lengths[0])))) | |
if len(priority_idx) > 0: | |
# if we have priority node, only use that node's order | |
stride_lengths = [stride_lengths[pi] for pi in priority_idx] | |
if config.pick_loop_orders: | |
order.sort(key=index_cmp) | |
return order | |
class NodeUser: | |
node: BaseSchedulerNode | |
can_inplace: bool = False | |
# A weak user must be scheduled after a given node, but doesn't actually | |
# use the result | |
is_weak: bool = False | |
def __hash__(self): | |
return hash((self.node.get_name(), self.can_inplace, self.is_weak)) | |
def __eq__(self, other): | |
return ( | |
self.get_name() == other.get_name() | |
and self.can_inplace == other.can_inplace | |
and self.is_weak == other.is_weak | |
) | |
def get_name(self): | |
return self.node.get_name() | |
def merge(self, other: "NodeUser") -> "NodeUser": | |
assert self.node is other.node | |
return NodeUser( | |
self.node, | |
self.can_inplace and other.can_inplace, | |
self.is_weak and other.is_weak, | |
) | |
_post_grad_graph_counter = itertools.count() | |
class Scheduler: | |
def __init__(self, nodes): | |
super().__init__() | |
self.backends = {} | |
self.fuse_cache = {} | |
self.post_grad_graph_id = next(_post_grad_graph_counter) | |
self.nodes = [] | |
self.available_buffer_names = { | |
*V.graph.graph_inputs.keys(), | |
*V.graph.constants.keys(), | |
} | |
self.nodes = [self.create_scheduler_node(n) for n in nodes] | |
# some new constants could have been created above | |
self.available_buffer_names.update(V.graph.constants.keys()) | |
for node in self.nodes: | |
node.prune_deps() | |
self.name_to_node: Dict[str, BaseSchedulerNode] = { | |
n.get_name(): n for n in self.nodes | |
} | |
self.name_to_fused_node: Dict[ | |
str, BaseSchedulerNode | |
] = dict() # set in fuse_nodes() | |
# mutation_real_name: Maps back to the original name for codegen | |
# Example: | |
# If you mutate buf0 inside of buf1's kernel, then: | |
# mutation_real_name = {"buf0" : "buf1"} | |
# all subsequent uses of buf0 become buf1's usage in dependency graph | |
self.mutation_real_name = {} | |
# We handle mutation by renaming modified versions of the same | |
# buffer in the dependency graph to prevent cycles. | |
# mutation_renames: tracks the current name for a given buffer | |
# (changed once per mutation) | |
# Example: | |
# If you mutate buf0 inside of buf1's kernel, then: | |
# mutation_renames = {"buf1" : "buf0"} | |
# in codegen we only use buf0, never buf1 | |
self.mutation_renames = {} | |
self.compute_dependencies() | |
self.topological_sort_schedule() | |
self.dead_node_elimination() | |
if config.reorder_for_compute_comm_overlap: | |
comms.decide_global_ordering_of_comms(self.nodes) | |
self.compute_ancestors() | |
metrics.ir_nodes_pre_fusion += len(self.nodes) | |
V.debug.ir_pre_fusion(self.nodes) | |
self.num_orig_nodes = len(self.nodes) | |
self.name_to_fused_node = {n.get_name(): n for n in self.nodes} | |
self.create_foreach_nodes() | |
self.topological_sort_schedule() | |
self.logged_slow_fusion = set() | |
self.fuse_nodes() | |
if config.reorder_for_compute_comm_overlap: | |
# Refresh node_users and inverse_users to reflect fused nodes | |
self.compute_node_users() | |
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) | |
self.compute_last_usage() | |
V.debug.ir_post_fusion(self.nodes) | |
V.debug.graph_diagram(self.nodes) | |
self.debug_draw_graph() | |
# used during codegen: | |
self.current_device: torch.device = None # type: ignore[assignment] | |
self.buffer_names_to_free = set() | |
# fx graph node to the position it appears in the graph | |
# for debug attribution | |
self.origin_to_index = {} | |
get_metric_table("graph_stats").add_row( | |
lambda: { | |
"graph_id": self.post_grad_graph_id, | |
"num_nodes_before_fusion": self.num_orig_nodes, | |
"num_nodes_after_fusion": len(self.nodes), | |
} | |
) | |
def debug_draw_graph(self): | |
"""Generate an image of the graph for debugging""" | |
if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": | |
from .debug import draw_buffers | |
draw_buffers(self.nodes, print_graph=True) | |
def debug_print_nodes(self, label): | |
if log.isEnabledFor(logging.INFO): | |
log.info("%s:", label) | |
for node in self.nodes: | |
node.log_details() | |
def create_scheduler_node(self, node): | |
assert ( | |
node.origins is not None | |
), "All nodes passed to scheduling must have an origin" | |
if node.is_no_op(): | |
return NopKernelSchedulerNode(self, node) | |
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): | |
return SchedulerNode(self, node) | |
elif isinstance(node, ir.ExternKernel): | |
return ExternKernelSchedulerNode(self, node) | |
else: | |
raise NotImplementedError(node) | |
def create_foreach_nodes(self): | |
removed_node_names = set() | |
fe_nodes = [] | |
kept_node_names = self.name_to_fused_node.keys() | |
for names in V.graph.lists.values(): | |
names = [ | |
name | |
for name in names | |
if name in kept_node_names | |
and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) | |
] | |
if not names: | |
# All nodes eliminated | |
continue | |
removed_node_names.update(names) | |
snodes = [self.name_to_node[name] for name in names] | |
fe_node = ForeachKernelSchedulerNode(self, snodes) # type: ignore[arg-type] | |
fe_nodes.append(fe_node) | |
for name in names: | |
self.name_to_fused_node[name] = fe_node | |
self.nodes = [ | |
node for node in self.nodes if node.get_name() not in removed_node_names | |
] + fe_nodes | |
def compute_dependencies(self): | |
""" | |
Create dependency edges between nodes, handling aliasing and | |
mutation properly. | |
""" | |
T = TypeVar("T") | |
class DedupList(Generic[T]): | |
""" | |
This data structure behaves like a list except it makes sure the | |
elements remain unique. | |
Normally one could use a set/dict for this purpose however | |
the list in question gets elements appended as it is being | |
iterated over which means that we need to keep the list | |
semantics. | |
""" | |
def __init__(self, items=None, membership=None): | |
self.items = items or list() | |
self.membership = membership or set() | |
def append(self, node_user: T) -> None: | |
if node_user in self.membership: | |
return | |
self.items.append(node_user) | |
self.membership.add(node_user) | |
def __add__(self, other: "DedupList[T]") -> "DedupList[T]": | |
new_membership = set.union(self.membership, other.membership) | |
new_items = self.items + [ | |
x for x in other.items if x not in self.membership | |
] | |
return DedupList(new_items, new_membership) | |
name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict( | |
DedupList | |
) | |
# handle aliasing by using python aliasing in name_to_users | |
# if foo aliases bar then we will make name_to_users["foo"] point | |
# to the same python list as name_to_users["bar"] | |
for node1 in self.nodes: | |
node1_name = node1.get_name() | |
for node2_name in node1.get_aliases(): | |
if node1_name in name_to_users and node2_name in name_to_users: | |
# merge the two | |
list1 = name_to_users[node1_name] | |
list2 = name_to_users[node2_name] | |
combined = list1 + list2 | |
for key in name_to_users.keys(): | |
if name_to_users[key] is list1 or name_to_users[key] is list2: | |
name_to_users[key] = combined | |
elif node1_name in name_to_users: | |
name_to_users[node2_name] = name_to_users[node1_name] | |
else: | |
name_to_users[node1_name] = name_to_users[node2_name] | |
def rename(n): | |
if n in self.mutation_renames: | |
return rename(self.mutation_renames[n]) | |
return n | |
def dep_closure(node_name): | |
reachable_names = {node_name} | |
node = self.name_to_node[node_name] | |
write_dep = next(iter(node.read_writes.writes)) | |
for read_dep in node.read_writes.reads: | |
if ( | |
read_dep.name in self.name_to_node | |
and isinstance(read_dep, dependencies.MemoryDep) | |
and isinstance(write_dep, dependencies.MemoryDep) | |
and read_dep.index == write_dep.index | |
and read_dep.size == write_dep.size | |
): | |
reachable_names.update(dep_closure(read_dep.name)) | |
return reachable_names | |
def add_user(used_by_name, user_node, can_inplace=False, is_weak=False): | |
name_to_users[rename(used_by_name)].append( | |
NodeUser(user_node, can_inplace, is_weak) | |
) | |
unbacked_symbol_to_origin_node = {} | |
for node in self.nodes: | |
log.debug("scheduling %s", node.node) | |
# unbacked symbols don't follow ordinary buffer dependencies, so | |
# we track their def/uses separately | |
unbacked_symbol_defs = sorted( | |
node.node.get_unbacked_symbol_defs(), key=lambda x: x.name | |
) | |
for s in unbacked_symbol_defs: | |
assert isinstance(s, sympy.Symbol) | |
# Pick the first definer as canonical. There may be multiple | |
# because if a MultiOutputLayout buffer propagates an unbacked | |
# symint to multiple outputs, they will all claim to def it. | |
if s not in unbacked_symbol_to_origin_node: | |
unbacked_symbol_to_origin_node[s] = node | |
unbacked_symbol_uses = sorted( | |
node.node.get_unbacked_symbol_uses(), key=lambda x: x.name | |
) | |
# if a kernel takes unbacked symints, register dependencies | |
for s in unbacked_symbol_uses: | |
assert ( | |
s in unbacked_symbol_to_origin_node | |
), f"{s} not in {unbacked_symbol_to_origin_node}" | |
node.add_fake_dep(StarDep(unbacked_symbol_to_origin_node[s].get_name())) | |
# a node will mutate either 0 or 1 buffers | |
assert len(node.get_mutations()) <= 1 | |
for alt_name in node.get_mutations(): | |
alt_name = rename(alt_name) | |
# this node must run after the prior writer | |
add_user(alt_name, node) | |
node.add_mutation_dep(StarDep(alt_name)) | |
for other_node in name_to_users[alt_name].items: | |
# this node must run after all prior readers | |
other_name = rename(other_node.get_name()) | |
known_dep_node_names = dep_closure(node.get_name()) | |
if other_name not in known_dep_node_names: | |
# If this node already directly or indirectly depends on other_node, | |
# we don't need to insert an extra dep. | |
node.add_mutation_dep(WeakDep(other_name)) | |
add_user(other_name, node, is_weak=True) | |
# add normal non-mutation dependencies | |
for read in node.read_writes.reads: | |
is_weak = isinstance(read, WeakDep) | |
add_user(read.name, node, node.can_inplace(read), is_weak) | |
node.update_mutated_names(self.mutation_renames) | |
# update our renaming scheme for the next iteration | |
for alt_name in node.get_mutations(): | |
self.mutation_renames[rename(alt_name)] = node.get_name() | |
self.mutation_renames[alt_name] = node.get_name() | |
self.mutation_real_name[node.get_name()] = self.mutation_real_name.get( | |
alt_name, alt_name | |
) | |
# make sure outputs aren't dead-code-eliminated | |
for node_name in V.graph.get_output_names(): | |
log.debug("scheduling output %s", node_name) | |
add_user(node_name, OutputNode(StarDep(node_name))) | |
# make sure unbacked symints aren't dead-code-eliminated | |
for node in V.graph.graph_outputs: | |
for s in node.get_unbacked_symbol_uses(): | |
assert ( | |
s in unbacked_symbol_to_origin_node | |
), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" | |
node_name = unbacked_symbol_to_origin_node[s].node.name | |
log.debug("scheduling output %s for unbacked symint %s", node_name, s) | |
add_user(node_name, OutputNode(StarDep(node_name))) | |
# make sure input mutation isn't dead-code-eliminated | |
for name in self.mutation_renames: | |
if name in V.graph.graph_inputs: | |
add_user(name, OutputNode(StarDep(name))) | |
V.graph.mutated_inputs.add(name) | |
inp_names = { | |
name: index for index, name in enumerate(V.graph.graph_inputs.keys()) | |
} | |
V.graph.mutated_input_idxs = [ | |
inp_names[name] for name in V.graph.mutated_inputs | |
] | |
# copy users information onto the nodes | |
for node in self.nodes: | |
node.set_users(name_to_users[node.get_name()].items) | |
# populate inverse_users | |
for node in self.nodes: | |
for user in node.users: | |
user.node.inverse_users.append(node) | |
def compute_node_users(self): | |
# set up buffer name to (fused)snode mapping | |
buf_to_snode = {} | |
for node in self.nodes: | |
if isinstance(node, FusedSchedulerNode): | |
for x in node.snodes: | |
buf_to_snode[x.get_name()] = node | |
buf_to_snode[node.get_name()] = node | |
for node in self.nodes: | |
node.node_users = [] | |
node.inverse_users = [] | |
# compute inverse_users | |
for node in self.nodes: | |
inverse_users = [] | |
for dep in node.unmet_dependencies: | |
assert dep.name in buf_to_snode | |
dep_node = buf_to_snode[dep.name] | |
inverse_users.append(dep_node) | |
node.inverse_users = inverse_users | |
# compute node_users | |
# TODO: ideally, we should deduplicate .users and .node_users, | |
# but currently .users contains extra information that's difficult to | |
# extract into a standalone container. | |
node_to_users: Dict[BaseSchedulerNode, List[BaseSchedulerNode]] = {} | |
for node in self.nodes: | |
for inverse_user in node.inverse_users: | |
node_to_users.setdefault(inverse_user, []).append(node) | |
for node, users in node_to_users.items(): | |
node.node_users = users | |
def dead_node_elimination(self): | |
""" | |
Remove any nodes without users | |
""" | |
again = True # repeat until a fixed point | |
while again: | |
updated_nodes = [] | |
for node in self.nodes: | |
def can_eliminate_user(user: NodeUser): | |
return user.is_weak or user.get_name() in V.graph.removed_buffers | |
can_eliminate = not node.has_side_effects() and all( | |
can_eliminate_user(u) for u in node.users | |
) | |
if not can_eliminate: | |
updated_nodes.append(node) | |
else: | |
# dead code | |
log.debug("removed dead node: %s", node.get_name()) | |
V.graph.removed_buffers.add(node.get_name()) | |
again = len(self.nodes) > len(updated_nodes) | |
self.nodes = updated_nodes | |
# Prune any WeakDeps no longer needed | |
for node in self.nodes: | |
node.prune_weak_deps() | |
def topological_sort_schedule(self): | |
""" | |
Ensure self.nodes is in topologically sorted order | |
""" | |
seen: Set[ir.Buffer] = set() | |
name_to_node: Dict[str, ir.Buffer] = dict() | |
result: List[ir.Buffer] = [] | |
def visit(n): | |
if n not in seen: | |
seen.add(n) | |
for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): | |
visit(name_to_node[dep.name]) | |
result.append(n) | |
for node in self.nodes: | |
for name in node.get_names(): | |
name_to_node[name] = node | |
for node in self.nodes: | |
visit(node) | |
self.nodes = result | |
def compute_ancestors(self): | |
""" | |
Populate each node.ancestors | |
""" | |
# note self.nodes is topologically sorted | |
name_to_ancestors: Dict[str, Set[str]] = {} | |
for node in self.nodes: | |
ancestors = set() | |
for dep in node.unmet_dependencies: | |
ancestors.add(dep.name) | |
ancestors |= name_to_ancestors[dep.name] | |
name_to_ancestors[node.get_name()] = ancestors | |
node.ancestors = ancestors | |
for order, node in enumerate(self.nodes): | |
node.min_order = order | |
node.max_order = order | |
def fuse_nodes(self): | |
""" | |
Mutates self.nodes to combine nodes into FusedSchedulerNodes. | |
""" | |
for i in range(10): | |
old_len = len(self.nodes) | |
fusion_log.debug( | |
"===== attempting fusion (%d/10): %d nodes =====", i + 1, old_len | |
) | |
self.fuse_nodes_once() | |
new_len = len(self.nodes) | |
fusion_log.debug( | |
"completed fusion round (%d/10): fused %d nodes into %d nodes\n", | |
i + 1, | |
old_len, | |
new_len, | |
) | |
if new_len == old_len or new_len == 1: | |
fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) | |
break | |
def benchmark_fused_nodes(self, nodes): | |
""" | |
Benchmark fused list of nodes and return the execution time | |
in milliseconds on randomly generated inputs. | |
""" | |
assert len(nodes) > 0 | |
device = nodes[0].get_device() | |
V.graph.scheduler = self | |
self.current_device = device | |
backend = self.get_backend(device) | |
return backend.benchmark_fused_nodes(nodes) | |
def speedup_by_fusion(self, node1, node2): | |
""" | |
If config.benchmark_fusion is False, always return True. | |
Otherwise, return True if fusion can brings speedup. | |
""" | |
if not config.benchmark_fusion: | |
return True | |
if ( | |
node1.is_template() | |
and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) | |
or node1.is_foreach() | |
or node2.is_foreach() | |
): | |
# TODO support benchmarking epilogue fusion | |
return True | |
node_list_1 = node1.get_nodes() | |
device = node_list_1[0].get_device() | |
# don't support benchmark fusion for CPU right now. | |
if device.type == "cpu": | |
return True | |
node_list_2 = node2.get_nodes() | |
node_list_fused = node_list_1 + node_list_2 | |
# We can not accurately benchmark kernel using atomic_add | |
# due to how we generate random integer inputs. | |
# Skip benchmarking them by allowing fusion. | |
if any( | |
hasattr(n.node, "data") | |
and hasattr(n.node.data, "scatter_mode") | |
and n.node.data.scatter_mode == "atomic_add" | |
for n in node_list_fused | |
): | |
return True | |
from triton.compiler.errors import CompilationError | |
why = WhyNoFuse(node1, node2) | |
try: | |
ms1, path1 = self.benchmark_fused_nodes(node_list_1) | |
if math.isinf(ms1): | |
why("register spilling of the first kernel") | |
return False | |
ms2, path2 = self.benchmark_fused_nodes(node_list_2) | |
if math.isinf(ms2): | |
why("register spilling of the second kernel") | |
return False | |
ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused) | |
if math.isinf(ms_fused): | |
why("register spilling of the fused kernel") | |
return False | |
except CompilationError as e: | |
# workaround triton issue: https://github.com/openai/triton/issues/2151 | |
if "Loop-carried variable" in str(e): | |
return True # allow fusion | |
else: | |
raise | |
if fusion_log.isEnabledFor(logging.DEBUG): | |
if ms_fused < ms1 + ms2: | |
fusion_log.debug( | |
"can fuse (benchmark): fusing %s with %s cause %sx speedup", | |
node1.get_names(), | |
node2.get_names(), | |
green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), | |
) | |
else: | |
fusion_log.debug( | |
"cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", | |
node1.get_names(), | |
node2.get_names(), | |
red_text(f"{ms_fused / (ms1 + ms2):.3f}"), | |
) | |
if ( | |
is_metric_table_enabled("slow_fusion") | |
and ms_fused >= ms1 + ms2 | |
and (path1, path2) not in self.logged_slow_fusion | |
): | |
self.logged_slow_fusion.add((path1, path2)) | |
get_metric_table("slow_fusion").add_row( | |
lambda: { | |
"kernel1_path": path1, | |
"kernel1_latency": ms1, | |
"kernel2_path": path2, | |
"kernel2_latency": ms2, | |
"fused_kernel_path": path_fused, | |
"fused_kernel_latency": ms_fused, | |
"slow_down_ratio": ms_fused / (ms1 + ms2), | |
} | |
) | |
return ms_fused < ms1 + ms2 | |
def fuse_nodes_once(self): | |
""" | |
Mutates self.nodes to combine nodes into FusedSchedulerNodes. | |
This relies on two key functions to control the logic: | |
- self.can_fuse(): checks if a fusion is legal | |
- self.score_fusion(): assigns priority to a given fusion | |
""" | |
fused_nodes = set(self.nodes) | |
for node1, node2 in self.get_possible_fusions(): | |
node1 = self.name_to_fused_node[node1.get_first_name()] | |
node2 = self.name_to_fused_node[node2.get_first_name()] | |
if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( | |
node1, node2 | |
): | |
if not self.speedup_by_fusion(node1, node2): | |
continue | |
fusion_log.debug( | |
"fusing %s with %s", node1.get_name(), node2.get_name() | |
) | |
# above can_fuse asserts that node2 has the same device | |
device = node1.get_device() | |
node3 = self.get_backend(device).fuse(node1, node2) | |
fused_nodes.remove(node1) | |
fused_nodes.remove(node2) | |
fused_nodes.add(node3) | |
self.name_to_fused_node.update( | |
{n.get_name(): node3 for n in node3.get_nodes()} | |
) | |
self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) | |
self.topological_sort_schedule() | |
self.prune_redundant_deps() | |
def prune_redundant_deps(self): | |
for node in self.nodes: | |
node.prune_redundant_deps(self.name_to_fused_node) | |
def get_possible_fusions(self): | |
""" | |
Helper to find all legal fusion opportunities, sorted by self.score_fusion() | |
""" | |
possible_fusions = [] | |
seen = set() | |
def check_all_pairs(nodes): | |
for node1_index, node1 in enumerate(nodes): | |
for node2 in nodes[node1_index + 1 :]: | |
key = (node1, node2) | |
if key in seen: | |
continue | |
seen.add(key) | |
if self.can_fuse(node1, node2): | |
possible_fusions.append(key) | |
elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( | |
node2, node1 | |
): | |
# foreach fusions and epilogue fusions are order dependent | |
possible_fusions.append((node2, node1)) | |
buffer_names_grouping = collections.defaultdict(list) | |
for node in self.nodes: | |
for buf in node.used_buffer_names(): | |
buffer_names_grouping[buf].append(node) | |
for node_grouping in buffer_names_grouping.values(): | |
check_all_pairs(node_grouping) | |
if config.aggressive_fusion: | |
group_grouping = collections.defaultdict(list) | |
for node in self.nodes: | |
group = getattr(node, "group", None) | |
if group: | |
group_grouping[group].append(node) | |
for node_grouping in group_grouping.values(): | |
check_all_pairs(node_grouping) | |
possible_fusions.sort(key=self.score_fusion_key, reverse=True) | |
fusion_log.debug("found %d possible fusions", len(possible_fusions)) | |
return possible_fusions | |
def will_fusion_create_cycle(self, node1, node2): | |
""" | |
Finds whether there's a path from node1 to node2 (or vice-versa) | |
caused indirectly by other fusions. | |
""" | |
def found_path(node): | |
# only fused nodes can introduce new ancestors. | |
if isinstance(node, FusedSchedulerNode) and node not in visited: | |
visited.add(node) | |
if node.get_names().issubset(combined_ancestors): | |
# All fusion outputs are in ancestors of node1 and node2, thus | |
# cannot introduce new path: | |
# | |
# 1. if output is neither descendent of node1 or node2, the | |
# output cannot introduce a path | |
# 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be | |
# on path(node1->node2), hence it cannot be ancestor of node2 | |
# 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be | |
# ancestor of node1 | |
return False | |
else: | |
# continue DFS of new ancestors introduced by the fusion | |
return bool(combined_names & node.ancestors) or any( | |
found_path(self.name_to_fused_node[n]) | |
for n in node.ancestors - combined_ancestors | |
) | |
return False | |
visited = set() | |
combined_names = node1.get_names() | node2.get_names() | |
combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names | |
cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) | |
if cycle: | |
WhyNoFuse(node1, node2)("will create cycle") | |
return cycle | |
def can_fusion_increase_peak_memory( | |
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode | |
): | |
""" | |
This function prevents fusion for nodes that can increase memory | |
footprint. This problem is more common in horizontal fusion, where nodes | |
that are far apart in the original order get fused, lengthening the live | |
intervals of tensors. This is very evident in models with activation | |
checkpointing, where the recomputed nodes from different checkpointed | |
regions get fused and significantly increase the memory footprint. | |
The current attempt is a quick, possibly hacky, heuristic to prevent the | |
fusion of nodes that are far away in the original order. | |
A better but difficult to implement heurisitic would be to use live | |
intervals of the buffers, find region of peak pressure in the original | |
program and prevent fusion that crosses that peak region. We might need | |
special care or good approximation in this implementation, as fusion of | |
node changes live intervals, and re-computing live intervals and peak | |
memory after each fusion can introduce large compilation overhead. | |
""" | |
proximity_score = max( | |
abs(node1.min_order - node2.max_order), | |
abs(node2.min_order - node1.max_order), | |
) | |
return proximity_score > 64 | |
def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): | |
""" | |
Determine if it is possible to combine node1 and node2 into a | |
single fused node. | |
""" | |
if node1 is node2: | |
return False | |
why = WhyNoFuse(node1, node2) | |
if ( | |
isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) | |
and not node1.is_template() | |
): | |
why("node1 is extern or nop") | |
return False | |
if ( | |
isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) | |
and not node2.is_template() | |
): | |
why("node2 is extern or nop") | |
return False | |
if node2.get_names() & node1.ancestors: | |
why("node1 must go before node2") | |
return False | |
if ( | |
isinstance(node1, (FusedSchedulerNode, SchedulerNode)) | |
and isinstance(node2, SchedulerNode) | |
and isinstance(node2._body, ir.LoopBody) | |
): | |
# Fix issue: https://github.com/pytorch/pytorch/issues/108963 | |
# Check: | |
# If node2 reads a buf which is a mutation buf of node1(SchedulerNode) or among nodes in node1(FusedSchedulerNode), | |
# we will get the corresponding mutation buf and check if this mutation buf is stored by atomic_add mode. | |
# If True, we will disable the fusion of node1 and node2. | |
if any( | |
( | |
node2_used_buf in self.mutation_renames | |
and node1.has_atomic_add(self.mutation_renames[node2_used_buf]) | |
) | |
for node2_used_buf in node2._body.reads_name2expr.keys() | |
): | |
return False | |
if node2.is_template(): | |
why("templates can only fuse epilogues") | |
return False | |
if node1.is_template() and ( | |
node2.has_aliasing_or_mutation() | |
or node2.is_reduction() | |
or not config.epilogue_fusion | |
): | |
why("template epilogue not satisfied") | |
return False | |
device = node1.get_device() | |
device2 = node2.get_device() | |
if device != device2: | |
why("device mismatch (%s vs %s)", device, device2) | |
return False | |
del device2 | |
no_shared_data = self.score_fusion_memory(node1, node2) == 0 | |
if no_shared_data and ( | |
not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() | |
): | |
why("no shared data") | |
return False # heuristic not needed for correctness | |
if ( | |
not node1.is_foreach() | |
and not node2.is_foreach() | |
and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size | |
): | |
why("exceeds max fusion") | |
return False # heuristic not needed for correctness | |
if node1.get_names() & node2.ancestors: | |
# node2 depends on node1 outputs | |
if not self.can_fuse_vertical(node1, node2): | |
return False | |
return self.get_backend(device).can_fuse_vertical(node1, node2) | |
else: # nodes don't depend on each other, but may have common reads | |
if self.can_fusion_increase_peak_memory(node1, node2): | |
why("will increase peak memory") | |
return False | |
return self.get_backend(device).can_fuse_horizontal(node1, node2) | |
def can_fuse_vertical(self, node1, node2): | |
""" | |
Check if it is legal to fuse a consumer (node2) into a producer (node1). | |
We can fuse them if all the reads of node2 either match | |
corresponding writes in node1, or are written by nodes that can | |
be scheduled before the fusion of node1 and node2. | |
We also disable fusion of a write subsequent to a read if the reads | |
and writes do not align. | |
""" | |
node1_names = node1.get_names() | |
computed_deps = set() | |
why = WhyNoFuse(node1, node2) | |
# StarDep doesn't match MemoryDep, different indices don't match | |
# However, broadcasting sometimes strips dimensions, and if that's the case | |
# we still can match unmet dep | |
# if there's indirect indexing, don't match it | |
def fusable_read_and_write(read: Dep, write: Dep): | |
return ( | |
self.mutation_renames.get(read.name, read.name) == write.name | |
and (isinstance(read, MemoryDep) and isinstance(write, MemoryDep)) | |
and not free_symbol_has(read.index, "tmp") | |
and not free_symbol_has(write.index, "tmp") | |
and read.index == write.index | |
and len(read.size) >= len(write.size) | |
and read.size[: len(write.size)] == write.size | |
) | |
for rd in node2.unmet_dependencies: | |
for cd in node1.read_writes.writes: | |
if fusable_read_and_write(rd, cd): | |
computed_deps.add(rd) | |
remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps} | |
if remaining_deps & node1_names: | |
# MemoryDeps didn't match and read different locations of the same buffer. | |
# Examples here include: | |
# - MemoryDep("foo", x) != MemoryDep("foo", x + 1) | |
# - MemoryDep("foo", x) != StarDep("foo") | |
why("memory deps did not match") | |
return False | |
for name in remaining_deps: | |
if node1_names & self.name_to_fused_node[name].ancestors: | |
why("intermediate nodes between node1 & node2") | |
return False | |
# similar to can_inplace, if we are going to fuse a write subsequent to a read | |
# require that the indexing and size is the same | |
for write in node2.read_writes.writes: | |
for read in node1.read_writes.reads: | |
if write.name != self.mutation_renames.get(read.name, read.name): | |
continue | |
# bail on StarDep | |
if not fusable_read_and_write(read=read, write=write): | |
why("fusing a write into a read with different indexing formula") | |
return False | |
return True | |
def score_fusion(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): | |
""" | |
Assign a score (higher comes first) to the fusion of node1 | |
and node2. When different fusions conflict with each other, | |
this is the way we decide what order to run them in. | |
Our current score is based on: | |
- Estimate of the saved memory operations | |
- Fusions closer together in original order | |
""" | |
memory_score = self.score_fusion_memory(node1, node2) | |
proximity_score = -max( | |
abs(node1.min_order - node2.max_order), | |
abs(node2.min_order - node1.max_order), | |
) | |
return ( | |
node1.is_template() == config.epilogue_fusion_first and memory_score > 0, | |
node1.is_reduction() == node2.is_reduction() and memory_score > 0, | |
memory_score, | |
proximity_score, | |
) | |
def score_fusion_memory(self, node1, node2): | |
""" | |
The first term in our fusion score that estimates number of saved memory operations. | |
""" | |
common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( | |
node2.read_writes.reads | node2.read_writes.writes | |
) | |
common_memory_deps = { | |
dep for dep in common_memory_deps if not dep.has_unbacked_symbols() | |
} | |
return sum(dep.numbytes_hint() for dep in common_memory_deps) | |
def score_fusion_key(self, nodes): | |
""" | |
Shim for list.sort(key=...) | |
""" | |
node1, node2 = nodes | |
return self.score_fusion(node1, node2) | |
def compute_last_usage(self): | |
""" | |
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) | |
""" | |
future_used_buffers = set() | |
for node_name in V.graph.get_output_names(): | |
future_used_buffers.add(node_name) | |
for node in reversed(self.nodes): | |
node.set_last_usage(future_used_buffers, self.mutation_real_name) | |
future_used_buffers.update(node.last_usage) | |
def free_buffers(self): | |
"""Free any buffers that are no longer needed""" | |
for name in sorted( | |
self.buffer_names_to_free | |
- V.graph.removed_buffers | |
- V.graph.wrapper_code.freed | |
): | |
if name in self.name_to_node: | |
node = self.name_to_node[name] | |
if node.can_free(): | |
V.graph.wrapper_code.codegen_free(node.node) | |
elif name in V.graph.graph_inputs: | |
storage = V.graph.graph_inputs[name].data | |
assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() | |
V.graph.wrapper_code.codegen_free(storage.data) | |
self.buffer_names_to_free.clear() | |
def remove_kernel_local_buffers(self): | |
""" | |
Any buffers that are both created and have a last use in the | |
same kernel can be removed. | |
""" | |
# V.kernel.store_buffer_names should represent the set of nodes | |
# get fused | |
fused_node_names = V.kernel.store_buffer_names | |
names_to_remove = [] | |
for out_buf in V.kernel.store_buffer_names: | |
users = self.name_to_node[out_buf].users | |
assert users is not None | |
users = {user.get_name() for user in users if not user.is_weak} | |
if users.issubset(fused_node_names): | |
names_to_remove.append(out_buf) | |
def remove_filter(n): | |
return ( | |
n not in V.kernel.must_keep_buffers | |
and n not in V.kernel.args.input_buffers | |
and n not in self.mutation_renames | |
and n not in self.mutation_real_name | |
) | |
names_to_remove = list(filter(remove_filter, names_to_remove)) | |
for name in names_to_remove: | |
if name in V.kernel.args.inplace_buffers: | |
buf = V.kernel.args.inplace_buffers[name] | |
if isinstance(buf, str) and buf.startswith("REMOVED"): | |
continue | |
remove = all(n in names_to_remove for n in buf.other_names) | |
if remove: | |
self.remove_inplace_buffer(name) | |
V.kernel.inplaced_to_remove.add(name) | |
else: | |
self.remove_buffer(name) | |
def remove_buffer(self, name): | |
# Assign a special value instead of deleting the entry | |
# because we still rely on output_buffers's length to | |
# generate unique arg name. | |
log.debug("remove_buffer(%r)", name) | |
V.kernel.args.output_buffers[name] = "REMOVED" | |
V.kernel.removed_buffers.add(name) | |
def remove_inplace_buffer(self, name): | |
log.debug("removing_inplace_buffer(%r)", name) | |
inner_name = V.kernel.args.inplace_buffers[name].inner_name | |
V.kernel.args.inplace_buffers[name] = inner_name.replace( | |
"in_out_ptr", "REMOVED" | |
) | |
V.kernel.removed_buffers.add(name) | |
def flush(self): | |
for backend in self.backends.values(): | |
backend.flush() | |
self.free_buffers() | |
def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode): | |
assert isinstance(scheduler_node, ExternKernelSchedulerNode) | |
# 'decide_inplace_update' stores the inplace update decisions in | |
# the current kernel from where 'allocate' retrieve those decisions. | |
# We have to make sure there is a non-NULL kernel handler to store | |
# those inplace update decisions. | |
with V.set_kernel_handler(Kernel(increase_kernel_count=False)): | |
scheduler_node.decide_inplace_update() | |
scheduler_node.allocate() | |
node = scheduler_node.node | |
assert isinstance(node, ir.ExternKernel), f"{type(node)=}" | |
node.codegen(V.graph.wrapper_code) | |
self.free_buffers() | |
def create_backend(self, device: torch.device): | |
assert ( | |
device.type != "cuda" or device.index is not None | |
), f"{device} should have been normalized in lowering" | |
V.graph.add_device_info(device) | |
device_scheduling = get_scheduling_for_device(device.type) | |
if device_scheduling is None: | |
raise RuntimeError(f"Unsupported device type: {device.type}") | |
if device.type == "cuda" and not has_triton(): | |
device_props = torch.cuda.get_device_properties(device) | |
if device_props.major < 7: | |
raise RuntimeError( | |
f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 | |
) | |
else: | |
raise RuntimeError( | |
"Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 | |
) | |
return device_scheduling(self) | |
def get_backend(self, device: torch.device): | |
if device not in self.backends: | |
self.backends[device] = self.create_backend(device) | |
return self.backends[device] | |
def enter_context(self, node): | |
def get_order(n): | |
if n not in self.origin_to_index: | |
self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) | |
return self.origin_to_index[n] | |
# Use a dict to have ordering | |
origins = { | |
(get_order(e), e): None for n in node.get_nodes() for e in n.node.origins | |
} | |
origins = list(origins.keys()) | |
if origins: | |
_, last = max(origins, key=operator.itemgetter(0)) | |
V.graph.wrapper_code.enter_context(last) | |
def codegen(self): | |
for node in self.nodes: | |
try: | |
log.debug( | |
"Generating code for node %s with estimated runtime %f", | |
node.get_name(), | |
node.get_estimated_runtime(), | |
) | |
except Exception as e: | |
log.debug( | |
"Generating code for node %s with estimated runtime 0.0", | |
node.get_name(), | |
) | |
self.enter_context(node) | |
if not isinstance(node, NopKernelSchedulerNode): | |
device = node.get_device() | |
if ( | |
device != self.current_device | |
or node.is_extern() | |
or node.is_template() | |
): | |
self.flush() | |
if device != self.current_device: | |
if device.type == "cuda": | |
if self.current_device and self.current_device.type == "cuda": | |
V.graph.wrapper_code.codegen_device_guard_exit() | |
assert device.index is not None, "device should have an index" | |
V.graph.wrapper_code.codegen_device_guard_enter(device.index) | |
elif self.current_device and self.current_device.type == "cuda": | |
V.graph.wrapper_code.codegen_device_guard_exit() | |
self.current_device = device | |
self.buffer_names_to_free.update(node.last_usage) | |
if node.is_template(): | |
node, *epilogue = node.get_nodes() | |
self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined] | |
elif node.is_extern(): | |
self.codegen_extern_call(node) | |
elif node.is_foreach(): | |
self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined] | |
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): | |
self.get_backend(device).codegen_nodes(node.get_nodes()) # type: ignore[possibly-undefined] | |
else: | |
assert isinstance(node, NopKernelSchedulerNode) | |
node.allocate() | |
if config.debug_check_inf_and_nan: | |
V.graph.wrapper_code.generate_inf_and_nan_checker(node) | |
if config.triton.debug_sync_kernel: | |
self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined] | |
self.available_buffer_names.update(node.get_names()) | |
if not isinstance(node, NopKernelSchedulerNode): | |
device = node.get_device() | |
if self.get_backend(device).ready_to_flush(): | |
self.flush() | |
if self.current_device and self.current_device.type == "cuda": | |
# exit the outermost CUDA device guard. this is | |
# important for nested indentation codegen-ing. | |
V.graph.wrapper_code.codegen_device_guard_exit() | |
self.flush() | |
def is_unaligned_buffer(self, buf_name): | |
if buf_name in V.graph.graph_inputs or buf_name in V.graph.constants: | |
# all graph inputs or constants are assumed to be aligned | |
return False | |
node = self.name_to_node[buf_name] | |
layout = node.node.get_layout() | |
if isinstance(layout, ir.AliasedLayout): | |
return not layout.maybe_guard_aligned() | |
else: | |
return False | |
class BaseScheduling: | |
def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): | |
""" | |
Check whether node1 and node2 can be vertically fused or not. | |
""" | |
raise NotImplementedError() | |
def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): | |
""" | |
Check whether node1 and node2 can be horizontally fused or not. | |
""" | |
raise NotImplementedError() | |
def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): | |
""" | |
Fuse two nodes | |
""" | |
if node1.is_foreach() or node2.is_foreach(): | |
return ForeachKernelSchedulerNode.fuse(node1, node2) | |
else: | |
return FusedSchedulerNode.fuse(node1, node2) | |
def group_fn(self, sizes): | |
""" | |
Process the iteration sizes in case a transformation needs to be applied. | |
""" | |
raise NotImplementedError() | |
def codegen_template( | |
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] | |
): | |
""" | |
Given a template node, generate a kernel. | |
This function is only available for triton now. If the third-party backend behaves as a sub-class | |
of TritonScheduling, it can override it or reuse it. | |
""" | |
raise NotImplementedError() | |
def codegen_nodes(self, nodes: List[SchedulerNode]): | |
""" | |
Generate a kernel given a list of pre-fused nodes. | |
""" | |
raise NotImplementedError() | |
def codegen_sync(self): | |
""" | |
Generate synchronization code for the kernel. This method depends on the hardware characteristics. | |
""" | |
raise NotImplementedError() | |
def ready_to_flush(self) -> bool: | |
""" | |
Check whether the backend is requesting the scheduler to flush the generated kernel. | |
If not supported, please return False. | |
""" | |
return False | |
def flush(self): | |
""" | |
Flush the generated kernel and python wrapper code to the source code file. | |
""" | |
raise NotImplementedError() | |
def benchmark_fused_nodes(self, nodes): | |
""" | |
Benchmark fused list of nodes and return the execution time | |
in milliseconds on randomly generated inputs. | |
""" | |
raise NotImplementedError() | |