Spaces:
Running
Running
import itertools | |
import unittest.mock | |
from contextlib import contextmanager | |
from typing import Iterator | |
import torch | |
import torch._C | |
import torch._ops | |
import torch.utils._python_dispatch | |
import torch.utils._pytree as pytree | |
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] | |
no_python_dispatcher = torch._C._DisablePythonDispatcher | |
enable_python_dispatcher = torch._C._EnablePythonDispatcher | |
enable_pre_dispatch = torch._C._EnablePreDispatch | |
CROSSREF_FUNCTIONALIZE = False | |
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: | |
""" | |
Warning: the set of overloads this will report is very subtle. It is precisely | |
the set of torch.ops functions that have actually been accessed from Python | |
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT | |
from the set of registered operators, which will in general be a larger set, | |
as this would include all operators which we ran C++ static initializers or | |
Python operator registration on. This does not eagerly populate the list on | |
torch.ops.aten; this list is lazy! | |
In other words, this is good for traversing over everything that has an | |
OpOverload object allocated in Python. We use it for cache invalidation, but | |
don't rely on this list being complete. | |
Note that even if we did report all C++ registered overloads, this isn't guaranteed | |
to be complete either, as a subsequent lazy load of a library which triggers more | |
registrations could add more things to the set. | |
""" | |
for ns in torch.ops: | |
packets = getattr(torch.ops, ns) | |
for op_name in packets: | |
packet = getattr(packets, op_name) | |
for overload in packet: | |
yield getattr(packet, overload) | |
def suspend_functionalization(): | |
f_tls = torch._C._dispatch_tls_is_dispatch_key_included( | |
torch._C.DispatchKey.Functionalize | |
) | |
f_rv = torch._C._functionalization_reapply_views_tls() | |
if f_tls: | |
torch._disable_functionalization() | |
try: | |
yield | |
finally: | |
if f_tls: | |
torch._enable_functionalization(reapply_views=f_rv) | |
def check_tensor_metadata_matches(nv, rv, desc): | |
assert callable(desc) | |
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" | |
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" | |
same_strides, idx = torch._prims_common.check_significant_strides( | |
nv, rv, only_cuda=False | |
) | |
assert ( | |
same_strides | |
), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" | |
def check_metadata_matches(n, r, desc): | |
assert callable(desc) | |
n_vals, n_spec = pytree.tree_flatten(n) | |
r_vals, r_spec = pytree.tree_flatten(r) | |
# TODO: test the specs match; empirically sometimes we have a tuple | |
# on one side and a list on the other | |
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" | |
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): | |
if not isinstance(rv, torch.Tensor): | |
continue | |
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") | |
class Lit: | |
def __init__(self, s): | |
self.s = s | |
def __repr__(self): | |
return self.s | |
def _fmt(a: object) -> object: | |
if isinstance(a, torch.Tensor): | |
return Lit( | |
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})" | |
) | |
else: | |
return a | |
def make_crossref_functionalize(op, final_key): | |
from torch._subclasses.fake_tensor import FakeTensorMode | |
# This case is pretty weird, suppress it for now | |
if op == torch.ops.aten.lift_fresh.default: | |
return final_key | |
def handler(*args, **kwargs): | |
fake_mode = FakeTensorMode() | |
def fakeify_defun(t): | |
if isinstance(t, torch.Tensor): | |
if torch._is_functional_tensor(t): | |
r = torch._from_functional_tensor(t) | |
# NB: This assumes that the inner tensor sizes/strides match | |
# the outer tensor sizes/strides. This doesn't necessarily have to | |
# be the case, see discussion at | |
# https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456 | |
assert t.size() == r.size() | |
assert t.stride() == r.stride() | |
else: | |
r = t | |
# TODO: suppress guards | |
return fake_mode.from_tensor(r) | |
return t | |
def maybe_detach(t): | |
if isinstance(t, torch.Tensor): | |
return t.detach() | |
else: | |
return t | |
# TODO: This probably does the wrong thing if you're running other | |
# substantive modes with the normal op outside here | |
with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization(): | |
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) | |
orig_f_args, orig_f_kwargs = pytree.tree_map( | |
maybe_detach, (f_args, f_kwargs) | |
) | |
with fake_mode: | |
f_r = op(*f_args, **f_kwargs) | |
r = op._op_dk(final_key, *args, **kwargs) | |
def desc(): | |
fmt_args = ", ".join( | |
itertools.chain( | |
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), | |
( | |
f"{k}={pytree.tree_map(_fmt, v)}" | |
for k, v in orig_f_kwargs.items() | |
), | |
) | |
) | |
return f"{op}({fmt_args})" | |
check_metadata_matches(f_r, r, desc) | |
return r | |
return handler | |
# NB: enabling this is slow, don't do it in a hot loop. This is purely | |
# for debugging purposes. | |
def enable_crossref_functionalize(): | |
for op in all_py_loaded_overloads(): | |
op._uncache_dispatch(torch._C.DispatchKey.Functionalize) | |
try: | |
with enable_python_dispatcher(), unittest.mock.patch( | |
"torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True | |
): | |
yield | |
finally: | |
for op in all_py_loaded_overloads(): | |
op._uncache_dispatch(torch._C.DispatchKey.Functionalize) | |