|
|
|
|
|
import functools |
|
import warnings |
|
from typing import Callable, Union |
|
|
|
import torch |
|
import torch.utils._pytree as pytree |
|
from torch._ops import OpOverload |
|
from torch._subclasses.fake_tensor import ( |
|
FakeTensorMode, |
|
tree_flatten_only, |
|
UnsupportedFakeTensorException, |
|
) |
|
from torch.utils._python_dispatch import TorchDispatchMode |
|
|
|
|
|
aten = torch._ops.ops.aten |
|
|
|
|
|
def outputs_alias_inputs(outputs, inputs): |
|
input_storages = { |
|
inp._typed_storage()._cdata |
|
for inp in tree_flatten_only(torch.Tensor, inputs) |
|
if torch._C._has_storage(inp) |
|
} |
|
return any( |
|
torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages |
|
for out in tree_flatten_only(torch.Tensor, outputs) |
|
) |
|
|
|
|
|
def outputs_are_inputs(outputs, inputs): |
|
input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)} |
|
return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs)) |
|
|
|
|
|
def output_alias_each_other(outputs): |
|
storages = set() |
|
for out in tree_flatten_only(torch.Tensor, outputs): |
|
if not torch._C._has_storage(out): |
|
continue |
|
stor = out._typed_storage()._cdata |
|
if stor in storages: |
|
return True |
|
storages.add(stor) |
|
return False |
|
|
|
|
|
def is_sdpa_error(func, idx, e): |
|
if ( |
|
( |
|
func is aten._scaled_dot_product_flash_attention.default |
|
or func is aten._flash_attention_forward.default |
|
) |
|
and idx in (6, 7) |
|
and "Devices" in repr(e) |
|
): |
|
return True |
|
if ( |
|
( |
|
func is aten._scaled_dot_product_efficient_attention.default |
|
or func is aten._efficient_attention_forward.default |
|
) |
|
and idx in (2, 3) |
|
and "Devices" in repr(e) |
|
): |
|
return True |
|
return False |
|
|
|
|
|
class CrossRefFakeMode(TorchDispatchMode): |
|
def __init__( |
|
self, |
|
ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None, |
|
*, |
|
check_strides=True, |
|
check_aliasing=True, |
|
): |
|
self.ignore_op_fn = ( |
|
ignore_op_fn if ignore_op_fn is not None else lambda fn: False |
|
) |
|
self.check_strides = check_strides |
|
self.check_aliasing = check_aliasing |
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
kwargs = kwargs or {} |
|
|
|
fake_r = None |
|
|
|
|
|
|
|
if ( |
|
func |
|
not in ( |
|
aten.lift_fresh.default, |
|
aten.lift_fresh_copy.default, |
|
aten.set_.source_Storage_storage_offset, |
|
) |
|
and not self.ignore_op_fn(func) |
|
and torch.Tag.dynamic_output_shape not in func.tags |
|
and torch.Tag.inplace_view not in func.tags |
|
and torch.Tag.data_dependent_output not in func.tags |
|
): |
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv |
|
|
|
try: |
|
|
|
with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode: |
|
fake_args, fake_kwargs = pytree.tree_map_only( |
|
torch.Tensor, |
|
functools.partial(fake_mode.from_tensor, static_shapes=True), |
|
(args, kwargs), |
|
) |
|
with warnings.catch_warnings(): |
|
fake_r = func(*fake_args, **fake_kwargs) |
|
except UnsupportedFakeTensorException: |
|
pass |
|
|
|
context = ( |
|
f"When comparing the output of {func} on FakeTensor and concrete Tensors, " |
|
f"found" |
|
) |
|
r = func(*args, **kwargs) |
|
if fake_r is not None: |
|
r_flat = pytree.tree_leaves(r) |
|
f_flat = pytree.tree_leaves(fake_r) |
|
assert len(f_flat) == len( |
|
r_flat |
|
), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" |
|
|
|
if self.check_aliasing: |
|
r_aliasing = outputs_alias_inputs(r, (args, kwargs)) |
|
f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs)) |
|
assert ( |
|
r_aliasing == f_aliasing |
|
), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}" |
|
|
|
r_identity_eq = outputs_are_inputs(r, (args, kwargs)) |
|
f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs)) |
|
assert ( |
|
r_identity_eq == f_identity_eq |
|
), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}" |
|
|
|
r_output_alias_each_other = output_alias_each_other(r) |
|
f_output_alias_each_other = output_alias_each_other(fake_r) |
|
assert r_output_alias_each_other == f_output_alias_each_other, ( |
|
f"{context} mismatch in outputs_alias_each_other check " |
|
f"{f_output_alias_each_other} != {r_output_alias_each_other}" |
|
) |
|
|
|
for idx, (r_out, fake_out) in enumerate( |
|
zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) |
|
): |
|
r_is_ten = isinstance(r_out, torch.Tensor) |
|
assert r_is_ten == isinstance( |
|
fake_out, torch.Tensor |
|
), f"{context} mismatched number of tensor outputs" |
|
if r_is_ten: |
|
assert r_out.requires_grad == fake_out.requires_grad, ( |
|
f"{context} mismatched requires_grad-ness of outputs. " |
|
f"This usually means that you have added autograd support " |
|
f"for your operator at a dispatch key other than Autograd, " |
|
f"which will lead to problems" |
|
) |
|
if torch._C._has_storage(r_out): |
|
r_offset = r_out.storage_offset() |
|
f_offset = fake_out.storage_offset() |
|
assert ( |
|
r_offset == f_offset |
|
), f"{context} mismatched storage offset" |
|
|
|
try: |
|
torch._prims.utils.compare_tensor_meta( |
|
r_out, |
|
fake_out, |
|
check_strides=self.check_strides, |
|
allow_rhs_unbacked=True, |
|
) |
|
except Exception as e: |
|
if is_sdpa_error(func, idx, e): |
|
continue |
|
error_message = ( |
|
f"{context} mismatched tensor metadata: {e}" |
|
if len(r_flat) == 1 |
|
else f"{context} mismatched tensor metadata for output[{idx}]: {e}" |
|
) |
|
raise RuntimeError(error_message) from e |
|
return r |
|
|