|
|
|
import contextlib |
|
|
|
import warnings |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Set, Union |
|
|
|
import torch |
|
import torchgen |
|
import torchgen.model |
|
from torch._C import ( |
|
_get_dispatch_stack_at, |
|
_len_torch_dispatch_stack, |
|
_pop_torch_dispatch_stack, |
|
_push_on_torch_dispatch_stack, |
|
DispatchKey, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_is_in_torch_dispatch_mode = False |
|
|
|
|
|
def is_in_torch_dispatch_mode() -> bool: |
|
return _is_in_torch_dispatch_mode |
|
|
|
|
|
class TorchDispatchMode: |
|
""" |
|
A ``TorchDispatchMode`` allows you to override the meaning of all |
|
``__torch_dispatch__`` overrideable functions within a dynamic scope, |
|
without having to actually create a tensor subclass or manually |
|
monkey-patch functions in the PyTorch API. Some common situations |
|
where you should use a mode: |
|
|
|
* You want to override the meaning of factory functions, or other |
|
functions that do not otherwise take a tensor as an argument |
|
(these cannot be overridden with tensor subclasses). |
|
|
|
* You want to override the behavior of all functions without needing |
|
to wrap your inputs in tensor subclasses; e.g., if you are just |
|
interested in logging intermediate computations. |
|
|
|
* You want to control the order of execution of various tensor |
|
subclasses explicitly, rather than implicitly via the return of |
|
``NotImplemented``. |
|
|
|
Independent subclasses of :class:`TorchDispatchMode` are compositional: |
|
modes can be pushed onto a stack using ``with MyMode():``. |
|
When you call functions in the PyTorch API inside your |
|
``__torch_dispatch__`` implementation, by default, they will forward on to |
|
the next mode on the mode stack. If you want recursively call back into |
|
your current ``__torch_dispatch__`` implementation, either explicitly |
|
invoke ``self.__torch_dispatch__(...)``, or use the context manager |
|
``__torch_dispatch__(self)`` to make PyTorch |
|
API self-referential (beware of infinite loops, in this case!) |
|
""" |
|
|
|
def __init__(self, _dispatch_key=None): |
|
if _dispatch_key is not None: |
|
assert isinstance(_dispatch_key, torch._C.DispatchKey) |
|
self.__dict__["_dispatch_key"] = _dispatch_key |
|
|
|
self.old_dispatch_mode_flag = False |
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
raise NotImplementedError |
|
|
|
def __enter__(self): |
|
global _is_in_torch_dispatch_mode |
|
self.old_dispatch_mode_flag = _is_in_torch_dispatch_mode |
|
_is_in_torch_dispatch_mode = True |
|
_push_mode(self) |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None) |
|
if mb_dk_or_mode_key is None: |
|
|
|
|
|
mb_dk_or_mode_key = self.__dict__.get("_mode_key", None) |
|
global _is_in_torch_dispatch_mode |
|
_is_in_torch_dispatch_mode = self.old_dispatch_mode_flag |
|
_pop_mode(mb_dk_or_mode_key) |
|
|
|
@classmethod |
|
def push(cls, *args, **kwargs): |
|
warnings.warn( |
|
"`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`" |
|
) |
|
instance = cls(*args, **kwargs) |
|
return instance |
|
|
|
|
|
def _get_current_dispatch_mode(): |
|
stack_len = _len_torch_dispatch_stack() |
|
|
|
if stack_len > 0: |
|
return _get_dispatch_stack_at(stack_len - 1) |
|
return None |
|
|
|
|
|
def _detect_infra_mode(key): |
|
assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY] |
|
from torch._ops import _get_dispatch_mode_pre_dispatch |
|
|
|
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch( |
|
key |
|
) |
|
post_dispatch_mode = torch._C._get_dispatch_mode( |
|
key |
|
) |
|
|
|
assert (pre_dispatch_mode is None) or ( |
|
post_dispatch_mode is None |
|
) |
|
|
|
if pre_dispatch_mode is None: |
|
return post_dispatch_mode |
|
|
|
return pre_dispatch_mode |
|
|
|
|
|
def _unset_infra_mode(key): |
|
from torch._ops import _get_dispatch_mode_pre_dispatch, unset_mode_pre_dispatch |
|
|
|
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key) |
|
post_dispatch_mode = torch._C._get_dispatch_mode(key) |
|
if pre_dispatch_mode and post_dispatch_mode: |
|
raise AssertionError( |
|
"Can't have active infra mode on both pre and post dispatch mode stack" |
|
) |
|
|
|
if pre_dispatch_mode: |
|
mode = unset_mode_pre_dispatch(key) |
|
return mode |
|
if post_dispatch_mode: |
|
return torch._C._unset_dispatch_mode(key) |
|
|
|
|
|
def _disable_infra_mode(key): |
|
assert key in ( |
|
torch._C._TorchDispatchModeKey.FUNCTIONAL, |
|
torch._C._TorchDispatchModeKey.PROXY, |
|
) |
|
mode_unset = _unset_infra_mode(key) |
|
try: |
|
yield mode_unset |
|
finally: |
|
if mode_unset is not None: |
|
_push_mode(mode_unset) |
|
|
|
|
|
def _get_current_dispatch_mode_stack(): |
|
stack_len = _len_torch_dispatch_stack() |
|
return [_get_dispatch_stack_at(i) for i in range(stack_len)] |
|
|
|
|
|
def _push_mode(mode: TorchDispatchMode): |
|
k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None |
|
assert k is None or k == torch._C.DispatchKey.PreDispatch |
|
if k is None: |
|
_push_on_torch_dispatch_stack(mode) |
|
return |
|
|
|
from torch._ops import _set_mode_pre_dispatch, get_cached_ops |
|
|
|
|
|
|
|
ks = torch._C._functionality_to_backend_keys(k) |
|
for op in get_cached_ops(): |
|
for key in ks: |
|
op._uncache_dispatch(key) |
|
_set_mode_pre_dispatch(mode) |
|
|
|
|
|
def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None): |
|
if k == torch._C.DispatchKey.PreDispatch: |
|
from torch._ops import _pop_mode_from_pre_dispatch |
|
|
|
return _pop_mode_from_pre_dispatch() |
|
|
|
if k is None or isinstance(k, torch._C._TorchDispatchModeKey): |
|
return _pop_torch_dispatch_stack(k) |
|
|
|
|
|
@contextlib.contextmanager |
|
def _pop_mode_temporarily(k: Optional[DispatchKey] = None): |
|
old = _pop_mode(k) |
|
try: |
|
yield old |
|
finally: |
|
_push_mode(old) |
|
|
|
|
|
@contextlib.contextmanager |
|
def _disable_current_modes(): |
|
from torch._ops import ( |
|
_len_torch_dispatch_stack_pre_dispatch, |
|
_pop_mode_from_pre_dispatch, |
|
) |
|
from torch._subclasses.functional_tensor import FunctionalTensorMode |
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode |
|
from torch._subclasses.schema_check_mode import SchemaCheckMode |
|
|
|
mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch() |
|
old_pre_dispatch_modes = [ |
|
_pop_mode_from_pre_dispatch() for _ in range(mode_len_pre_dispatch) |
|
] |
|
|
|
has_proxy_mode_in_pre_dispatch = False |
|
has_functional_mode_in_pre_dispatch = False |
|
has_schema_check_mode_in_pre_dispatch = False |
|
|
|
for i in old_pre_dispatch_modes: |
|
if isinstance(i, ProxyTorchDispatchMode): |
|
has_proxy_mode_in_pre_dispatch = True |
|
if isinstance(i, FunctionalTensorMode): |
|
has_functional_mode_in_pre_dispatch = True |
|
if isinstance(i, SchemaCheckMode): |
|
has_schema_check_mode_in_pre_dispatch = True |
|
|
|
mode_len = _len_torch_dispatch_stack() |
|
old_modes = [_pop_mode() for _ in range(mode_len)] |
|
|
|
for old in old_modes: |
|
if ( |
|
isinstance(old, FunctionalTensorMode) |
|
and has_functional_mode_in_pre_dispatch |
|
): |
|
raise AssertionError( |
|
"Can't have FunctionalMode available both in PreDispatch and Python Key" |
|
) |
|
if isinstance(old, ProxyTorchDispatchMode) and has_proxy_mode_in_pre_dispatch: |
|
raise AssertionError( |
|
"Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key" |
|
) |
|
if ( |
|
isinstance(old, SchemaCheckMode) |
|
and has_schema_check_mode_in_pre_dispatch |
|
): |
|
raise AssertionError( |
|
"Can't have SchemaCheckMode available both in PreDispatch and Python Key" |
|
) |
|
|
|
|
|
try: |
|
yield old_pre_dispatch_modes + old_modes |
|
finally: |
|
for mode in reversed(old_modes): |
|
_push_mode(mode) |
|
for mode in reversed(old_pre_dispatch_modes): |
|
_push_mode(mode) |
|
|
|
|
|
class BaseTorchDispatchMode(TorchDispatchMode): |
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
if kwargs is None: |
|
kwargs = {} |
|
return func(*args, **kwargs) |
|
|
|
|
|
def is_traceable_wrapper_subclass(t): |
|
""" |
|
Returns whether or not a tensor subclass that implements __torch_dispatch__ |
|
is 'traceable' with torch.compile. |
|
In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2, |
|
It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__. |
|
It is also expected to obey some restrictions around traceability and aliasing: |
|
* The subclass's __torch_dispatch__() implementation should desugar into pytorch |
|
dispatcher operations that can be traced into a graph. |
|
* The subclass should use return_and_correct_aliasing(). This is needed today to make |
|
sure that torch.compile does the right thing in a few cases around input mutation |
|
and output aliasing. |
|
|
|
Expected magic method signatures: |
|
attrs, ctx = t.__tensor_flatten__() |
|
attrs: list of attribute name strings for inner tensors |
|
ctx: dict containing any other subclass-specific metadata needed for unflattening |
|
|
|
t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride) |
|
inner_tensors: dict mapping attribute name -> tensor for each inner tensor |
|
ctx: dict with subclass metadata in the form that __tensor_flatten__() produces |
|
outer_size: expected (possibly symbolic) size that the returned subclass |
|
instance should have. Note that this arg is useful for certain subclasses |
|
that require the shape info to be constructed. In most cases, this arg can be |
|
safely ignored. |
|
outer_stride: expected (possibly symbolic) stride that the returned subclass |
|
instance should have. Note that this arg is useful for certain subclasses |
|
that require the stride info to be constructed. In most cases, this arg can be |
|
safely ignored. |
|
""" |
|
is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor |
|
return ( |
|
is_subclass |
|
and hasattr(t, "__tensor_flatten__") |
|
and hasattr(t, "__tensor_unflatten__") |
|
) |
|
|
|
|
|
def transform_subclass(t, callback, outer_size=None, outer_stride=None): |
|
""" |
|
Given a traceable, wrapper tensor subclass ``t`` that implements |
|
``__torch_dispatch__`` and holds some inner tensors, |
|
and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``, |
|
`transform_subclass` will construct a fresh instance of the wrapper tensor subclass. |
|
It will do so by grabbing each inner tensor attribute from the wrapper, |
|
passing them into ``callback`` to get a transformed tensor, |
|
and putting each transformed tensor into the fresh tensor subclass instance. |
|
|
|
Note: this function will not handle ensuring that the fresh subclass |
|
gets the same (autograd, and aliasing) metadata as the original tensor. |
|
This is generally handled in other subsystems like AOTAutograd. |
|
""" |
|
outer_size = outer_size if outer_size is not None else t.size() |
|
outer_stride = outer_stride if outer_stride is not None else t.stride() |
|
|
|
attrs, ctx = t.__tensor_flatten__() |
|
transformed_tensors_dict = {} |
|
for attr in attrs: |
|
transformed_tensors_dict[attr] = callback(attr, getattr(t, attr)) |
|
sub = type(t).__tensor_unflatten__( |
|
transformed_tensors_dict, ctx, outer_size, outer_stride |
|
) |
|
|
|
|
|
|
|
|
|
assert sub.shape == outer_size, ( |
|
f"Expected return value from {type(t)}__tensor_unflatten__() to have " |
|
f"shape equal to {outer_size}, but got: {sub.shape}" |
|
) |
|
assert sub.stride() == outer_stride, ( |
|
f"Expected return value from {type(t)}__tensor_unflatten__() to have " |
|
f"stride equal to {outer_stride}, but got: {sub.stride()}" |
|
) |
|
|
|
return sub |
|
|
|
|
|
def _correct_storage_aliasing(func, schema_info, args, outs): |
|
""" |
|
Given: an OpOverload, a SchemaInfo (cached information from torchgen about schema), |
|
and the inputs/outputs to the OpOverload, |
|
this function checks to see if func is a view operator |
|
(by checking if any of the outputs in the op's schema |
|
are immutable aliases of inputs). |
|
If so, this function manually aliases the storage of the output tensor |
|
with its corresponding input tensor alias. |
|
It does this by unsafely overwriting the storage field of the output tensor |
|
to be the same storage as the input. |
|
""" |
|
assert isinstance(func, torch._ops.OpOverload) |
|
assert isinstance(args, tuple) |
|
assert isinstance(outs, (list, tuple)) |
|
flat_outs = torch.utils._pytree.tree_leaves(outs) |
|
|
|
def alias_non_inplace_storage(arg, ret): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret): |
|
ret_list = ret if isinstance(ret, list) else [ret] |
|
for r in ret_list: |
|
assert type(arg) == type( |
|
r |
|
), f"""Called {str(func)} with input of type {type(arg)} |
|
and output of type {type(ret)}. But expected types to match.""" |
|
|
|
|
|
|
|
with torch.utils._mode_utils.no_dispatch(): |
|
|
|
|
|
meta_in_tls = torch._C._meta_in_tls_dispatch_include() |
|
torch._C._set_meta_in_tls_dispatch_include(True) |
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(ret, list): |
|
for r in ret: |
|
torch.ops.aten.set_.source_Storage_storage_offset( |
|
r, |
|
arg.untyped_storage(), |
|
r.storage_offset(), |
|
r.shape, |
|
r.stride(), |
|
) |
|
else: |
|
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}" |
|
torch.ops.aten.set_.source_Storage_storage_offset( |
|
ret, |
|
arg.untyped_storage(), |
|
ret.storage_offset(), |
|
ret.shape, |
|
ret.stride(), |
|
) |
|
finally: |
|
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) |
|
|
|
def is_read_only_alias_match(arg, ret): |
|
shared_aliases = arg.alias_set & ret.alias_set |
|
return len(shared_aliases) > 0 and not arg.is_write |
|
|
|
num_args = len(func._schema.arguments) |
|
num_returns = len(func._schema.returns) |
|
for arg_idx in range(num_args): |
|
for return_idx in range(num_returns): |
|
if is_read_only_alias_match( |
|
schema_info.args[arg_idx], schema_info.outs[return_idx] |
|
): |
|
alias_non_inplace_storage(args[arg_idx], outs[return_idx]) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class AliasInfo: |
|
alias_set: Set[str] |
|
is_write: bool |
|
name: Optional[str] |
|
|
|
|
|
@dataclass |
|
class SchemaInfo: |
|
args: List[AliasInfo] |
|
outs: List[AliasInfo] |
|
|
|
|
|
|
|
parsed_schema_map: Dict[Any, SchemaInfo] = {} |
|
|
|
|
|
|
|
|
|
def get_alias_info(func) -> SchemaInfo: |
|
if func in parsed_schema_map: |
|
return parsed_schema_map[func] |
|
|
|
|
|
if func.namespace == "aten": |
|
torchgen_schema_str = str(func._schema) |
|
assert torchgen_schema_str.startswith("aten::") |
|
|
|
|
|
torchgen_schema_str = torchgen_schema_str[6:] |
|
import re |
|
|
|
|
|
|
|
torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str) |
|
torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str) |
|
|
|
torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]") |
|
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str) |
|
arg_schemas = [ |
|
AliasInfo( |
|
alias_set=( |
|
set() if a.annotation is None else set(a.annotation.alias_set) |
|
), |
|
is_write=a.annotation is not None and a.annotation.is_write, |
|
name=a.name, |
|
) |
|
for a in torchgen_schema.arguments.flat_all |
|
] |
|
out_schemas = [ |
|
AliasInfo( |
|
alias_set=( |
|
set() if a.annotation is None else set(a.annotation.alias_set) |
|
), |
|
is_write=a.annotation is not None and a.annotation.is_write, |
|
name=a.name, |
|
) |
|
for a in torchgen_schema.returns |
|
] |
|
else: |
|
|
|
arg_schemas = [ |
|
AliasInfo( |
|
alias_set=( |
|
set() if a.alias_info is None else set(a.alias_info.before_set) |
|
), |
|
is_write=a.alias_info is not None and a.alias_info.is_write, |
|
name=a.name, |
|
) |
|
for a in func._schema.arguments |
|
] |
|
out_schemas = [ |
|
AliasInfo( |
|
alias_set=( |
|
set() if a.alias_info is None else set(a.alias_info.before_set) |
|
), |
|
is_write=a.alias_info is not None and a.alias_info.is_write, |
|
name=a.name, |
|
) |
|
for a in func._schema.returns |
|
] |
|
schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas) |
|
parsed_schema_map[func] = schema_info |
|
return schema_info |
|
|
|
|
|
def return_and_correct_aliasing(func, args, kwargs, out): |
|
""" |
|
This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses |
|
that would like to work with torch.compile. It ensures that the subclass |
|
properly implements the aliasing behavior of every op, |
|
which is needed for correctness in AOTAutograd. |
|
This function will handle: |
|
|
|
* When we see a view op, we will alias the storages of any |
|
input and output tensor subclasses |
|
|
|
* When we see an inplace or out= op, we will directly |
|
return the corresponding input tensor, instead of returning |
|
a (potentially) fresh output tensor. |
|
""" |
|
|
|
|
|
|
|
schema_info = get_alias_info(func) |
|
|
|
def get_write_alias(x): |
|
if len(x.alias_set) == 0: |
|
return None |
|
alias_set = list(x.alias_set) |
|
|
|
assert len(alias_set) == 1 |
|
if x.is_write: |
|
return alias_set[0] |
|
return None |
|
|
|
def get_arg_from_alias(output_alias, schema_info, args, kwargs): |
|
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( |
|
func, args=args, kwargs=kwargs |
|
) |
|
|
|
arg_indices = [ |
|
i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set |
|
] |
|
|
|
assert len(arg_indices) == 1 |
|
idx = arg_indices[0] |
|
arg_info = schema_info.args[idx] |
|
if arg_info.name is not None and arg_info.name in new_kwargs: |
|
return new_kwargs[arg_info.name] |
|
return new_args[idx] |
|
|
|
|
|
|
|
_correct_storage_aliasing( |
|
func, schema_info, args, (out,) if not isinstance(out, tuple) else out |
|
) |
|
|
|
|
|
|
|
if torch.Tag.inplace_view in func.tags: |
|
|
|
|
|
mutated_args = [ |
|
x |
|
for i, x in enumerate(args) |
|
if get_write_alias(schema_info.args[i]) is not None |
|
] |
|
|
|
|
|
assert len(mutated_args) == 1 |
|
|
|
|
|
|
|
from torch._subclasses.functional_tensor import FunctionalTensor |
|
|
|
if not isinstance(mutated_args[0], FunctionalTensor): |
|
with torch.utils._mode_utils.no_dispatch(): |
|
|
|
|
|
meta_in_tls = torch._C._meta_in_tls_dispatch_include() |
|
torch._C._set_meta_in_tls_dispatch_include(True) |
|
try: |
|
func(*args, **kwargs) |
|
finally: |
|
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) |
|
|
|
|
|
|
|
|
|
if not any(get_write_alias(r) is not None for r in schema_info.outs): |
|
return out |
|
|
|
|
|
if not all(get_write_alias(r) is not None for r in schema_info.outs): |
|
raise RuntimeError("Unsupported schema: " + str(func._schema)) |
|
|
|
if len(func._schema.returns) == 1: |
|
return get_arg_from_alias( |
|
get_write_alias(schema_info.outs[0]), schema_info, args, kwargs |
|
) |
|
|
|
|
|
outs_to_return = type(out)( |
|
[ |
|
( |
|
get_arg_from_alias( |
|
get_write_alias(schema_info.outs[i]), schema_info, args, kwargs |
|
) |
|
if get_write_alias(r) is not None |
|
else o |
|
) |
|
for ((i, r), o) in zip(enumerate(schema_info.outs), out) |
|
] |
|
) |
|
return outs_to_return |
|
|