mbuali's picture
Upload folder using huggingface_hub
d1ceb73 verified
# mypy: ignore-errors
import functools
import itertools
import math
import sys
from typing import Callable, Union
import torch
import torch._custom_op
import torch._logging
from torch._ops import OpOverload
from torch._prims_common import (
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype,
is_float_dtype,
is_integer_dtype,
)
from torch._subclasses.fake_tensor import (
DataDependentOutputException,
DynamicOutputShapeException,
FakeTensor,
in_kernel_invocation_manager,
run_fallback_kernel,
UnsupportedOperatorException,
)
from torch.fx.operator_schemas import normalize_function
from torch.utils._stats import count_label
pytree = torch.utils._pytree
__all__ = [
"op_implementations_checks",
"get_fast_op_impls",
"stride_incorrect_op",
"has_meta",
]
op_implementations_dict = {}
op_implementations_checks = []
aten = torch._ops.ops.aten
def ordered_set(*items):
return dict.fromkeys(items, True)
# This function indicates if the backend device
# supports non-contiguous tensors
def is_noncontiguous_supported(device):
if device.type == "hpu":
return False
return True
_like_tensor_constructors = ordered_set(
aten.empty_like.default,
aten.empty_like.out,
aten.full_like.default,
aten.full_like.out,
aten.ones_like.default,
aten.ones_like.out,
aten.rand_like.default,
aten.rand_like.out,
aten.randn_like.default,
aten.randn_like.out,
aten.randint_like.default,
aten.randint_like.out,
aten.randint_like.low_dtype,
aten.randint_like.low_dtype_out,
aten.zeros_like.default,
aten.zeros_like.out,
aten.new_empty.default,
aten.new_empty.out,
aten.new_empty_strided.default,
aten.new_empty_strided.out,
aten.new_full.default,
aten.new_full.out,
aten.new_zeros.default,
aten.new_zeros.out,
aten.new_ones.default,
aten.new_ones.out,
)
_device_not_kwarg_ops = ordered_set(
aten._resize_output_.default,
aten._nested_tensor_from_tensor_list.default,
aten._nested_tensor_from_tensor_list.out,
aten.pin_memory.default,
aten.is_pinned.default,
aten.to.device,
aten.to.prim_Device,
aten._pin_memory.default,
aten._pin_memory.out,
aten._resize_output.default,
aten._resize_output.out,
)
# this op is never actually used
_non_kwarg_device_constructors = (aten._list_to_tensor,)
def contains_tensor_types(type):
tensor_type = torch._C.TensorType.get()
return type.isSubtypeOf(tensor_type) or any(
contains_tensor_types(e) for e in type.containedTypes()
)
@functools.lru_cache(None)
def _is_tensor_constructor(func: OpOverload):
assert isinstance(func, OpOverload)
schema = func._schema
if any(contains_tensor_types(arg.type) for arg in schema.arguments):
return False
# TODO: no real reason to restrict multiple outputs
return (
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
)
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
def impl_decorator(op_impl):
if isinstance(run_impl_check, OpOverload):
assert (
run_impl_check not in op_implementations_dict
), f"duplicate registration: {run_impl_check}"
op_implementations_dict[run_impl_check] = op_impl
elif isinstance(run_impl_check, (list, tuple)):
for op in run_impl_check:
register_op_impl(op)(op_impl)
else:
assert callable(run_impl_check)
op_implementations_checks.append((run_impl_check, op_impl))
return op_impl
return impl_decorator
@register_op_impl(op_implementations_dict.__contains__)
def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
@register_op_impl(_is_tensor_constructor)
@register_op_impl([*_like_tensor_constructors])
def constructors(fake_mode, func, *args, **kwargs):
assert func not in _non_kwarg_device_constructors
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
if "names" in kwargs:
raise UnsupportedOperatorException(
"torch.compile doesn't support named tensors"
)
if func in _like_tensor_constructors:
default_device = new_kwargs["input"].device
# TODO: file issue
args = (new_kwargs.pop("input"),)
else:
# cpu is default device if none is specified
default_device = torch.device("cpu")
args = ()
out_device = new_kwargs.pop("device", None)
out_device = out_device if out_device is not None else default_device
new_kwargs["device"] = torch.device("meta")
# _like constructors have fake tensor inputs (maybe this causes the non-like
# to fail? hmmm)
with in_kernel_invocation_manager(fake_mode):
r = func(*args, **new_kwargs)
return FakeTensor(fake_mode, r, out_device)
@register_op_impl(aten.to.prim_Device)
@register_op_impl(aten.to.device)
def non_kwarg_to(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args, kwargs, normalize_to_only_use_kwargs=True
)
input_device = new_kwargs["device"]
out_device = input_device if input_device else new_kwargs["input"].device
new_kwargs["device"] = torch.device("meta")
inp = new_kwargs.pop("input")
with in_kernel_invocation_manager(fake_mode):
r = func(inp, **new_kwargs)
# TODO: I think this does the wrong thing if r is inp
return fake_mode.fake_tensor_converter.from_meta_and_device(
fake_mode, r, out_device
)
def stride_incorrect_op(op):
if op.namespace not in ("aten", "prims"):
return False
if op is aten._fft_c2c.default:
return False
op_name = op.name()
if "fft" in op_name:
return True
return False
# These operators have meta implementations with incorrect strides
@register_op_impl(stride_incorrect_op)
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
# This is a workaround for meta implmentations with incorrect strides
def is_symbolic(x):
if isinstance(x, FakeTensor):
return x._has_symbolic_sizes_strides
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return True
return False
# For static shapes, we can fall back to eager for the real strides
if fake_mode.allow_fallback_kernels:
require_dynamic = any(
is_symbolic(x) for x in itertools.chain(args, kwargs.values())
)
if not require_dynamic:
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
raise UnsupportedOperatorException(func)
# Dont default to default device handling,
# since the device of `the_template` is ignored
@register_op_impl(aten.resize_as_.default)
def resize_as_(fake_mode, func, *args, **kwargs):
with in_kernel_invocation_manager(fake_mode):
return func(*args, **kwargs)
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
# TODO: remove me
return constructors(fake_mode, func, *args, **kwargs)
# index.Tensor data-dependent in only some conditions
@register_op_impl(
lambda func: torch.Tag.dynamic_output_shape in func.tags
and func
not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
)
def dyn_shape(fake_mode, func, *args, **kwargs):
raise DynamicOutputShapeException(func)
def _unique(
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
# Do not use a memo for unique_dim
if dim is not None or (nnz := arg.unique_memo) is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
# If numel is zero, then the output size must be zero.
# In this case, we must not allocate an unbacked SymInt,
# because if we do, it will immediately get refined to
# zero, but this will be inconsistent with size oblivious
# tests (which will continue to claim that the unbacked
# symint cannot equal zero). We could also unconditionally
# allocate an unbacked SymInt and not refine its range,
# but this seems more precise.
nnz = 0
else:
nnz = fake_mode.shape_env.create_unbacked_symint()
maxval = sys.maxsize - 1
numel = arg.numel() if dim is None else arg.size(dim)
if not has_free_symbols(numel):
maxval = int(numel)
_constrain_range_for_size(nnz, max=maxval)
if dim is None:
arg.unique_memo = nnz
if dim is None:
ret = [arg.new_empty((nnz,))]
else:
ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
if return_inverse or return_if_dim_and_cpu:
inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
else:
inverse = arg.new_empty(0)
ret.append(inverse)
if return_counts or return_if_dim_and_cpu:
counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
else:
counts = arg.new_empty(0)
ret.append(counts)
return tuple(ret)
@register_op_impl(aten._unique2.default)
def unique2(
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
):
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
@register_op_impl(aten.unique_dim.default)
def unique_dim(
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
):
return _unique(
fake_mode,
func,
arg,
# normalize dim to be non-negative
dim if dim >= 0 else dim % max(arg.ndim, 1),
sorted,
return_inverse,
return_counts,
)
@register_op_impl(aten.repeat_interleave.Tensor)
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
if output_size is None:
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
raise DynamicOutputShapeException(func)
output_size = fake_mode.shape_env.create_unbacked_symint()
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
_constrain_range_for_size(output_size)
# TODO: consider a memo
return repeats.new_empty(output_size)
@register_op_impl(torch.ops.aten._local_scalar_dense.default)
def local_scalar_dense(fake_mode, func, arg):
if (r := arg.item_memo) is not None:
return r
if fake_mode.shape_env is None or (
not fake_mode.shape_env.allow_scalar_outputs
and not fake_mode.allow_scalar_outputs
):
# Without symints/symfloats, cannot handle this
raise DataDependentOutputException(func)
if is_float_dtype(arg.dtype):
r = fake_mode.shape_env.create_unbacked_symfloat()
elif is_integer_dtype(arg.dtype):
r = fake_mode.shape_env.create_unbacked_symint()
elif is_boolean_dtype(arg.dtype):
r = fake_mode.shape_env.create_unbacked_symbool()
else:
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
arg.item_memo = r
return r
@register_op_impl(torch.ops.aten.nonzero.default)
def nonzero(fake_mode, func, arg):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
if (nnz := arg.nonzero_memo) is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
if not has_free_symbols(arg.numel()) and arg.numel() == 0:
# If numel is zero, then the output size must be zero.
# In this case, we must not allocate an unbacked SymInt,
# because if we do, it will immediately get refined to
# zero, but this will be inconsistent with size oblivious
# tests (which will continue to claim that the unbacked
# symint cannot equal zero). We could also unconditionally
# allocate an unbacked SymInt and not refine its range,
# but this seems more precise.
nnz = 0
else:
nnz = fake_mode.shape_env.create_unbacked_symint()
maxval = sys.maxsize - 1
if not has_free_symbols(arg.numel()):
maxval = int(arg.numel())
_constrain_range_for_size(nnz, max=maxval)
arg.nonzero_memo = nnz
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
@register_op_impl(torch.ops.aten.masked_select.default)
def masked_select(fake_mode, func, self, mask):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
nnz = fake_mode.shape_env.create_unbacked_symint()
# see nonzero for commentary
maxval = sys.maxsize - 1
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
has_free_symbols,
)
if not has_free_symbols(self.numel()):
if self.numel() > 2:
maxval = int(self.numel())
_constrain_range_for_size(nnz, max=maxval)
return self.new_empty((nnz,))
# NB: this must be ordered after local_scalar_dense
@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
def data_dep(fake_mode, func, *args, **kwargs):
raise DataDependentOutputException(func)
# Bool Indices get Expanded as Masks
# See: IndexingUtils.h:expandTensors
def check_no_bool_index_tensors(func, self, indices):
for index in indices:
if index is not None and index.dtype in (torch.bool, torch.uint8):
raise DynamicOutputShapeException(func)
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
out_device = new_kwargs["input"].device
with in_kernel_invocation_manager(fake_mode):
out = func(*args, **kwargs)
if not is_noncontiguous_supported(out_device):
out = out.new_empty(out.shape)
if out is new_kwargs["input"]:
return out # copy_
return FakeTensor(fake_mode, out, out_device)
_is_builtin_namespaces = ordered_set("aten", "prims", "prim")
def is_builtin(op):
return op.namespace in _is_builtin_namespaces
def has_meta(func):
return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
@register_op_impl(
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
)
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
tensor_lists = []
for arg in itertools.chain(args, kwargs.values()):
if (
isinstance(arg, (list, tuple))
and len(arg)
and isinstance(arg[0], torch.Tensor)
):
tensor_lists.append(arg)
try:
with in_kernel_invocation_manager(fake_mode):
out_meta = func(*args, **kwargs)
except NotImplementedError as not_implemented_error:
return NotImplemented
if not out_meta:
return out_meta
assert tensor_lists
out_fake = []
for i, meta_t in enumerate(out_meta):
device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
out_fake.append(
fake_mode.fake_tensor_converter.from_meta_and_device(
fake_mode, meta_t, device
)
)
return out_fake
# Dont default to default device handling,
# Since op can take in non-zero sized cpu
# index tensors with cuda self
@register_op_impl(aten.index.Tensor)
def index_tensor(fake_mode, func, *args, **kwargs):
from torch._meta_registrations import meta_index_Tensor
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
out_device = new_kwargs["input"].device
# ensure nonzero call goes to fake tensor
with fake_mode:
out = meta_index_Tensor(*args, **kwargs)
return out.to(out_device)
# Can take mixed meta/non-meta arguments; the meta registration
# will roughly do the right thing even when given real devices
@register_op_impl(aten._embedding_bag.default)
def embedding_bag(fake_mode, func, *args, **kwargs):
from torch._meta_registrations import meta_embedding_bag
with fake_mode:
return meta_embedding_bag(*args, **kwargs)
# takes in multiple-devices, dont default to default device handling
@register_op_impl(aten._unsafe_index_put.default)
@register_op_impl(aten.copy.default)
@register_op_impl(aten.copy_.default)
@register_op_impl(aten.slice_scatter.default)
def multi_device_op_default(fake_mode, func, *args, **kwargs):
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
# same with multi_device_op_default, but return the input
@register_op_impl(aten.copy.out)
@register_op_impl(aten.slice_scatter.out)
def multi_device_op_out(fake_mode, func, *args, **kwargs):
with in_kernel_invocation_manager(fake_mode):
out = func(*args, **kwargs)
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
return new_kwargs["input"]
@register_op_impl(aten.index_put.default)
@register_op_impl(aten.index_put_.default)
def index_put_impl(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
values = new_kwargs["values"]
self_device = new_kwargs["input"].fake_device
torch._check(
self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
)
out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
if func is aten.index_put_.default:
return new_kwargs["input"]
else:
return out
@register_op_impl(aten._nested_tensor_from_tensor_list.default)
@register_op_impl(aten._nested_tensor_from_tensor_list.out)
@register_op_impl(aten._nested_view_from_buffer.default)
@register_op_impl(aten._nested_view_from_buffer_copy.default)
def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
raise UnsupportedOperatorException(
"torch.compile does not support strided NestedTensor"
)
@register_op_impl(
[
x
for x in _device_not_kwarg_ops
if x
not in (
# these are already registered elsewhere
aten.to.device,
aten.to.prim_Device,
aten._nested_tensor_from_tensor_list.default,
aten._nested_tensor_from_tensor_list.out,
)
]
)
def nyi(fake_mode, func, *args, **kwargs):
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
@register_op_impl([aten.convolution.default, aten.convolution_backward.default])
def conv(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
device = kwargs["input"].fake_device
# need to re-enable mode so the tensors report fake device
with fake_mode:
# if the input is unsqueezed is done in Convolution.cpp we get segfault
k = kwargs["weight"].ndim
batch = kwargs["input"].shape[0]
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import has_hint
if not has_hint(batch):
# TODO: We can make this a little more faithful with best effort
# channels last detection (but only if it's statically obvious!)
mem_fmt = None
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
mem_fmt = None
else:
if func is aten.convolution.default:
conv_backend = torch._C._select_conv_backend(**kwargs)
else:
conv_backend = torch._C._select_conv_backend(
kwargs["input"],
kwargs["weight"],
bias=None,
stride=kwargs["stride"],
padding=kwargs["padding"],
dilation=kwargs["dilation"],
transposed=kwargs["transposed"],
output_padding=kwargs["output_padding"],
groups=kwargs["groups"],
bias_sizes=kwargs["bias_sizes"],
)
mem_fmt = torch._C._conv_determine_backend_memory_format(
kwargs["input"], kwargs["weight"], conv_backend
)
def convert(t, mem_fmt):
if t is None:
return t
if mem_fmt is not None:
t = t.to(memory_format=mem_fmt)
return FakeTensor(fake_mode, t, device)
with in_kernel_invocation_manager(fake_mode):
out = func(**kwargs)
if func is aten.convolution.default:
return convert(out, mem_fmt)
else:
return (
convert(out[0], mem_fmt),
convert(out[1], mem_fmt),
convert(out[2], None),
)
@register_op_impl(aten._scaled_dot_product_flash_attention.default)
def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
return_debug_mask = kwargs["return_debug_mask"]
# unused: value, dropout_p, is_causal, scale
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
batch_size = query.size(0)
num_heads = query.size(1)
max_seqlen_batch_q = query.size(2)
head_dim = query.size(3)
max_seqlen_batch_k = key.size(2)
query_t = query.transpose(1, 2)
# empty_like already returns a fake tensor so we don't need to convert it
attention = torch.empty_like(query_t).transpose(1, 2)
logsumexp = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device="meta",
),
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device="meta",
),
device=query.device,
)
else:
debug_mask = convert_tensor(
torch.empty(0, dtype=query.dtype, device="meta"),
query.device,
)
# Note [Seed and Offset]: device for seed and offset below depends on whether we are
# capturing or not, but at the time of tracing we don't know if we
# are going to use cudagraphs or not, so we return meta tensors here
# it's possible we'll need to have some special handling in inductor for sdpa
return (
attention,
logsumexp,
None,
None,
max_seqlen_batch_q,
max_seqlen_batch_k,
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
debug_mask,
)
@register_op_impl(aten._scaled_dot_product_efficient_attention.default)
def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
value = kwargs["value"]
compute_log_sumexp = kwargs["compute_log_sumexp"]
# unused: attn_bias, dropout_p, is_causal, scale
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = convert_tensor(
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
query.device,
)
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
logsum_exp = convert_tensor(
torch.empty(
(B, num_heads, logsumexp_dim),
dtype=torch.float,
device="meta",
),
query.device,
)
res = res.transpose(1, 2)
# See Note [Seed and Offset]:
seed = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
offset = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
return res, logsum_exp, seed, offset
@register_op_impl(aten._flash_attention_forward.default)
def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
cum_seq_q = kwargs["cum_seq_q"]
cum_seq_k = kwargs["cum_seq_k"]
max_q = kwargs["max_q"]
max_k = kwargs["max_k"]
return_debug_mask = kwargs["return_debug_mask"]
# unused: value, dropout_p, is_causal, scale
# unused: seqused_k, alibi_slopes, window_size_left, window_size_right
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
# NB: there are two underlying paths:
# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
num_heads = query.size(-2)
head_dim = query.size(-1)
# Cuda Path
# note: empty_like already returns a fake tensor, we don't need to wrap it
attention = torch.empty_like(query)
logsumexp = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q),
dtype=torch.float,
device="meta",
),
device=query.device,
)
if return_debug_mask:
blocksize_c = 128 if head_dim > 64 else 256
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
if max_seqlen_batch_k <= 128:
max_seqlen_k = 128
elif max_seqlen_batch_k <= 256:
max_seqlen_k = 256
debug_mask = convert_tensor(
torch.empty(
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
dtype=query.dtype,
device="meta",
),
query.device,
)
else:
debug_mask = convert_tensor(
torch.empty(0, dtype=query.dtype, device="meta"),
query.device,
)
# See Note [Seed and Offset]:
return (
attention,
logsumexp,
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
debug_mask,
)
@register_op_impl(aten._efficient_attention_forward.default)
def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
query = kwargs["query"]
key = kwargs["key"]
value = kwargs["value"]
cu_seqlens_q = kwargs["cu_seqlens_q"]
max_seqlen_q = kwargs["max_seqlen_q"]
max_seqlen_k = kwargs["max_seqlen_k"]
compute_log_sumexp = kwargs["compute_log_sumexp"]
# unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, seqlen_k
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
B = query.size(0)
M = query.size(1)
N = key.size(1)
num_heads = query.size(-2)
K = query.size(-1)
Kv = value.size(-1)
res = convert_tensor(
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
query.device,
)
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
actual_max_seqlen_q = M
if cu_seqlens_q is not None:
assert max_seqlen_q is not None
actual_max_seqlen_q = max_seqlen_q
actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
logsumexp_dim = (
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
)
logsum_exp = convert_tensor(
torch.empty(
(logsumexp_batch_dim, num_heads, logsumexp_dim),
dtype=torch.float,
device="meta",
),
query.device,
)
# See Note [Seed and Offset]:
seed = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
offset = convert_tensor(
torch.empty((), dtype=torch.long, device="meta"), query.device
)
return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
@register_op_impl(torch.ops.aten._pack_padded_sequence.default)
def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first):
if (
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
):
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
new_batch_size = fake_mode.shape_env.create_unbacked_symint()
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
_constrain_range_for_size(new_batch_size)
if not batch_first:
# Inputs should have shape (batch_size, seq_len, *)
inputs = inputs.transpose(0, 1)
res_size = inputs.shape[1:]
packed_data = inputs.new_empty(res_size)
batch_size = inputs.new_empty((new_batch_size,))
return (packed_data, batch_size)
FAST_OP_IMPLEMENTATIONS = {}
# Unlike register_op_impl, these don't do the slow iteration for
# run_impl_check, and these run BEFORE decompositions
def register_fast_op_impl(func: OpOverload):
def impl_decorator(op_impl):
FAST_OP_IMPLEMENTATIONS[func] = op_impl
return op_impl
return impl_decorator
# infer_size_impl in ExpandUtils
def infer_size(a, b):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
dimsA = len(a)
dimsB = len(b)
ndim = max(dimsA, dimsB)
expandedSizes = [0] * ndim
for i in range(ndim - 1, -1, -1):
offset = ndim - 1 - i
dimA = dimsA - 1 - offset
dimB = dimsB - 1 - offset
sizeA = a[dimA] if dimA >= 0 else 1
sizeB = b[dimB] if dimB >= 0 else 1
# NB: It is very important to test for broadcasting, before testing
# sizeA == sizeB. This is because the broadcasting tests are likely
# to be statically known (in particular, if sizeA/sizeB is unbacked
# but size-like, we will unsoundly assume they never equal 1), but
# the sizeA == sizeB test may not be statically known. However, once
# we have established that no broadcasting is happening, the
# sizeA == sizeB is now expect_true and we can defer it as a runtime
# assert (this works because Python will return the terminal
# expression of an or statement as-is, without bool()'ing it; if this
# were not the case, we'd need to write this using torch.sym_or() or
# something like that).
torch._check(
guard_size_oblivious(sizeA == 1)
or guard_size_oblivious(sizeB == 1)
or sizeA == sizeB,
lambda: f"The size of tensor a ({sizeA}) "
f"must match the size of tensor b ({sizeB}) "
f"at non-singleton dimension {i})",
)
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
return tuple(expandedSizes)
def make_fast_binary_impl(slow_ref):
def fast_binary_impl(mode, *args, **kwargs):
def slow(msg):
count_label(f"slow {msg}")
with mode:
return slow_ref(*args, **kwargs)
count_label("attempt fast")
# Fast path (based off of TensorIterator fast path).
# Unfortunately, there is no way to easily deduplicate
# this with either the TensorIterator C++ implementation
# (which we don't want to SymIntify, and also the algorithm
# here is slightly different from TensorIterator to allow
# for broadcasting), nor the PrimTorch implementation
# (which does not actually implement a fast path.)
operands = args
# compute_shape
has_scalars = False
has_tensors = False
final_shape = None
for op in operands:
shape = op.shape if isinstance(op, torch.Tensor) else ()
if len(shape) == 0:
has_scalars = True
else:
has_tensors = True
if final_shape is None:
final_shape = shape
# TODO: Minor optimization: track if the shapes
# were equal so you can skip the equality check
# below if unnecessary
final_shape = infer_size(final_shape, shape)
assert final_shape is not None
# Do some extra safety checks to see if the output
# stride is obvious
for op in operands:
if (
isinstance(op, torch.Tensor)
and len(op.shape) == len(final_shape)
and op.shape == final_shape
):
break
else:
return slow("both tensors nontrivially broadcast")
# compute_types
cpu = torch.device("cpu")
common_device = cpu
common_dtype = None
output_dtype = None
has_different_input_dtypes = False
for op in operands:
if not isinstance(op, torch.Tensor):
# Use elementwise_dtypes for the tricky case
has_different_input_dtypes = True
continue
if common_device == cpu and not op.device.type == "cpu":
common_device = op.device
# Slightly simplified here as target_dtype cannot vary
if common_dtype is None:
common_dtype = op.dtype
elif common_dtype != op.dtype:
has_different_input_dtypes = True
if has_different_input_dtypes:
# compute promotion
# TODO: we don't need the compute type
_, common_dtype = elementwise_dtypes(
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
# check all tensors on same device
# cpu scalars are assumed allow
current_cpu_scalars_on_non_cpu = 0
max_cpu_scalars_on_non_cpu = 1 # hard coded atm
for op in operands:
if not isinstance(op, torch.Tensor):
continue
if common_device != cpu and op.dim() == 0 and op.device == cpu:
if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
return slow("error")
current_cpu_scalars_on_non_cpu += 1
elif op.device != common_device:
return slow("error")
# compute_fast_setup_type
is_contiguous = True
is_channels_last = True
# TODO: is_non-overlapping_and_dense (not bound from Python
# no inplace, no out, everything defined
if is_noncontiguous_supported(common_device):
for op in operands:
if not isinstance(op, torch.Tensor):
continue
is_contiguous = is_contiguous and op.is_contiguous(
memory_format=torch.contiguous_format
)
is_channels_last = is_channels_last and op.is_contiguous(
memory_format=torch.channels_last
)
if is_contiguous:
# do contiguous
count_label("fast is_contiguous")
return FakeTensor(
mode,
torch.empty(
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.contiguous_format,
),
device=common_device,
)
if is_channels_last:
count_label("fast channels_last")
# do channels last
return FakeTensor(
mode,
torch.empty(
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.channels_last,
),
device=common_device,
)
return slow("no contiguity match")
return fast_binary_impl
@functools.lru_cache(None)
def get_fast_op_impls():
import torch._refs
register_fast_op_impl(torch.ops.aten.add.Tensor)(
make_fast_binary_impl(torch._refs.add)
)
register_fast_op_impl(torch.ops.aten.sub.Tensor)(
make_fast_binary_impl(torch._refs.sub)
)
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
register_fast_op_impl(torch.ops.aten.div.Tensor)(
make_fast_binary_impl(torch._refs.div)
)
return FAST_OP_IMPLEMENTATIONS