|
|
|
|
|
import functools |
|
import itertools |
|
import math |
|
import sys |
|
from typing import Callable, Union |
|
|
|
import torch |
|
import torch._custom_op |
|
import torch._logging |
|
|
|
from torch._ops import OpOverload |
|
from torch._prims_common import ( |
|
elementwise_dtypes, |
|
ELEMENTWISE_TYPE_PROMOTION_KIND, |
|
is_boolean_dtype, |
|
is_float_dtype, |
|
is_integer_dtype, |
|
) |
|
|
|
from torch._subclasses.fake_tensor import ( |
|
DataDependentOutputException, |
|
DynamicOutputShapeException, |
|
FakeTensor, |
|
in_kernel_invocation_manager, |
|
run_fallback_kernel, |
|
UnsupportedOperatorException, |
|
) |
|
from torch.fx.operator_schemas import normalize_function |
|
|
|
from torch.utils._stats import count_label |
|
|
|
pytree = torch.utils._pytree |
|
|
|
__all__ = [ |
|
"op_implementations_checks", |
|
"get_fast_op_impls", |
|
"stride_incorrect_op", |
|
"has_meta", |
|
] |
|
|
|
op_implementations_dict = {} |
|
op_implementations_checks = [] |
|
|
|
|
|
aten = torch._ops.ops.aten |
|
|
|
|
|
def ordered_set(*items): |
|
return dict.fromkeys(items, True) |
|
|
|
|
|
|
|
|
|
def is_noncontiguous_supported(device): |
|
if device.type == "hpu": |
|
return False |
|
return True |
|
|
|
|
|
_like_tensor_constructors = ordered_set( |
|
aten.empty_like.default, |
|
aten.empty_like.out, |
|
aten.full_like.default, |
|
aten.full_like.out, |
|
aten.ones_like.default, |
|
aten.ones_like.out, |
|
aten.rand_like.default, |
|
aten.rand_like.out, |
|
aten.randn_like.default, |
|
aten.randn_like.out, |
|
aten.randint_like.default, |
|
aten.randint_like.out, |
|
aten.randint_like.low_dtype, |
|
aten.randint_like.low_dtype_out, |
|
aten.zeros_like.default, |
|
aten.zeros_like.out, |
|
aten.new_empty.default, |
|
aten.new_empty.out, |
|
aten.new_empty_strided.default, |
|
aten.new_empty_strided.out, |
|
aten.new_full.default, |
|
aten.new_full.out, |
|
aten.new_zeros.default, |
|
aten.new_zeros.out, |
|
aten.new_ones.default, |
|
aten.new_ones.out, |
|
) |
|
|
|
|
|
_device_not_kwarg_ops = ordered_set( |
|
aten._resize_output_.default, |
|
aten._nested_tensor_from_tensor_list.default, |
|
aten._nested_tensor_from_tensor_list.out, |
|
aten.pin_memory.default, |
|
aten.is_pinned.default, |
|
aten.to.device, |
|
aten.to.prim_Device, |
|
aten._pin_memory.default, |
|
aten._pin_memory.out, |
|
aten._resize_output.default, |
|
aten._resize_output.out, |
|
) |
|
|
|
|
|
_non_kwarg_device_constructors = (aten._list_to_tensor,) |
|
|
|
|
|
def contains_tensor_types(type): |
|
tensor_type = torch._C.TensorType.get() |
|
return type.isSubtypeOf(tensor_type) or any( |
|
contains_tensor_types(e) for e in type.containedTypes() |
|
) |
|
|
|
|
|
@functools.lru_cache(None) |
|
def _is_tensor_constructor(func: OpOverload): |
|
assert isinstance(func, OpOverload) |
|
schema = func._schema |
|
if any(contains_tensor_types(arg.type) for arg in schema.arguments): |
|
return False |
|
|
|
return ( |
|
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() |
|
) |
|
|
|
|
|
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): |
|
def impl_decorator(op_impl): |
|
if isinstance(run_impl_check, OpOverload): |
|
assert ( |
|
run_impl_check not in op_implementations_dict |
|
), f"duplicate registration: {run_impl_check}" |
|
op_implementations_dict[run_impl_check] = op_impl |
|
elif isinstance(run_impl_check, (list, tuple)): |
|
for op in run_impl_check: |
|
register_op_impl(op)(op_impl) |
|
else: |
|
assert callable(run_impl_check) |
|
op_implementations_checks.append((run_impl_check, op_impl)) |
|
|
|
return op_impl |
|
|
|
return impl_decorator |
|
|
|
|
|
@register_op_impl(op_implementations_dict.__contains__) |
|
def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs): |
|
return op_implementations_dict[func](fake_mode, func, *args, **kwargs) |
|
|
|
|
|
@register_op_impl(_is_tensor_constructor) |
|
@register_op_impl([*_like_tensor_constructors]) |
|
def constructors(fake_mode, func, *args, **kwargs): |
|
assert func not in _non_kwarg_device_constructors |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
if "names" in kwargs: |
|
raise UnsupportedOperatorException( |
|
"torch.compile doesn't support named tensors" |
|
) |
|
|
|
if func in _like_tensor_constructors: |
|
default_device = new_kwargs["input"].device |
|
|
|
args = (new_kwargs.pop("input"),) |
|
else: |
|
|
|
default_device = torch.device("cpu") |
|
args = () |
|
out_device = new_kwargs.pop("device", None) |
|
out_device = out_device if out_device is not None else default_device |
|
new_kwargs["device"] = torch.device("meta") |
|
|
|
|
|
with in_kernel_invocation_manager(fake_mode): |
|
r = func(*args, **new_kwargs) |
|
return FakeTensor(fake_mode, r, out_device) |
|
|
|
|
|
@register_op_impl(aten.to.prim_Device) |
|
@register_op_impl(aten.to.device) |
|
def non_kwarg_to(fake_mode, func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args, kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
input_device = new_kwargs["device"] |
|
out_device = input_device if input_device else new_kwargs["input"].device |
|
new_kwargs["device"] = torch.device("meta") |
|
inp = new_kwargs.pop("input") |
|
with in_kernel_invocation_manager(fake_mode): |
|
r = func(inp, **new_kwargs) |
|
|
|
return fake_mode.fake_tensor_converter.from_meta_and_device( |
|
fake_mode, r, out_device |
|
) |
|
|
|
|
|
def stride_incorrect_op(op): |
|
if op.namespace not in ("aten", "prims"): |
|
return False |
|
if op is aten._fft_c2c.default: |
|
return False |
|
|
|
op_name = op.name() |
|
if "fft" in op_name: |
|
return True |
|
return False |
|
|
|
|
|
|
|
@register_op_impl(stride_incorrect_op) |
|
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): |
|
|
|
|
|
def is_symbolic(x): |
|
if isinstance(x, FakeTensor): |
|
return x._has_symbolic_sizes_strides |
|
if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)): |
|
return True |
|
return False |
|
|
|
|
|
if fake_mode.allow_fallback_kernels: |
|
require_dynamic = any( |
|
is_symbolic(x) for x in itertools.chain(args, kwargs.values()) |
|
) |
|
if not require_dynamic: |
|
flat_args, args_spec = pytree.tree_flatten((args, kwargs)) |
|
return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None) |
|
|
|
raise UnsupportedOperatorException(func) |
|
|
|
|
|
|
|
|
|
@register_op_impl(aten.resize_as_.default) |
|
def resize_as_(fake_mode, func, *args, **kwargs): |
|
with in_kernel_invocation_manager(fake_mode): |
|
return func(*args, **kwargs) |
|
|
|
|
|
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) |
|
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): |
|
|
|
return constructors(fake_mode, func, *args, **kwargs) |
|
|
|
|
|
|
|
@register_op_impl( |
|
lambda func: torch.Tag.dynamic_output_shape in func.tags |
|
and func |
|
not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor] |
|
) |
|
def dyn_shape(fake_mode, func, *args, **kwargs): |
|
raise DynamicOutputShapeException(func) |
|
|
|
|
|
def _unique( |
|
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False |
|
): |
|
if ( |
|
fake_mode.shape_env is None |
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
|
): |
|
|
|
raise DynamicOutputShapeException(func) |
|
|
|
|
|
if dim is not None or (nnz := arg.unique_memo) is None: |
|
|
|
from torch.fx.experimental.symbolic_shapes import ( |
|
_constrain_range_for_size, |
|
has_free_symbols, |
|
) |
|
|
|
if not has_free_symbols(arg.numel()) and arg.numel() == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nnz = 0 |
|
else: |
|
nnz = fake_mode.shape_env.create_unbacked_symint() |
|
|
|
maxval = sys.maxsize - 1 |
|
|
|
numel = arg.numel() if dim is None else arg.size(dim) |
|
if not has_free_symbols(numel): |
|
maxval = int(numel) |
|
|
|
_constrain_range_for_size(nnz, max=maxval) |
|
|
|
if dim is None: |
|
arg.unique_memo = nnz |
|
|
|
if dim is None: |
|
ret = [arg.new_empty((nnz,))] |
|
else: |
|
ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])] |
|
|
|
return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu") |
|
if return_inverse or return_if_dim_and_cpu: |
|
inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],)) |
|
else: |
|
inverse = arg.new_empty(0) |
|
ret.append(inverse) |
|
|
|
if return_counts or return_if_dim_and_cpu: |
|
counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],)) |
|
else: |
|
counts = arg.new_empty(0) |
|
ret.append(counts) |
|
|
|
return tuple(ret) |
|
|
|
|
|
@register_op_impl(aten._unique2.default) |
|
def unique2( |
|
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False |
|
): |
|
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) |
|
|
|
|
|
@register_op_impl(aten.unique_dim.default) |
|
def unique_dim( |
|
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False |
|
): |
|
return _unique( |
|
fake_mode, |
|
func, |
|
arg, |
|
|
|
dim if dim >= 0 else dim % max(arg.ndim, 1), |
|
sorted, |
|
return_inverse, |
|
return_counts, |
|
) |
|
|
|
|
|
@register_op_impl(aten.repeat_interleave.Tensor) |
|
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): |
|
if output_size is None: |
|
if ( |
|
fake_mode.shape_env is None |
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
|
): |
|
raise DynamicOutputShapeException(func) |
|
|
|
output_size = fake_mode.shape_env.create_unbacked_symint() |
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size |
|
|
|
_constrain_range_for_size(output_size) |
|
|
|
return repeats.new_empty(output_size) |
|
|
|
|
|
@register_op_impl(torch.ops.aten._local_scalar_dense.default) |
|
def local_scalar_dense(fake_mode, func, arg): |
|
if (r := arg.item_memo) is not None: |
|
return r |
|
if fake_mode.shape_env is None or ( |
|
not fake_mode.shape_env.allow_scalar_outputs |
|
and not fake_mode.allow_scalar_outputs |
|
): |
|
|
|
raise DataDependentOutputException(func) |
|
if is_float_dtype(arg.dtype): |
|
r = fake_mode.shape_env.create_unbacked_symfloat() |
|
elif is_integer_dtype(arg.dtype): |
|
r = fake_mode.shape_env.create_unbacked_symint() |
|
elif is_boolean_dtype(arg.dtype): |
|
r = fake_mode.shape_env.create_unbacked_symbool() |
|
else: |
|
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") |
|
arg.item_memo = r |
|
return r |
|
|
|
|
|
@register_op_impl(torch.ops.aten.nonzero.default) |
|
def nonzero(fake_mode, func, arg): |
|
if ( |
|
fake_mode.shape_env is None |
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
|
): |
|
|
|
raise DynamicOutputShapeException(func) |
|
|
|
if (nnz := arg.nonzero_memo) is None: |
|
|
|
from torch.fx.experimental.symbolic_shapes import ( |
|
_constrain_range_for_size, |
|
has_free_symbols, |
|
) |
|
|
|
if not has_free_symbols(arg.numel()) and arg.numel() == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nnz = 0 |
|
else: |
|
nnz = fake_mode.shape_env.create_unbacked_symint() |
|
|
|
maxval = sys.maxsize - 1 |
|
|
|
if not has_free_symbols(arg.numel()): |
|
maxval = int(arg.numel()) |
|
|
|
_constrain_range_for_size(nnz, max=maxval) |
|
|
|
arg.nonzero_memo = nnz |
|
|
|
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64) |
|
|
|
|
|
@register_op_impl(torch.ops.aten.masked_select.default) |
|
def masked_select(fake_mode, func, self, mask): |
|
if ( |
|
fake_mode.shape_env is None |
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
|
): |
|
|
|
raise DynamicOutputShapeException(func) |
|
|
|
nnz = fake_mode.shape_env.create_unbacked_symint() |
|
|
|
|
|
maxval = sys.maxsize - 1 |
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import ( |
|
_constrain_range_for_size, |
|
has_free_symbols, |
|
) |
|
|
|
if not has_free_symbols(self.numel()): |
|
if self.numel() > 2: |
|
maxval = int(self.numel()) |
|
|
|
_constrain_range_for_size(nnz, max=maxval) |
|
|
|
return self.new_empty((nnz,)) |
|
|
|
|
|
|
|
@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags) |
|
def data_dep(fake_mode, func, *args, **kwargs): |
|
raise DataDependentOutputException(func) |
|
|
|
|
|
|
|
|
|
def check_no_bool_index_tensors(func, self, indices): |
|
for index in indices: |
|
if index is not None and index.dtype in (torch.bool, torch.uint8): |
|
raise DynamicOutputShapeException(func) |
|
|
|
|
|
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
out_device = new_kwargs["input"].device |
|
with in_kernel_invocation_manager(fake_mode): |
|
out = func(*args, **kwargs) |
|
if not is_noncontiguous_supported(out_device): |
|
out = out.new_empty(out.shape) |
|
|
|
if out is new_kwargs["input"]: |
|
return out |
|
return FakeTensor(fake_mode, out, out_device) |
|
|
|
|
|
_is_builtin_namespaces = ordered_set("aten", "prims", "prim") |
|
|
|
|
|
def is_builtin(op): |
|
return op.namespace in _is_builtin_namespaces |
|
|
|
|
|
def has_meta(func): |
|
return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") |
|
|
|
|
|
@register_op_impl( |
|
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func) |
|
) |
|
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): |
|
tensor_lists = [] |
|
for arg in itertools.chain(args, kwargs.values()): |
|
if ( |
|
isinstance(arg, (list, tuple)) |
|
and len(arg) |
|
and isinstance(arg[0], torch.Tensor) |
|
): |
|
tensor_lists.append(arg) |
|
|
|
try: |
|
with in_kernel_invocation_manager(fake_mode): |
|
out_meta = func(*args, **kwargs) |
|
except NotImplementedError as not_implemented_error: |
|
return NotImplemented |
|
|
|
if not out_meta: |
|
return out_meta |
|
|
|
assert tensor_lists |
|
out_fake = [] |
|
|
|
for i, meta_t in enumerate(out_meta): |
|
device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists]) |
|
out_fake.append( |
|
fake_mode.fake_tensor_converter.from_meta_and_device( |
|
fake_mode, meta_t, device |
|
) |
|
) |
|
|
|
return out_fake |
|
|
|
|
|
|
|
|
|
|
|
@register_op_impl(aten.index.Tensor) |
|
def index_tensor(fake_mode, func, *args, **kwargs): |
|
from torch._meta_registrations import meta_index_Tensor |
|
|
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
out_device = new_kwargs["input"].device |
|
|
|
with fake_mode: |
|
out = meta_index_Tensor(*args, **kwargs) |
|
return out.to(out_device) |
|
|
|
|
|
|
|
|
|
@register_op_impl(aten._embedding_bag.default) |
|
def embedding_bag(fake_mode, func, *args, **kwargs): |
|
from torch._meta_registrations import meta_embedding_bag |
|
|
|
with fake_mode: |
|
return meta_embedding_bag(*args, **kwargs) |
|
|
|
|
|
|
|
@register_op_impl(aten._unsafe_index_put.default) |
|
@register_op_impl(aten.copy.default) |
|
@register_op_impl(aten.copy_.default) |
|
@register_op_impl(aten.slice_scatter.default) |
|
def multi_device_op_default(fake_mode, func, *args, **kwargs): |
|
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) |
|
|
|
|
|
|
|
@register_op_impl(aten.copy.out) |
|
@register_op_impl(aten.slice_scatter.out) |
|
def multi_device_op_out(fake_mode, func, *args, **kwargs): |
|
with in_kernel_invocation_manager(fake_mode): |
|
out = func(*args, **kwargs) |
|
|
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
return new_kwargs["input"] |
|
|
|
|
|
@register_op_impl(aten.index_put.default) |
|
@register_op_impl(aten.index_put_.default) |
|
def index_put_impl(fake_mode, func, *args, **kwargs): |
|
_, new_kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
values = new_kwargs["values"] |
|
self_device = new_kwargs["input"].fake_device |
|
torch._check( |
|
self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1), |
|
lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})", |
|
) |
|
|
|
out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) |
|
if func is aten.index_put_.default: |
|
return new_kwargs["input"] |
|
else: |
|
return out |
|
|
|
|
|
@register_op_impl(aten._nested_tensor_from_tensor_list.default) |
|
@register_op_impl(aten._nested_tensor_from_tensor_list.out) |
|
@register_op_impl(aten._nested_view_from_buffer.default) |
|
@register_op_impl(aten._nested_view_from_buffer_copy.default) |
|
def nested_tensors_unsupported(fake_mode, func, *args, **kwargs): |
|
raise UnsupportedOperatorException( |
|
"torch.compile does not support strided NestedTensor" |
|
) |
|
|
|
|
|
@register_op_impl( |
|
[ |
|
x |
|
for x in _device_not_kwarg_ops |
|
if x |
|
not in ( |
|
|
|
aten.to.device, |
|
aten.to.prim_Device, |
|
aten._nested_tensor_from_tensor_list.default, |
|
aten._nested_tensor_from_tensor_list.out, |
|
) |
|
] |
|
) |
|
def nyi(fake_mode, func, *args, **kwargs): |
|
assert func not in _device_not_kwarg_ops, f"NYI: {func}" |
|
|
|
|
|
@register_op_impl([aten.convolution.default, aten.convolution_backward.default]) |
|
def conv(fake_mode, func, *args, **kwargs): |
|
_, kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
device = kwargs["input"].fake_device |
|
|
|
with fake_mode: |
|
|
|
k = kwargs["weight"].ndim |
|
batch = kwargs["input"].shape[0] |
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import has_hint |
|
|
|
if not has_hint(batch): |
|
|
|
|
|
mem_fmt = None |
|
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: |
|
mem_fmt = None |
|
else: |
|
if func is aten.convolution.default: |
|
conv_backend = torch._C._select_conv_backend(**kwargs) |
|
else: |
|
conv_backend = torch._C._select_conv_backend( |
|
kwargs["input"], |
|
kwargs["weight"], |
|
bias=None, |
|
stride=kwargs["stride"], |
|
padding=kwargs["padding"], |
|
dilation=kwargs["dilation"], |
|
transposed=kwargs["transposed"], |
|
output_padding=kwargs["output_padding"], |
|
groups=kwargs["groups"], |
|
bias_sizes=kwargs["bias_sizes"], |
|
) |
|
mem_fmt = torch._C._conv_determine_backend_memory_format( |
|
kwargs["input"], kwargs["weight"], conv_backend |
|
) |
|
|
|
def convert(t, mem_fmt): |
|
if t is None: |
|
return t |
|
if mem_fmt is not None: |
|
t = t.to(memory_format=mem_fmt) |
|
return FakeTensor(fake_mode, t, device) |
|
|
|
with in_kernel_invocation_manager(fake_mode): |
|
out = func(**kwargs) |
|
|
|
if func is aten.convolution.default: |
|
return convert(out, mem_fmt) |
|
else: |
|
return ( |
|
convert(out[0], mem_fmt), |
|
convert(out[1], mem_fmt), |
|
convert(out[2], None), |
|
) |
|
|
|
|
|
@register_op_impl(aten._scaled_dot_product_flash_attention.default) |
|
def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs): |
|
_, kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
query = kwargs["query"] |
|
key = kwargs["key"] |
|
return_debug_mask = kwargs["return_debug_mask"] |
|
|
|
|
|
def convert_tensor(t, device): |
|
return FakeTensor(fake_mode, t, device) |
|
|
|
batch_size = query.size(0) |
|
num_heads = query.size(1) |
|
max_seqlen_batch_q = query.size(2) |
|
head_dim = query.size(3) |
|
max_seqlen_batch_k = key.size(2) |
|
|
|
query_t = query.transpose(1, 2) |
|
|
|
attention = torch.empty_like(query_t).transpose(1, 2) |
|
logsumexp = convert_tensor( |
|
torch.empty( |
|
(batch_size, num_heads, max_seqlen_batch_q), |
|
dtype=torch.float, |
|
device="meta", |
|
), |
|
device=query.device, |
|
) |
|
|
|
if return_debug_mask: |
|
blocksize_c = 128 if head_dim > 64 else 256 |
|
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) |
|
if max_seqlen_batch_k <= 128: |
|
max_seqlen_k = 128 |
|
elif max_seqlen_batch_k <= 256: |
|
max_seqlen_k = 256 |
|
debug_mask = convert_tensor( |
|
torch.empty( |
|
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), |
|
dtype=query.dtype, |
|
device="meta", |
|
), |
|
device=query.device, |
|
) |
|
else: |
|
debug_mask = convert_tensor( |
|
torch.empty(0, dtype=query.dtype, device="meta"), |
|
query.device, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
attention, |
|
logsumexp, |
|
None, |
|
None, |
|
max_seqlen_batch_q, |
|
max_seqlen_batch_k, |
|
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), |
|
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), |
|
debug_mask, |
|
) |
|
|
|
|
|
@register_op_impl(aten._scaled_dot_product_efficient_attention.default) |
|
def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs): |
|
_, kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
query = kwargs["query"] |
|
key = kwargs["key"] |
|
value = kwargs["value"] |
|
compute_log_sumexp = kwargs["compute_log_sumexp"] |
|
|
|
|
|
def convert_tensor(t, device): |
|
return FakeTensor(fake_mode, t, device) |
|
|
|
query = query.transpose(1, 2) |
|
key = key.transpose(1, 2) |
|
value = value.transpose(1, 2) |
|
|
|
B = query.size(0) |
|
M = query.size(1) |
|
N = key.size(1) |
|
num_heads = query.size(-2) |
|
K = query.size(-1) |
|
Kv = value.size(-1) |
|
|
|
res = convert_tensor( |
|
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), |
|
query.device, |
|
) |
|
|
|
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 |
|
logsum_exp = convert_tensor( |
|
torch.empty( |
|
(B, num_heads, logsumexp_dim), |
|
dtype=torch.float, |
|
device="meta", |
|
), |
|
query.device, |
|
) |
|
|
|
res = res.transpose(1, 2) |
|
|
|
|
|
seed = convert_tensor( |
|
torch.empty((), dtype=torch.long, device="meta"), query.device |
|
) |
|
offset = convert_tensor( |
|
torch.empty((), dtype=torch.long, device="meta"), query.device |
|
) |
|
|
|
return res, logsum_exp, seed, offset |
|
|
|
|
|
@register_op_impl(aten._flash_attention_forward.default) |
|
def meta__flash_attention_forward(fake_mode, func, *args, **kwargs): |
|
_, kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
query = kwargs["query"] |
|
key = kwargs["key"] |
|
cum_seq_q = kwargs["cum_seq_q"] |
|
cum_seq_k = kwargs["cum_seq_k"] |
|
max_q = kwargs["max_q"] |
|
max_k = kwargs["max_k"] |
|
return_debug_mask = kwargs["return_debug_mask"] |
|
|
|
|
|
|
|
def convert_tensor(t, device): |
|
return FakeTensor(fake_mode, t, device) |
|
|
|
|
|
|
|
|
|
|
|
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 |
|
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q |
|
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k |
|
num_heads = query.size(-2) |
|
head_dim = query.size(-1) |
|
|
|
|
|
|
|
attention = torch.empty_like(query) |
|
logsumexp = convert_tensor( |
|
torch.empty( |
|
(batch_size, num_heads, max_seqlen_batch_q), |
|
dtype=torch.float, |
|
device="meta", |
|
), |
|
device=query.device, |
|
) |
|
|
|
if return_debug_mask: |
|
blocksize_c = 128 if head_dim > 64 else 256 |
|
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) |
|
if max_seqlen_batch_k <= 128: |
|
max_seqlen_k = 128 |
|
elif max_seqlen_batch_k <= 256: |
|
max_seqlen_k = 256 |
|
debug_mask = convert_tensor( |
|
torch.empty( |
|
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), |
|
dtype=query.dtype, |
|
device="meta", |
|
), |
|
query.device, |
|
) |
|
else: |
|
debug_mask = convert_tensor( |
|
torch.empty(0, dtype=query.dtype, device="meta"), |
|
query.device, |
|
) |
|
|
|
|
|
return ( |
|
attention, |
|
logsumexp, |
|
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), |
|
convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), |
|
debug_mask, |
|
) |
|
|
|
|
|
@register_op_impl(aten._efficient_attention_forward.default) |
|
def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): |
|
_, kwargs = normalize_function( |
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
) |
|
|
|
query = kwargs["query"] |
|
key = kwargs["key"] |
|
value = kwargs["value"] |
|
cu_seqlens_q = kwargs["cu_seqlens_q"] |
|
max_seqlen_q = kwargs["max_seqlen_q"] |
|
max_seqlen_k = kwargs["max_seqlen_k"] |
|
compute_log_sumexp = kwargs["compute_log_sumexp"] |
|
|
|
|
|
def convert_tensor(t, device): |
|
return FakeTensor(fake_mode, t, device) |
|
|
|
B = query.size(0) |
|
M = query.size(1) |
|
N = key.size(1) |
|
num_heads = query.size(-2) |
|
K = query.size(-1) |
|
Kv = value.size(-1) |
|
|
|
res = convert_tensor( |
|
torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), |
|
query.device, |
|
) |
|
|
|
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B |
|
actual_max_seqlen_q = M |
|
if cu_seqlens_q is not None: |
|
assert max_seqlen_q is not None |
|
actual_max_seqlen_q = max_seqlen_q |
|
actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N |
|
logsumexp_dim = ( |
|
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 |
|
) |
|
logsum_exp = convert_tensor( |
|
torch.empty( |
|
(logsumexp_batch_dim, num_heads, logsumexp_dim), |
|
dtype=torch.float, |
|
device="meta", |
|
), |
|
query.device, |
|
) |
|
|
|
|
|
seed = convert_tensor( |
|
torch.empty((), dtype=torch.long, device="meta"), query.device |
|
) |
|
offset = convert_tensor( |
|
torch.empty((), dtype=torch.long, device="meta"), query.device |
|
) |
|
|
|
return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k |
|
|
|
|
|
@register_op_impl(torch.ops.aten._pack_padded_sequence.default) |
|
def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first): |
|
if ( |
|
fake_mode.shape_env is None |
|
or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
|
): |
|
|
|
raise DynamicOutputShapeException(func) |
|
|
|
new_batch_size = fake_mode.shape_env.create_unbacked_symint() |
|
|
|
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size |
|
|
|
_constrain_range_for_size(new_batch_size) |
|
|
|
if not batch_first: |
|
|
|
inputs = inputs.transpose(0, 1) |
|
|
|
res_size = inputs.shape[1:] |
|
packed_data = inputs.new_empty(res_size) |
|
batch_size = inputs.new_empty((new_batch_size,)) |
|
return (packed_data, batch_size) |
|
|
|
|
|
FAST_OP_IMPLEMENTATIONS = {} |
|
|
|
|
|
|
|
|
|
def register_fast_op_impl(func: OpOverload): |
|
def impl_decorator(op_impl): |
|
FAST_OP_IMPLEMENTATIONS[func] = op_impl |
|
return op_impl |
|
|
|
return impl_decorator |
|
|
|
|
|
|
|
def infer_size(a, b): |
|
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious |
|
|
|
dimsA = len(a) |
|
dimsB = len(b) |
|
ndim = max(dimsA, dimsB) |
|
expandedSizes = [0] * ndim |
|
for i in range(ndim - 1, -1, -1): |
|
offset = ndim - 1 - i |
|
dimA = dimsA - 1 - offset |
|
dimB = dimsB - 1 - offset |
|
sizeA = a[dimA] if dimA >= 0 else 1 |
|
sizeB = b[dimB] if dimB >= 0 else 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch._check( |
|
guard_size_oblivious(sizeA == 1) |
|
or guard_size_oblivious(sizeB == 1) |
|
or sizeA == sizeB, |
|
lambda: f"The size of tensor a ({sizeA}) " |
|
f"must match the size of tensor b ({sizeB}) " |
|
f"at non-singleton dimension {i})", |
|
) |
|
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA |
|
return tuple(expandedSizes) |
|
|
|
|
|
def make_fast_binary_impl(slow_ref): |
|
def fast_binary_impl(mode, *args, **kwargs): |
|
def slow(msg): |
|
count_label(f"slow {msg}") |
|
with mode: |
|
return slow_ref(*args, **kwargs) |
|
|
|
count_label("attempt fast") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
operands = args |
|
|
|
|
|
has_scalars = False |
|
has_tensors = False |
|
final_shape = None |
|
for op in operands: |
|
shape = op.shape if isinstance(op, torch.Tensor) else () |
|
if len(shape) == 0: |
|
has_scalars = True |
|
else: |
|
has_tensors = True |
|
if final_shape is None: |
|
final_shape = shape |
|
|
|
|
|
|
|
final_shape = infer_size(final_shape, shape) |
|
assert final_shape is not None |
|
|
|
|
|
|
|
for op in operands: |
|
if ( |
|
isinstance(op, torch.Tensor) |
|
and len(op.shape) == len(final_shape) |
|
and op.shape == final_shape |
|
): |
|
break |
|
else: |
|
return slow("both tensors nontrivially broadcast") |
|
|
|
|
|
cpu = torch.device("cpu") |
|
common_device = cpu |
|
common_dtype = None |
|
output_dtype = None |
|
has_different_input_dtypes = False |
|
for op in operands: |
|
if not isinstance(op, torch.Tensor): |
|
|
|
has_different_input_dtypes = True |
|
continue |
|
if common_device == cpu and not op.device.type == "cpu": |
|
common_device = op.device |
|
|
|
if common_dtype is None: |
|
common_dtype = op.dtype |
|
elif common_dtype != op.dtype: |
|
has_different_input_dtypes = True |
|
|
|
if has_different_input_dtypes: |
|
|
|
|
|
_, common_dtype = elementwise_dtypes( |
|
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT |
|
) |
|
|
|
|
|
|
|
current_cpu_scalars_on_non_cpu = 0 |
|
max_cpu_scalars_on_non_cpu = 1 |
|
for op in operands: |
|
if not isinstance(op, torch.Tensor): |
|
continue |
|
if common_device != cpu and op.dim() == 0 and op.device == cpu: |
|
if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu: |
|
return slow("error") |
|
current_cpu_scalars_on_non_cpu += 1 |
|
elif op.device != common_device: |
|
return slow("error") |
|
|
|
|
|
is_contiguous = True |
|
is_channels_last = True |
|
|
|
|
|
|
|
if is_noncontiguous_supported(common_device): |
|
for op in operands: |
|
if not isinstance(op, torch.Tensor): |
|
continue |
|
is_contiguous = is_contiguous and op.is_contiguous( |
|
memory_format=torch.contiguous_format |
|
) |
|
is_channels_last = is_channels_last and op.is_contiguous( |
|
memory_format=torch.channels_last |
|
) |
|
if is_contiguous: |
|
|
|
count_label("fast is_contiguous") |
|
return FakeTensor( |
|
mode, |
|
torch.empty( |
|
final_shape, |
|
dtype=common_dtype, |
|
device="meta", |
|
memory_format=torch.contiguous_format, |
|
), |
|
device=common_device, |
|
) |
|
if is_channels_last: |
|
count_label("fast channels_last") |
|
|
|
return FakeTensor( |
|
mode, |
|
torch.empty( |
|
final_shape, |
|
dtype=common_dtype, |
|
device="meta", |
|
memory_format=torch.channels_last, |
|
), |
|
device=common_device, |
|
) |
|
|
|
return slow("no contiguity match") |
|
|
|
return fast_binary_impl |
|
|
|
|
|
@functools.lru_cache(None) |
|
def get_fast_op_impls(): |
|
import torch._refs |
|
|
|
register_fast_op_impl(torch.ops.aten.add.Tensor)( |
|
make_fast_binary_impl(torch._refs.add) |
|
) |
|
register_fast_op_impl(torch.ops.aten.sub.Tensor)( |
|
make_fast_binary_impl(torch._refs.sub) |
|
) |
|
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) |
|
register_fast_op_impl(torch.ops.aten.div.Tensor)( |
|
make_fast_binary_impl(torch._refs.div) |
|
) |
|
return FAST_OP_IMPLEMENTATIONS |
|
|