|
|
|
import inspect |
|
import warnings |
|
from functools import wraps |
|
from itertools import chain |
|
|
|
from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple |
|
|
|
import torch |
|
import torch._prims_common as utils |
|
from torch._prims_common import ( |
|
CustomOutParamAnnotation, |
|
ELEMENTWISE_TYPE_PROMOTION_KIND, |
|
Number, |
|
NumberType, |
|
ShapeType, |
|
TensorLike, |
|
TensorLikeType, |
|
) |
|
from torch.utils import _pytree as pytree |
|
from torch.utils._pytree import tree_flatten, tree_unflatten |
|
|
|
|
|
@overload |
|
def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: |
|
pass |
|
|
|
|
|
@overload |
|
def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType: |
|
pass |
|
|
|
|
|
@overload |
|
def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence: |
|
pass |
|
|
|
|
|
@overload |
|
def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None: |
|
pass |
|
|
|
|
|
|
|
def _maybe_convert_to_dtype(a, dtype): |
|
if isinstance(a, TensorLike): |
|
if a.dtype != dtype: |
|
return a.to(dtype) |
|
return a |
|
if isinstance(a, Number): |
|
return utils.dtype_to_type_ctor(dtype)(a) |
|
if isinstance(a, Sequence): |
|
return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) |
|
|
|
|
|
if a is None: |
|
return None |
|
|
|
raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!") |
|
|
|
|
|
def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: |
|
if not isinstance(a, Number): |
|
msg = f"Found unknown type {type(a)} when trying to convert scalars!" |
|
raise ValueError(msg) |
|
if not utils.is_weakly_lesser_type(type(a), typ): |
|
msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!" |
|
raise ValueError(msg) |
|
|
|
return typ(a) |
|
|
|
|
|
def _annotation_has_type(*, typ, annotation): |
|
if hasattr(annotation, "__args__"): |
|
for a in annotation.__args__: |
|
if _annotation_has_type(typ=typ, annotation=a): |
|
return True |
|
return False |
|
|
|
return typ is annotation |
|
|
|
|
|
class elementwise_type_promotion_wrapper: |
|
""" |
|
Adds elementwise type promotion to a Python reference implementation. |
|
|
|
Takes two kwargs, type_promoting_args and type_promotion_kind. |
|
|
|
type_promoting_args must be a string Sequence specifiying the argument names of all |
|
arguments that participate in type promotion (and should be type promoted). If the |
|
arg specifies a Sequence-type then every element of the Sequence will participate in |
|
type promotion. |
|
|
|
type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. |
|
See its documentation for details. |
|
|
|
The return_dtype will be coerced to the wrapped function's dtype arg if it is available and |
|
not None. |
|
|
|
Other type promotion behavior, like validating the Python type of scalar arguments, must |
|
be handled separately. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, |
|
type_promoting_args: Optional[Sequence[str]] = None, |
|
): |
|
self.type_promoting_arg_names = type_promoting_args |
|
self.type_promotion_kind = type_promotion_kind |
|
|
|
def __call__(self, fn: Callable) -> Callable: |
|
sig = inspect.signature(fn) |
|
|
|
@wraps(fn) |
|
def _fn(*args, **kwargs): |
|
bound = sig.bind(*args, **kwargs) |
|
type_promoting_args = tuple( |
|
bound.arguments[x] |
|
for x in self.type_promoting_arg_names |
|
if x in bound.arguments.keys() |
|
) |
|
|
|
flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args) |
|
compute_dtype, result_dtype = utils.elementwise_dtypes( |
|
*flattened_type_promoting_args, |
|
type_promotion_kind=self.type_promotion_kind, |
|
) |
|
|
|
promoted_args = { |
|
x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) |
|
for x in self.type_promoting_arg_names |
|
if x in bound.arguments.keys() |
|
} |
|
bound.arguments.update(promoted_args) |
|
|
|
result = fn(**bound.arguments) |
|
|
|
|
|
if "dtype" in bound.arguments: |
|
maybe_dtype = bound.arguments["dtype"] |
|
if maybe_dtype: |
|
result_dtype = maybe_dtype |
|
|
|
if isinstance(result, TensorLike): |
|
return _maybe_convert_to_dtype(result, result_dtype) |
|
if isinstance(result, Sequence): |
|
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result) |
|
raise AssertionError(f"Unhandled result type: {type(result)}") |
|
|
|
_fn.__signature__ = sig |
|
return _fn |
|
|
|
|
|
|
|
def _resize_output_check(out: TensorLikeType, shape: ShapeType): |
|
|
|
if utils.same_shape(out.shape, shape): |
|
return False |
|
if out.numel() != 0: |
|
msg = ( |
|
f"An output with one or more elements was resized since it had shape {str(out.shape)} " |
|
"which does not match the required output shape {str(shape)}. " |
|
"This behavior is deprecated, and in a future PyTorch release outputs will not " |
|
"be resized unless they have zero elements. " |
|
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." |
|
) |
|
warnings.warn(msg) |
|
return True |
|
|
|
|
|
|
|
def _maybe_resize_out( |
|
out: TensorLikeType, |
|
shape: ShapeType, |
|
memory_format: Optional[torch.memory_format] = None, |
|
): |
|
if _resize_output_check(out, shape): |
|
return out.resize_(shape, memory_format=memory_format) |
|
else: |
|
return out |
|
|
|
|
|
def _safe_copy_out( |
|
*, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False |
|
): |
|
|
|
if copy_from.device != copy_to.device: |
|
msg = ( |
|
f"Attempting to copy from device {copy_from.device} " |
|
f"to device {copy_to.device}, but cross-device copies are not allowed!" |
|
) |
|
raise RuntimeError(msg) |
|
|
|
|
|
if exact_dtype: |
|
torch._check( |
|
copy_from.dtype == copy_to.dtype, |
|
lambda: f"Expected out tensor to have dtype {copy_from.dtype} " |
|
f"but got {copy_to.dtype} instead", |
|
) |
|
else: |
|
torch._check( |
|
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), |
|
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " |
|
"but this can't be cast because it is not safe!", |
|
) |
|
|
|
return copy_to.copy_(copy_from) |
|
|
|
|
|
def out_wrapper( |
|
*out_names: str, |
|
exact_dtype: bool = False, |
|
pass_is_out: bool = False, |
|
preserve_memory_format=False, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
default_out_names = ("out",) |
|
if len(out_names) == 0: |
|
|
|
out_names = default_out_names |
|
|
|
is_tensor = len(out_names) == 1 |
|
|
|
def maybe_compute_memory_format(t): |
|
return utils.suggest_memory_format(t) if preserve_memory_format else None |
|
|
|
def _out_wrapper(fn: Callable) -> Callable: |
|
""" |
|
Adds the out parameter to a Python reference. |
|
""" |
|
out_type = ( |
|
TensorLikeType |
|
if is_tensor |
|
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] |
|
) |
|
return_type = ( |
|
TensorLikeType |
|
if is_tensor |
|
else NamedTuple( |
|
f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] |
|
) |
|
) |
|
|
|
sig = inspect.signature(fn) |
|
factory_kwargs = ("device", "dtype") |
|
is_factory_fn = all(p in sig.parameters for p in factory_kwargs) |
|
|
|
@wraps(fn) |
|
def _fn(*args, out=None, **kwargs): |
|
if is_factory_fn and out is not None: |
|
for k in factory_kwargs: |
|
out_attr = getattr(out, k) |
|
if k not in kwargs: |
|
kwargs[k] = out_attr |
|
if pass_is_out: |
|
result = fn(*args, is_out=(out is not None), **kwargs) |
|
else: |
|
result = fn(*args, **kwargs) |
|
assert ( |
|
isinstance(result, TensorLike) |
|
and is_tensor |
|
or isinstance(result, Tuple) |
|
and len(result) == len(out_names) |
|
) |
|
if out is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_tensor: |
|
assert isinstance(out, TensorLike) |
|
|
|
_maybe_resize_out( |
|
out, result.shape, maybe_compute_memory_format(result) |
|
) |
|
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) |
|
else: |
|
assert isinstance(out, Tuple) |
|
torch._check_type( |
|
len(out) == len(result), |
|
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", |
|
) |
|
for r, o in zip(result, out): |
|
|
|
_maybe_resize_out(o, r.shape, maybe_compute_memory_format(r)) |
|
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) |
|
else: |
|
out = result |
|
|
|
return out if is_tensor else return_type(*out) |
|
|
|
out_param = inspect.Parameter( |
|
"out", |
|
kind=inspect.Parameter.KEYWORD_ONLY, |
|
default=None, |
|
annotation=out_type, |
|
) |
|
|
|
assert isinstance(sig.return_annotation, str) or sig.return_annotation in ( |
|
sig.empty, |
|
out_type, |
|
) |
|
params = chain(sig.parameters.values(), (out_param,)) |
|
_fn.__signature__ = inspect.Signature( |
|
parameters=params, return_annotation=return_type |
|
) |
|
|
|
_fn.__annotations__ = fn.__annotations__ |
|
_fn.__annotations__["out"] = out_type |
|
_fn.__annotations__["return"] = return_type |
|
|
|
|
|
|
|
if is_tensor and out_names != default_out_names: |
|
_fn.__annotations__[CustomOutParamAnnotation] = out_names[0] |
|
|
|
|
|
|
|
|
|
_fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" |
|
|
|
return _fn |
|
|
|
return _out_wrapper |
|
|
|
|
|
def _maybe_remove_out_wrapper(fn: Callable): |
|
return inspect.unwrap( |
|
fn, |
|
stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"), |
|
) |
|
|
|
|
|
def backwards_not_supported(prim): |
|
def redispatch_prim(args, kwargs): |
|
with torch._C._AutoDispatchBelowAutograd(): |
|
old = torch._C._dispatch_tls_is_dispatch_key_excluded( |
|
torch._C.DispatchKey.ADInplaceOrView |
|
) |
|
return prim(*args, **kwargs) |
|
|
|
class BackwardsNotSupported(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, args_spec, *flat_args): |
|
args, kwargs = tree_unflatten(flat_args, args_spec) |
|
return redispatch_prim(args, kwargs) |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
raise RuntimeError("backwards not supported on prim") |
|
|
|
@wraps(prim) |
|
def _autograd_impl(*args, **kwargs): |
|
flat_args, args_spec = tree_flatten((args, kwargs)) |
|
if torch.is_grad_enabled() and any( |
|
a.requires_grad for a in flat_args if isinstance(a, torch.Tensor) |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return BackwardsNotSupported.apply(args_spec, *flat_args) |
|
else: |
|
return redispatch_prim(args, kwargs) |
|
|
|
return _autograd_impl |
|
|
|
|
|
|
|
|
|
|
|
def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable: |
|
""" |
|
Allows unary operators that accept tensors to work with Python numbers. |
|
""" |
|
sig = inspect.signature(fn) |
|
|
|
@wraps(fn) |
|
def _fn(*args, **kwargs): |
|
if len(args) > 0 and isinstance(args[0], Number): |
|
dtype = utils.type_to_dtype(type(args[0])) |
|
args_ = list(args) |
|
args_[0] = torch.tensor(args[0], dtype=dtype) |
|
result = fn(*args_, **kwargs) |
|
assert isinstance(result, torch.Tensor) |
|
return result.item() |
|
|
|
return fn(*args, **kwargs) |
|
|
|
_fn.__signature__ = sig |
|
return _fn |
|
|