|
|
|
|
|
"""Assorted utilities, which do not need anything other then torch and stdlib. |
|
""" |
|
|
|
import operator |
|
|
|
import torch |
|
|
|
from . import _dtypes_impl |
|
|
|
|
|
|
|
def is_sequence(seq): |
|
if isinstance(seq, str): |
|
return False |
|
try: |
|
len(seq) |
|
except Exception: |
|
return False |
|
return True |
|
|
|
|
|
class AxisError(ValueError, IndexError): |
|
pass |
|
|
|
|
|
class UFuncTypeError(TypeError, RuntimeError): |
|
pass |
|
|
|
|
|
def cast_if_needed(tensor, dtype): |
|
|
|
if dtype is not None and tensor.dtype != dtype: |
|
tensor = tensor.to(dtype) |
|
return tensor |
|
|
|
|
|
def cast_int_to_float(x): |
|
|
|
if _dtypes_impl._category(x.dtype) < 2: |
|
x = x.to(_dtypes_impl.default_dtypes().float_dtype) |
|
return x |
|
|
|
|
|
|
|
def normalize_axis_index(ax, ndim, argname=None): |
|
if not (-ndim <= ax < ndim): |
|
raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") |
|
if ax < 0: |
|
ax += ndim |
|
return ax |
|
|
|
|
|
|
|
def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): |
|
""" |
|
Normalizes an axis argument into a tuple of non-negative integer axes. |
|
|
|
This handles shorthands such as ``1`` and converts them to ``(1,)``, |
|
as well as performing the handling of negative indices covered by |
|
`normalize_axis_index`. |
|
|
|
By default, this forbids axes from being specified multiple times. |
|
Used internally by multi-axis-checking logic. |
|
|
|
Parameters |
|
---------- |
|
axis : int, iterable of int |
|
The un-normalized index or indices of the axis. |
|
ndim : int |
|
The number of dimensions of the array that `axis` should be normalized |
|
against. |
|
argname : str, optional |
|
A prefix to put before the error message, typically the name of the |
|
argument. |
|
allow_duplicate : bool, optional |
|
If False, the default, disallow an axis from being specified twice. |
|
|
|
Returns |
|
------- |
|
normalized_axes : tuple of int |
|
The normalized axis index, such that `0 <= normalized_axis < ndim` |
|
""" |
|
|
|
if type(axis) not in (tuple, list): |
|
try: |
|
axis = [operator.index(axis)] |
|
except TypeError: |
|
pass |
|
|
|
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) |
|
if not allow_duplicate and len(set(axis)) != len(axis): |
|
if argname: |
|
raise ValueError(f"repeated axis in `{argname}` argument") |
|
else: |
|
raise ValueError("repeated axis") |
|
return axis |
|
|
|
|
|
def allow_only_single_axis(axis): |
|
if axis is None: |
|
return axis |
|
if len(axis) != 1: |
|
raise NotImplementedError("does not handle tuple axis") |
|
return axis[0] |
|
|
|
|
|
def expand_shape(arr_shape, axis): |
|
|
|
if type(axis) not in (list, tuple): |
|
axis = (axis,) |
|
out_ndim = len(axis) + len(arr_shape) |
|
axis = normalize_axis_tuple(axis, out_ndim) |
|
shape_it = iter(arr_shape) |
|
shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] |
|
return shape |
|
|
|
|
|
def apply_keepdims(tensor, axis, ndim): |
|
if axis is None: |
|
|
|
shape = (1,) * ndim |
|
tensor = tensor.expand(shape).contiguous() |
|
else: |
|
shape = expand_shape(tensor.shape, axis) |
|
tensor = tensor.reshape(shape) |
|
return tensor |
|
|
|
|
|
def axis_none_flatten(*tensors, axis=None): |
|
"""Flatten the arrays if axis is None.""" |
|
if axis is None: |
|
tensors = tuple(ar.flatten() for ar in tensors) |
|
return tensors, 0 |
|
else: |
|
return tensors, axis |
|
|
|
|
|
def typecast_tensor(t, target_dtype, casting): |
|
"""Dtype-cast tensor to target_dtype. |
|
|
|
Parameters |
|
---------- |
|
t : torch.Tensor |
|
The tensor to cast |
|
target_dtype : torch dtype object |
|
The array dtype to cast all tensors to |
|
casting : str |
|
The casting mode, see `np.can_cast` |
|
|
|
Returns |
|
------- |
|
`torch.Tensor` of the `target_dtype` dtype |
|
|
|
Raises |
|
------ |
|
ValueError |
|
if the argument cannot be cast according to the `casting` rule |
|
|
|
""" |
|
can_cast = _dtypes_impl.can_cast_impl |
|
|
|
if not can_cast(t.dtype, target_dtype, casting=casting): |
|
raise TypeError( |
|
f"Cannot cast array data from {t.dtype} to" |
|
f" {target_dtype} according to the rule '{casting}'" |
|
) |
|
return cast_if_needed(t, target_dtype) |
|
|
|
|
|
def typecast_tensors(tensors, target_dtype, casting): |
|
return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) |
|
|
|
|
|
def _try_convert_to_tensor(obj): |
|
try: |
|
tensor = torch.as_tensor(obj) |
|
except Exception as e: |
|
mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." |
|
raise NotImplementedError(mesg) |
|
return tensor |
|
|
|
|
|
def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): |
|
"""The core logic of the array(...) function. |
|
|
|
Parameters |
|
---------- |
|
obj : tensor_like |
|
The thing to coerce |
|
dtype : torch.dtype object or None |
|
Coerce to this torch dtype |
|
copy : bool |
|
Copy or not |
|
ndmin : int |
|
The results as least this many dimensions |
|
is_weak : bool |
|
Whether obj is a weakly typed python scalar. |
|
|
|
Returns |
|
------- |
|
tensor : torch.Tensor |
|
a tensor object with requested dtype, ndim and copy semantics. |
|
|
|
Notes |
|
----- |
|
This is almost a "tensor_like" coersion function. Does not handle wrapper |
|
ndarrays (those should be handled in the ndarray-aware layer prior to |
|
invoking this function). |
|
""" |
|
if isinstance(obj, torch.Tensor): |
|
tensor = obj |
|
else: |
|
|
|
|
|
|
|
|
|
default_dtype = torch.get_default_dtype() |
|
torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32)) |
|
try: |
|
tensor = _try_convert_to_tensor(obj) |
|
finally: |
|
torch.set_default_dtype(default_dtype) |
|
|
|
|
|
tensor = cast_if_needed(tensor, dtype) |
|
|
|
|
|
ndim_extra = ndmin - tensor.ndim |
|
if ndim_extra > 0: |
|
tensor = tensor.view((1,) * ndim_extra + tensor.shape) |
|
|
|
|
|
if copy: |
|
tensor = tensor.clone() |
|
|
|
return tensor |
|
|
|
|
|
def ndarrays_to_tensors(*inputs): |
|
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)""" |
|
from ._ndarray import ndarray |
|
|
|
if len(inputs) == 0: |
|
return ValueError() |
|
elif len(inputs) == 1: |
|
input_ = inputs[0] |
|
if isinstance(input_, ndarray): |
|
return input_.tensor |
|
elif isinstance(input_, tuple): |
|
result = [] |
|
for sub_input in input_: |
|
sub_result = ndarrays_to_tensors(sub_input) |
|
result.append(sub_result) |
|
return tuple(result) |
|
else: |
|
return input_ |
|
else: |
|
assert isinstance(inputs, tuple) |
|
return ndarrays_to_tensors(inputs) |
|
|