Spaces:
Running
Running
# mypy: ignore-errors | |
import functools | |
import itertools | |
import math | |
import sys | |
from typing import Callable, Union | |
import torch | |
import torch._custom_op | |
import torch._logging | |
from torch._ops import OpOverload | |
from torch._prims_common import ( | |
elementwise_dtypes, | |
ELEMENTWISE_TYPE_PROMOTION_KIND, | |
is_boolean_dtype, | |
is_float_dtype, | |
is_integer_dtype, | |
) | |
from torch._subclasses.fake_tensor import ( | |
DataDependentOutputException, | |
DynamicOutputShapeException, | |
FakeTensor, | |
in_kernel_invocation_manager, | |
run_fallback_kernel, | |
UnsupportedOperatorException, | |
) | |
from torch.fx.operator_schemas import normalize_function | |
from torch.utils._stats import count_label | |
pytree = torch.utils._pytree | |
__all__ = [ | |
"op_implementations_checks", | |
"get_fast_op_impls", | |
"stride_incorrect_op", | |
"has_meta", | |
] | |
op_implementations_dict = {} | |
op_implementations_checks = [] | |
aten = torch._ops.ops.aten | |
def ordered_set(*items): | |
return dict.fromkeys(items, True) | |
# This function indicates if the backend device | |
# supports non-contiguous tensors | |
def is_noncontiguous_supported(device): | |
if device.type == "hpu": | |
return False | |
return True | |
_like_tensor_constructors = ordered_set( | |
aten.empty_like.default, | |
aten.empty_like.out, | |
aten.full_like.default, | |
aten.full_like.out, | |
aten.ones_like.default, | |
aten.ones_like.out, | |
aten.rand_like.default, | |
aten.rand_like.out, | |
aten.randn_like.default, | |
aten.randn_like.out, | |
aten.randint_like.default, | |
aten.randint_like.out, | |
aten.randint_like.low_dtype, | |
aten.randint_like.low_dtype_out, | |
aten.zeros_like.default, | |
aten.zeros_like.out, | |
aten.new_empty.default, | |
aten.new_empty.out, | |
aten.new_empty_strided.default, | |
aten.new_empty_strided.out, | |
aten.new_full.default, | |
aten.new_full.out, | |
aten.new_zeros.default, | |
aten.new_zeros.out, | |
aten.new_ones.default, | |
aten.new_ones.out, | |
) | |
_device_not_kwarg_ops = ordered_set( | |
aten._resize_output_.default, | |
aten._nested_tensor_from_tensor_list.default, | |
aten._nested_tensor_from_tensor_list.out, | |
aten.pin_memory.default, | |
aten.is_pinned.default, | |
aten.to.device, | |
aten.to.prim_Device, | |
aten._pin_memory.default, | |
aten._pin_memory.out, | |
aten._resize_output.default, | |
aten._resize_output.out, | |
) | |
# this op is never actually used | |
_non_kwarg_device_constructors = (aten._list_to_tensor,) | |
def contains_tensor_types(type): | |
tensor_type = torch._C.TensorType.get() | |
return type.isSubtypeOf(tensor_type) or any( | |
contains_tensor_types(e) for e in type.containedTypes() | |
) | |
def _is_tensor_constructor(func: OpOverload): | |
assert isinstance(func, OpOverload) | |
schema = func._schema | |
if any(contains_tensor_types(arg.type) for arg in schema.arguments): | |
return False | |
# TODO: no real reason to restrict multiple outputs | |
return ( | |
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() | |
) | |
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): | |
def impl_decorator(op_impl): | |
if isinstance(run_impl_check, OpOverload): | |
assert ( | |
run_impl_check not in op_implementations_dict | |
), f"duplicate registration: {run_impl_check}" | |
op_implementations_dict[run_impl_check] = op_impl | |
elif isinstance(run_impl_check, (list, tuple)): | |
for op in run_impl_check: | |
register_op_impl(op)(op_impl) | |
else: | |
assert callable(run_impl_check) | |
op_implementations_checks.append((run_impl_check, op_impl)) | |
return op_impl | |
return impl_decorator | |
def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs): | |
return op_implementations_dict[func](fake_mode, func, *args, **kwargs) | |
def constructors(fake_mode, func, *args, **kwargs): | |
assert func not in _non_kwarg_device_constructors | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
if "names" in kwargs: | |
raise UnsupportedOperatorException( | |
"torch.compile doesn't support named tensors" | |
) | |
if func in _like_tensor_constructors: | |
default_device = new_kwargs["input"].device | |
# TODO: file issue | |
args = (new_kwargs.pop("input"),) | |
else: | |
# cpu is default device if none is specified | |
default_device = torch.device("cpu") | |
args = () | |
out_device = new_kwargs.pop("device", None) | |
out_device = out_device if out_device is not None else default_device | |
new_kwargs["device"] = torch.device("meta") | |
# _like constructors have fake tensor inputs (maybe this causes the non-like | |
# to fail? hmmm) | |
with in_kernel_invocation_manager(fake_mode): | |
r = func(*args, **new_kwargs) | |
return FakeTensor(fake_mode, r, out_device) | |
def non_kwarg_to(fake_mode, func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args, kwargs, normalize_to_only_use_kwargs=True | |
) | |
input_device = new_kwargs["device"] | |
out_device = input_device if input_device else new_kwargs["input"].device | |
new_kwargs["device"] = torch.device("meta") | |
inp = new_kwargs.pop("input") | |
with in_kernel_invocation_manager(fake_mode): | |
r = func(inp, **new_kwargs) | |
# TODO: I think this does the wrong thing if r is inp | |
return fake_mode.fake_tensor_converter.from_meta_and_device( | |
fake_mode, r, out_device | |
) | |
def stride_incorrect_op(op): | |
if op.namespace not in ("aten", "prims"): | |
return False | |
if op is aten._fft_c2c.default: | |
return False | |
op_name = op.name() | |
if "fft" in op_name: | |
return True | |
return False | |
# These operators have meta implementations with incorrect strides | |
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): | |
# This is a workaround for meta implmentations with incorrect strides | |
def is_symbolic(x): | |
if isinstance(x, FakeTensor): | |
return x._has_symbolic_sizes_strides | |
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)): | |
return True | |
return False | |
# For static shapes, we can fall back to eager for the real strides | |
if fake_mode.allow_fallback_kernels: | |
require_dynamic = any( | |
is_symbolic(x) for x in itertools.chain(args, kwargs.values()) | |
) | |
if not require_dynamic: | |
flat_args, args_spec = pytree.tree_flatten((args, kwargs)) | |
return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None) | |
raise UnsupportedOperatorException(func) | |
# Dont default to default device handling, | |
# since the device of `the_template` is ignored | |
def resize_as_(fake_mode, func, *args, **kwargs): | |
with in_kernel_invocation_manager(fake_mode): | |
return func(*args, **kwargs) | |
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): | |
# TODO: remove me | |
return constructors(fake_mode, func, *args, **kwargs) | |
# index.Tensor data-dependent in only some conditions | |
def dyn_shape(fake_mode, func, *args, **kwargs): | |
raise DynamicOutputShapeException(func) | |
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): | |
if output_size is None: | |
if ( | |
fake_mode.shape_env is None | |
or not fake_mode.shape_env.allow_dynamic_output_shape_ops | |
): | |
raise DynamicOutputShapeException(func) | |
output_size = fake_mode.shape_env.create_unbacked_symint() | |
# Avoid importing sympy at a module level | |
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size | |
_constrain_range_for_size(output_size) | |
# TODO: consider a memo | |
return repeats.new_empty(output_size) | |
def local_scalar_dense(fake_mode, func, arg): | |
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs: | |
# Without symints/symfloats, cannot handle this | |
raise DataDependentOutputException(func) | |
if is_float_dtype(arg.dtype): | |
return fake_mode.shape_env.create_unbacked_symfloat() | |
elif is_integer_dtype(arg.dtype): | |
return fake_mode.shape_env.create_unbacked_symint() | |
elif is_boolean_dtype(arg.dtype): | |
return fake_mode.shape_env.create_unbacked_symbool() | |
else: | |
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") | |
def nonzero(fake_mode, func, arg): | |
if ( | |
fake_mode.shape_env is None | |
or not fake_mode.shape_env.allow_dynamic_output_shape_ops | |
): | |
# Without symints/symfloats, cannot handle this | |
raise DynamicOutputShapeException(func) | |
if arg.nonzero_memo is None: | |
nnz = fake_mode.shape_env.create_unbacked_symint() | |
# This is unsound, but it works well in practice | |
# See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# | |
# TODO: Add a config knob to turn off this unsound behavior | |
# | |
# NB: If numel < 2, the bounds here might be COMPLETELY | |
# disjoint with what can actually occur. But this is fine: | |
# remember, the hypothesis is that if your later code works | |
# with N >= 2, it will work with N = 1 and N = 0. | |
maxval = sys.maxsize - 1 | |
# Avoid importing sympy at a module level | |
from torch.fx.experimental.symbolic_shapes import ( | |
_constrain_range_for_size, | |
has_free_symbols, | |
) | |
if not has_free_symbols(arg.numel()): | |
# Don't upgrade the range if numel is less than two, since we then | |
# have an empty range which makes things go explodey. We also | |
# don't allow for 2 because that would specialize the unbacked | |
# SymInt to 2, which is also likely to be buggy. | |
if arg.numel() > 2: | |
maxval = int(arg.numel()) | |
_constrain_range_for_size(nnz, max=maxval) | |
arg._nonzero_memo = nnz | |
arg._nonzero_memo_vc = arg._version | |
return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64) | |
def masked_select(fake_mode, func, self, mask): | |
if ( | |
fake_mode.shape_env is None | |
or not fake_mode.shape_env.allow_dynamic_output_shape_ops | |
): | |
# Without symints/symfloats, cannot handle this | |
raise DynamicOutputShapeException(func) | |
nnz = fake_mode.shape_env.create_unbacked_symint() | |
# see nonzero for commentary | |
maxval = sys.maxsize - 1 | |
# Avoid importing sympy at a module level | |
from torch.fx.experimental.symbolic_shapes import ( | |
_constrain_range_for_size, | |
has_free_symbols, | |
) | |
if not has_free_symbols(self.numel()): | |
if self.numel() > 2: | |
maxval = int(self.numel()) | |
_constrain_range_for_size(nnz, max=maxval) | |
return self.new_empty((nnz,)) | |
# NB: this must be ordered after local_scalar_dense | |
def data_dep(fake_mode, func, *args, **kwargs): | |
raise DataDependentOutputException(func) | |
# Bool Indices get Expanded as Masks | |
# See: IndexingUtils.h:expandTensors | |
def check_no_bool_index_tensors(func, self, indices): | |
for index in indices: | |
if index is not None and index.dtype in (torch.bool, torch.uint8): | |
raise DynamicOutputShapeException(func) | |
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
out_device = new_kwargs["input"].device | |
with in_kernel_invocation_manager(fake_mode): | |
out = func(*args, **kwargs) | |
if not is_noncontiguous_supported(out_device): | |
out = out.new_empty(out.shape) | |
if out is new_kwargs["input"]: | |
return out # copy_ | |
return FakeTensor(fake_mode, out, out_device) | |
_is_builtin_namespaces = ordered_set("aten", "prims", "prim") | |
def is_builtin(op): | |
return op.namespace in _is_builtin_namespaces | |
def has_meta(func): | |
return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") | |
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): | |
tensor_lists = [] | |
for arg in itertools.chain(args, kwargs.values()): | |
if ( | |
isinstance(arg, (list, tuple)) | |
and len(arg) | |
and isinstance(arg[0], torch.Tensor) | |
): | |
tensor_lists.append(arg) | |
try: | |
with in_kernel_invocation_manager(fake_mode): | |
out_meta = func(*args, **kwargs) | |
except NotImplementedError as not_implemented_error: | |
return NotImplemented | |
if not out_meta: | |
return out_meta | |
assert tensor_lists | |
out_fake = [] | |
for i, meta_t in enumerate(out_meta): | |
device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists]) | |
out_fake.append( | |
fake_mode.fake_tensor_converter.from_meta_and_device( | |
fake_mode, meta_t, device | |
) | |
) | |
return out_fake | |
# Dont default to default device handling, | |
# Since op can take in non-zero sized cpu | |
# index tensors with cuda self | |
def index_tensor(fake_mode, func, *args, **kwargs): | |
from torch._meta_registrations import meta_index_Tensor | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
out_device = new_kwargs["input"].device | |
# ensure nonzero call goes to fake tensor | |
with fake_mode: | |
out = meta_index_Tensor(*args, **kwargs) | |
return out.to(out_device) | |
# Can take mixed meta/non-meta arguments; the meta registration | |
# will roughly do the right thing even when given real devices | |
def embedding_bag(fake_mode, func, *args, **kwargs): | |
from torch._meta_registrations import meta_embedding_bag | |
with fake_mode: | |
return meta_embedding_bag(*args, **kwargs) | |
# takes in multiple-devices, dont default to default device handling | |
def multi_device_op_default(fake_mode, func, *args, **kwargs): | |
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) | |
# same with multi_device_op_default, but return the input | |
def multi_device_op_out(fake_mode, func, *args, **kwargs): | |
with in_kernel_invocation_manager(fake_mode): | |
out = func(*args, **kwargs) | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
return new_kwargs["input"] | |
def index_put_impl(fake_mode, func, *args, **kwargs): | |
_, new_kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
values = new_kwargs["values"] | |
self_device = new_kwargs["input"].fake_device | |
torch._check( | |
self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1), | |
lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})", | |
) | |
out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) | |
if func is aten.index_put_.default: | |
return new_kwargs["input"] | |
else: | |
return out | |
def nested_tensors_unsupported(fake_mode, func, *args, **kwargs): | |
raise UnsupportedOperatorException( | |
"torch.compile does not support strided NestedTensor" | |
) | |
def nyi(fake_mode, func, *args, **kwargs): | |
assert func not in _device_not_kwarg_ops, f"NYI: {func}" | |
def conv(fake_mode, func, *args, **kwargs): | |
_, kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
device = kwargs["input"].fake_device | |
# need to re-enable mode so the tensors report fake device | |
with fake_mode: | |
# if the input is unsqueezed is done in Convolution.cpp we get segfault | |
k = kwargs["weight"].ndim | |
batch = kwargs["input"].shape[0] | |
# Avoid importing sympy at a module level | |
from torch.fx.experimental.symbolic_shapes import has_hint | |
if not has_hint(batch): | |
# TODO: We can make this a little more faithful with best effort | |
# channels last detection (but only if it's statically obvious!) | |
mem_fmt = None | |
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: | |
mem_fmt = None | |
else: | |
if func is aten.convolution.default: | |
conv_backend = torch._C._select_conv_backend(**kwargs) | |
else: | |
conv_backend = torch._C._select_conv_backend( | |
kwargs["input"], | |
kwargs["weight"], | |
bias=None, | |
stride=kwargs["stride"], | |
padding=kwargs["padding"], | |
dilation=kwargs["dilation"], | |
transposed=kwargs["transposed"], | |
output_padding=kwargs["output_padding"], | |
groups=kwargs["groups"], | |
bias_sizes=kwargs["bias_sizes"], | |
) | |
mem_fmt = torch._C._conv_determine_backend_memory_format( | |
kwargs["input"], kwargs["weight"], conv_backend | |
) | |
def convert(t, mem_fmt): | |
if t is None: | |
return t | |
if mem_fmt is not None: | |
t = t.to(memory_format=mem_fmt) | |
return FakeTensor(fake_mode, t, device) | |
with in_kernel_invocation_manager(fake_mode): | |
out = func(**kwargs) | |
if func is aten.convolution.default: | |
return convert(out, mem_fmt) | |
else: | |
return ( | |
convert(out[0], mem_fmt), | |
convert(out[1], mem_fmt), | |
convert(out[2], None), | |
) | |
def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs): | |
_, kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
query = kwargs["query"] | |
key = kwargs["key"] | |
return_debug_mask = kwargs["return_debug_mask"] | |
# unused: value, dropout_p, is_causal, scale | |
def convert_tensor(t, device): | |
return FakeTensor(fake_mode, t, device) | |
batch_size = query.size(0) | |
num_heads = query.size(1) | |
max_seqlen_batch_q = query.size(2) | |
head_dim = query.size(3) | |
max_seqlen_batch_k = key.size(2) | |
query_t = query.transpose(1, 2) | |
# empty_like already returns a fake tensor so we don't need to convert it | |
attention = torch.empty_like(query_t).transpose(1, 2) | |
logsumexp = convert_tensor( | |
torch.empty( | |
(batch_size, num_heads, max_seqlen_batch_q), | |
dtype=torch.float, | |
device="meta", | |
), | |
device=query.device, | |
) | |
if return_debug_mask: | |
blocksize_c = 128 if head_dim > 64 else 256 | |
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) | |
if max_seqlen_batch_k <= 128: | |
max_seqlen_k = 128 | |
elif max_seqlen_batch_k <= 256: | |
max_seqlen_k = 256 | |
debug_mask = convert_tensor( | |
torch.empty( | |
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), | |
dtype=query.dtype, | |
device="meta", | |
), | |
device=query.device, | |
) | |
else: | |
debug_mask = convert_tensor( | |
torch.empty(0, dtype=query.dtype, device="meta"), | |
query.device, | |
) | |
# Note [Seed and Offset]: device for seed and offset below depends on whether we are | |
# capturing or not, but at the time of tracing we don't know if we | |
# are going to use cudagraphs or not, so we return meta tensors here | |
# it's possible we'll need to have some special handling in inductor for sdpa | |
return ( | |
attention, | |
logsumexp, | |
None, | |
None, | |
max_seqlen_batch_q, | |
max_seqlen_batch_k, | |
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), | |
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), | |
debug_mask, | |
) | |
def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs): | |
_, kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
query = kwargs["query"] | |
key = kwargs["key"] | |
value = kwargs["value"] | |
compute_log_sumexp = kwargs["compute_log_sumexp"] | |
# unused: attn_bias, dropout_p, is_causal, scale | |
def convert_tensor(t, device): | |
return FakeTensor(fake_mode, t, device) | |
query = query.transpose(1, 2) | |
key = key.transpose(1, 2) | |
value = value.transpose(1, 2) | |
B = query.size(0) | |
M = query.size(1) | |
N = key.size(1) | |
num_heads = query.size(-2) | |
K = query.size(-1) | |
Kv = value.size(-1) | |
res = convert_tensor( | |
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), | |
query.device, | |
) | |
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 | |
logsum_exp = convert_tensor( | |
torch.empty( | |
(B, num_heads, logsumexp_dim), | |
dtype=torch.float, | |
device="meta", | |
), | |
query.device, | |
) | |
res = res.transpose(1, 2) | |
# See Note [Seed and Offset]: | |
seed = convert_tensor( | |
torch.empty((), dtype=torch.long, device="meta"), query.device | |
) | |
offset = convert_tensor( | |
torch.empty((), dtype=torch.long, device="meta"), query.device | |
) | |
return res, logsum_exp, seed, offset | |
def meta__flash_attention_forward(fake_mode, func, *args, **kwargs): | |
_, kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
query = kwargs["query"] | |
key = kwargs["key"] | |
cum_seq_q = kwargs["cum_seq_q"] | |
cum_seq_k = kwargs["cum_seq_k"] | |
max_q = kwargs["max_q"] | |
max_k = kwargs["max_k"] | |
return_debug_mask = kwargs["return_debug_mask"] | |
# unused: value, dropout_p, is_causal, scale | |
def convert_tensor(t, device): | |
return FakeTensor(fake_mode, t, device) | |
# NB: there are two underlying paths: | |
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) | |
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total | |
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total | |
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 | |
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q | |
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k | |
num_heads = query.size(-2) | |
head_dim = query.size(-1) | |
# Cuda Path | |
# note: empty_like already returns a fake tensor, we don't need to wrap it | |
attention = torch.empty_like(query) | |
logsumexp = convert_tensor( | |
torch.empty( | |
(batch_size, num_heads, max_seqlen_batch_q), | |
dtype=torch.float, | |
device="meta", | |
), | |
device=query.device, | |
) | |
if return_debug_mask: | |
blocksize_c = 128 if head_dim > 64 else 256 | |
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) | |
if max_seqlen_batch_k <= 128: | |
max_seqlen_k = 128 | |
elif max_seqlen_batch_k <= 256: | |
max_seqlen_k = 256 | |
debug_mask = convert_tensor( | |
torch.empty( | |
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), | |
dtype=query.dtype, | |
device="meta", | |
), | |
query.device, | |
) | |
else: | |
debug_mask = convert_tensor( | |
torch.empty(0, dtype=query.dtype, device="meta"), | |
query.device, | |
) | |
# See Note [Seed and Offset]: | |
return ( | |
attention, | |
logsumexp, | |
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), | |
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), | |
debug_mask, | |
) | |
def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): | |
_, kwargs = normalize_function( | |
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | |
) | |
query = kwargs["query"] | |
key = kwargs["key"] | |
value = kwargs["value"] | |
cu_seqlens_q = kwargs["cu_seqlens_q"] | |
max_seqlen_q = kwargs["max_seqlen_q"] | |
max_seqlen_k = kwargs["max_seqlen_k"] | |
compute_log_sumexp = kwargs["compute_log_sumexp"] | |
# unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k | |
def convert_tensor(t, device): | |
return FakeTensor(fake_mode, t, device) | |
B = query.size(0) | |
M = query.size(1) | |
N = key.size(1) | |
num_heads = query.size(-2) | |
K = query.size(-1) | |
Kv = value.size(-1) | |
res = convert_tensor( | |
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), | |
query.device, | |
) | |
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B | |
actual_max_seqlen_q = M | |
if cu_seqlens_q is not None: | |
assert max_seqlen_q is not None | |
actual_max_seqlen_q = max_seqlen_q | |
actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N | |
logsumexp_dim = ( | |
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 | |
) | |
logsum_exp = convert_tensor( | |
torch.empty( | |
(logsumexp_batch_dim, num_heads, logsumexp_dim), | |
dtype=torch.float, | |
device="meta", | |
), | |
query.device, | |
) | |
# See Note [Seed and Offset]: | |
seed = convert_tensor( | |
torch.empty((), dtype=torch.long, device="meta"), query.device | |
) | |
offset = convert_tensor( | |
torch.empty((), dtype=torch.long, device="meta"), query.device | |
) | |
return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k | |
FAST_OP_IMPLEMENTATIONS = {} | |
# Unlike register_op_impl, these don't do the slow iteration for | |
# run_impl_check, and these run BEFORE decompositions | |
def register_fast_op_impl(func: OpOverload): | |
def impl_decorator(op_impl): | |
FAST_OP_IMPLEMENTATIONS[func] = op_impl | |
return op_impl | |
return impl_decorator | |
# infer_size_impl in ExpandUtils | |
def infer_size(a, b): | |
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious | |
dimsA = len(a) | |
dimsB = len(b) | |
ndim = max(dimsA, dimsB) | |
expandedSizes = [0] * ndim | |
for i in range(ndim - 1, -1, -1): | |
offset = ndim - 1 - i | |
dimA = dimsA - 1 - offset | |
dimB = dimsB - 1 - offset | |
sizeA = a[dimA] if dimA >= 0 else 1 | |
sizeB = b[dimB] if dimB >= 0 else 1 | |
# NB: It is very important to test for broadcasting, before testing | |
# sizeA == sizeB. This is because the broadcasting tests are likely | |
# to be statically known (in particular, if sizeA/sizeB is unbacked | |
# but size-like, we will unsoundly assume they never equal 1), but | |
# the sizeA == sizeB test may not be statically known. However, once | |
# we have established that no broadcasting is happening, the | |
# sizeA == sizeB is now expect_true and we can defer it as a runtime | |
# assert (this works because Python will return the terminal | |
# expression of an or statement as-is, without bool()'ing it; if this | |
# were not the case, we'd need to write this using torch.sym_or() or | |
# something like that). | |
torch._check( | |
guard_size_oblivious(sizeA == 1) | |
or guard_size_oblivious(sizeB == 1) | |
or sizeA == sizeB, | |
lambda: f"The size of tensor a ({sizeA}) " | |
f"must match the size of tensor b ({sizeB}) " | |
f"at non-singleton dimension {i})", | |
) | |
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA | |
return tuple(expandedSizes) | |
def make_fast_binary_impl(slow_ref): | |
def fast_binary_impl(mode, *args, **kwargs): | |
def slow(msg): | |
count_label(f"slow {msg}") | |
with mode: | |
return slow_ref(*args, **kwargs) | |
count_label("attempt fast") | |
# Fast path (based off of TensorIterator fast path). | |
# Unfortunately, there is no way to easily deduplicate | |
# this with either the TensorIterator C++ implementation | |
# (which we don't want to SymIntify, and also the algorithm | |
# here is slightly different from TensorIterator to allow | |
# for broadcasting), nor the PrimTorch implementation | |
# (which does not actually implement a fast path.) | |
operands = args | |
# compute_shape | |
has_scalars = False | |
has_tensors = False | |
final_shape = None | |
for op in operands: | |
shape = op.shape if isinstance(op, torch.Tensor) else () | |
if len(shape) == 0: | |
has_scalars = True | |
else: | |
has_tensors = True | |
if final_shape is None: | |
final_shape = shape | |
# TODO: Minor optimization: track if the shapes | |
# were equal so you can skip the equality check | |
# below if unnecessary | |
final_shape = infer_size(final_shape, shape) | |
assert final_shape is not None | |
# Do some extra safety checks to see if the output | |
# stride is obvious | |
for op in operands: | |
if ( | |
isinstance(op, torch.Tensor) | |
and len(op.shape) == len(final_shape) | |
and op.shape == final_shape | |
): | |
break | |
else: | |
return slow("both tensors nontrivially broadcast") | |
# compute_types | |
cpu = torch.device("cpu") | |
common_device = cpu | |
common_dtype = None | |
output_dtype = None | |
has_different_input_dtypes = False | |
for op in operands: | |
if not isinstance(op, torch.Tensor): | |
# Use elementwise_dtypes for the tricky case | |
has_different_input_dtypes = True | |
continue | |
if common_device == cpu and not op.device.type == "cpu": | |
common_device = op.device | |
# Slightly simplified here as target_dtype cannot vary | |
if common_dtype is None: | |
common_dtype = op.dtype | |
elif common_dtype != op.dtype: | |
has_different_input_dtypes = True | |
if has_different_input_dtypes: | |
# compute promotion | |
# TODO: we don't need the compute type | |
_, common_dtype = elementwise_dtypes( | |
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT | |
) | |
# check all tensors on same device | |
# cpu scalars are assumed allow | |
current_cpu_scalars_on_non_cpu = 0 | |
max_cpu_scalars_on_non_cpu = 1 # hard coded atm | |
for op in operands: | |
if not isinstance(op, torch.Tensor): | |
continue | |
if common_device != cpu and op.dim() == 0 and op.device == cpu: | |
if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu: | |
return slow("error") | |
current_cpu_scalars_on_non_cpu += 1 | |
elif op.device != common_device: | |
return slow("error") | |
# compute_fast_setup_type | |
is_contiguous = True | |
is_channels_last = True | |
# TODO: is_non-overlapping_and_dense (not bound from Python | |
# no inplace, no out, everything defined | |
if is_noncontiguous_supported(common_device): | |
for op in operands: | |
if not isinstance(op, torch.Tensor): | |
continue | |
is_contiguous = is_contiguous and op.is_contiguous( | |
memory_format=torch.contiguous_format | |
) | |
is_channels_last = is_channels_last and op.is_contiguous( | |
memory_format=torch.channels_last | |
) | |
if is_contiguous: | |
# do contiguous | |
count_label("fast is_contiguous") | |
return FakeTensor( | |
mode, | |
torch.empty( | |
final_shape, | |
dtype=common_dtype, | |
device="meta", | |
memory_format=torch.contiguous_format, | |
), | |
device=common_device, | |
) | |
if is_channels_last: | |
count_label("fast channels_last") | |
# do channels last | |
return FakeTensor( | |
mode, | |
torch.empty( | |
final_shape, | |
dtype=common_dtype, | |
device="meta", | |
memory_format=torch.channels_last, | |
), | |
device=common_device, | |
) | |
return slow("no contiguity match") | |
return fast_binary_impl | |
def get_fast_op_impls(): | |
import torch._refs | |
register_fast_op_impl(torch.ops.aten.add.Tensor)( | |
make_fast_binary_impl(torch._refs.add) | |
) | |
register_fast_op_impl(torch.ops.aten.sub.Tensor)( | |
make_fast_binary_impl(torch._refs.sub) | |
) | |
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] | |
register_fast_op_impl(torch.ops.aten.div.Tensor)( | |
make_fast_binary_impl(torch._refs.div) | |
) | |
return FAST_OP_IMPLEMENTATIONS | |