|
|
|
|
|
from __future__ import annotations |
|
|
|
from typing import Optional |
|
|
|
import torch |
|
|
|
from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util |
|
from ._normalizations import ( |
|
ArrayLike, |
|
ArrayLikeOrScalar, |
|
CastingModes, |
|
DTypeLike, |
|
normalizer, |
|
NotImplementedType, |
|
OutArray, |
|
) |
|
|
|
|
|
def _ufunc_postprocess(result, out, casting): |
|
if out is not None: |
|
result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting) |
|
result = torch.broadcast_to(result, out.shape) |
|
return result |
|
|
|
|
|
|
|
|
|
_binary = [ |
|
name |
|
for name in dir(_binary_ufuncs_impl) |
|
if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"] |
|
] |
|
|
|
|
|
NEP50_FUNCS = ( |
|
"add", |
|
"subtract", |
|
"multiply", |
|
"floor_divide", |
|
"true_divide", |
|
"divide", |
|
"remainder", |
|
"bitwise_and", |
|
"bitwise_or", |
|
"bitwise_xor", |
|
"bitwise_left_shift", |
|
"bitwise_right_shift", |
|
"hypot", |
|
"arctan2", |
|
"logaddexp", |
|
"logaddexp2", |
|
"heaviside", |
|
"copysign", |
|
"fmax", |
|
"minimum", |
|
"fmin", |
|
"maximum", |
|
"fmod", |
|
"gcd", |
|
"lcm", |
|
"pow", |
|
) |
|
|
|
|
|
def deco_binary_ufunc(torch_func): |
|
"""Common infra for binary ufuncs. |
|
|
|
Normalize arguments, sort out type casting, broadcasting and delegate to |
|
the pytorch functions for the actual work. |
|
""" |
|
|
|
@normalizer |
|
def wrapped( |
|
x1: ArrayLikeOrScalar, |
|
x2: ArrayLikeOrScalar, |
|
/, |
|
out: Optional[OutArray] = None, |
|
*, |
|
where: NotImplementedType = True, |
|
casting: Optional[CastingModes] = "same_kind", |
|
order: NotImplementedType = "K", |
|
dtype: Optional[DTypeLike] = None, |
|
subok: NotImplementedType = False, |
|
signature: NotImplementedType = None, |
|
extobj: NotImplementedType = None, |
|
): |
|
if dtype is not None: |
|
|
|
def cast(x, dtype): |
|
if isinstance(x, torch.Tensor): |
|
return _util.typecast_tensor(x, dtype, casting) |
|
else: |
|
return torch.as_tensor(x, dtype=dtype) |
|
|
|
x1 = cast(x1, dtype) |
|
x2 = cast(x2, dtype) |
|
elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): |
|
dtype = _dtypes_impl.result_type_impl(x1, x2) |
|
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) |
|
else: |
|
x1, x2 = _dtypes_impl.nep50_to_tensors( |
|
x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__ |
|
) |
|
|
|
result = torch_func(x1, x2) |
|
|
|
return _ufunc_postprocess(result, out, casting) |
|
|
|
wrapped.__qualname__ = torch_func.__name__ |
|
wrapped.__name__ = torch_func.__name__ |
|
|
|
return wrapped |
|
|
|
|
|
|
|
|
|
|
|
|
|
@normalizer |
|
def matmul( |
|
x1: ArrayLike, |
|
x2: ArrayLike, |
|
/, |
|
out: Optional[OutArray] = None, |
|
*, |
|
casting: Optional[CastingModes] = "same_kind", |
|
order: NotImplementedType = "K", |
|
dtype: Optional[DTypeLike] = None, |
|
subok: NotImplementedType = False, |
|
signature: NotImplementedType = None, |
|
extobj: NotImplementedType = None, |
|
axes: NotImplementedType = None, |
|
axis: NotImplementedType = None, |
|
): |
|
if dtype is None: |
|
dtype = _dtypes_impl.result_type_impl(x1, x2) |
|
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) |
|
|
|
result = _binary_ufuncs_impl.matmul(x1, x2) |
|
|
|
result = _ufunc_postprocess(result, out, casting) |
|
return result |
|
|
|
|
|
|
|
@normalizer |
|
def ldexp( |
|
x1: ArrayLikeOrScalar, |
|
x2: ArrayLikeOrScalar, |
|
/, |
|
out: Optional[OutArray] = None, |
|
*, |
|
where: NotImplementedType = True, |
|
casting: Optional[CastingModes] = "same_kind", |
|
order: NotImplementedType = "K", |
|
dtype: Optional[DTypeLike] = None, |
|
subok: NotImplementedType = False, |
|
signature: NotImplementedType = None, |
|
extobj: NotImplementedType = None, |
|
): |
|
if dtype is not None: |
|
if isinstance(x1, torch.Tensor): |
|
x1 = _util.typecast_tensor(x1, dtype, casting) |
|
else: |
|
x1 = torch.as_tensor(x1, dtype=dtype) |
|
else: |
|
if not isinstance(x1, torch.Tensor): |
|
x1 = torch.as_tensor(x1) |
|
x1 = _util.cast_int_to_float(x1) |
|
|
|
x2 = torch.as_tensor(x2) |
|
|
|
if _dtypes_impl._category(x2.dtype) != 1: |
|
raise ValueError("ldexp 2nd arg must be integer") |
|
|
|
result = _binary_ufuncs_impl.ldexp(x1, x2) |
|
|
|
if x1.dtype == torch.float16: |
|
|
|
result = result.to(torch.float16) |
|
|
|
return _ufunc_postprocess(result, out, casting) |
|
|
|
|
|
|
|
@normalizer |
|
def divmod( |
|
x1: ArrayLike, |
|
x2: ArrayLike, |
|
out1: Optional[OutArray] = None, |
|
out2: Optional[OutArray] = None, |
|
/, |
|
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None), |
|
*, |
|
where: NotImplementedType = True, |
|
casting: Optional[CastingModes] = "same_kind", |
|
order: NotImplementedType = "K", |
|
dtype: Optional[DTypeLike] = None, |
|
subok: NotImplementedType = False, |
|
signature: NotImplementedType = None, |
|
extobj: NotImplementedType = None, |
|
): |
|
|
|
|
|
num_outs = sum(x is not None for x in [out1, out2]) |
|
if num_outs == 1: |
|
raise ValueError("both out1 and out2 need to be provided") |
|
elif num_outs == 2: |
|
o1, o2 = out |
|
if o1 is not None or o2 is not None: |
|
raise TypeError( |
|
"cannot specify 'out' as both a positional and keyword argument" |
|
) |
|
else: |
|
out1, out2 = out |
|
|
|
if dtype is None: |
|
dtype = _dtypes_impl.result_type_impl(x1, x2) |
|
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) |
|
|
|
quot, rem = _binary_ufuncs_impl.divmod(x1, x2) |
|
|
|
quot = _ufunc_postprocess(quot, out1, casting) |
|
rem = _ufunc_postprocess(rem, out2, casting) |
|
return quot, rem |
|
|
|
|
|
|
|
|
|
|
|
for name in _binary: |
|
ufunc = getattr(_binary_ufuncs_impl, name) |
|
vars()[name] = deco_binary_ufunc(ufunc) |
|
|
|
|
|
def modf(x, /, *args, **kwds): |
|
quot, rem = divmod(x, 1, *args, **kwds) |
|
return rem, quot |
|
|
|
|
|
_binary = _binary + ["divmod", "modf", "matmul", "ldexp"] |
|
|
|
|
|
|
|
|
|
|
|
_unary = [ |
|
name |
|
for name in dir(_unary_ufuncs_impl) |
|
if not name.startswith("_") and name != "torch" |
|
] |
|
|
|
|
|
|
|
_fp_unary = [ |
|
"arccos", |
|
"arccosh", |
|
"arcsin", |
|
"arcsinh", |
|
"arctan", |
|
"arctanh", |
|
"cbrt", |
|
"cos", |
|
"cosh", |
|
"deg2rad", |
|
"degrees", |
|
"exp", |
|
"exp2", |
|
"expm1", |
|
"log", |
|
"log10", |
|
"log1p", |
|
"log2", |
|
"rad2deg", |
|
"radians", |
|
"reciprocal", |
|
"sin", |
|
"sinh", |
|
"sqrt", |
|
"square", |
|
"tan", |
|
"tanh", |
|
"trunc", |
|
] |
|
|
|
|
|
def deco_unary_ufunc(torch_func): |
|
"""Common infra for unary ufuncs. |
|
|
|
Normalize arguments, sort out type casting, broadcasting and delegate to |
|
the pytorch functions for the actual work. |
|
""" |
|
|
|
@normalizer |
|
def wrapped( |
|
x: ArrayLike, |
|
/, |
|
out: Optional[OutArray] = None, |
|
*, |
|
where=True, |
|
casting: Optional[CastingModes] = "same_kind", |
|
order="K", |
|
dtype: Optional[DTypeLike] = None, |
|
subok: NotImplementedType = False, |
|
signature=None, |
|
extobj=None, |
|
): |
|
if dtype is not None: |
|
x = _util.typecast_tensor(x, dtype, casting) |
|
|
|
if torch_func.__name__ in _fp_unary: |
|
x = _util.cast_int_to_float(x) |
|
|
|
result = torch_func(x) |
|
result = _ufunc_postprocess(result, out, casting) |
|
return result |
|
|
|
wrapped.__qualname__ = torch_func.__name__ |
|
wrapped.__name__ = torch_func.__name__ |
|
|
|
return wrapped |
|
|
|
|
|
|
|
|
|
|
|
for name in _unary: |
|
ufunc = getattr(_unary_ufuncs_impl, name) |
|
vars()[name] = deco_unary_ufunc(ufunc) |
|
|
|
|
|
__all__ = _binary + _unary |
|
|