import dis import inspect from typing import Sequence, Union import torch import functorch._C from functorch._C import dim as _C from .tree_map import tree_flatten, tree_map from .wrap_type import wrap_type _C._patch_tensor_class() dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists class DimensionMismatchError(Exception): pass class DimensionBindError(Exception): pass from . import op_properties # use dict to avoid writing C++ bindings for set pointwise = dict.fromkeys(op_properties.pointwise, True) use_c = True if not use_c: from . import reference class _Tensor: # fast path around slow wrapping/unwrapping logic for simply queries used # by the implementation... @property def dims(self): return tuple(d for d in self._levels if isinstance(d, Dim)) def dim(self): return self.ndim if use_c: __torch_function__ = classmethod(_C.__torch_function__) expand = _C._instancemethod(_C.expand) else: __torch_function__ = reference.__torch_function__ expand = reference.expand index = _C._instancemethod(_C.index) def __repr__(self): tensor, levels, ndim = self._tensor, self._levels, self.ndim return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}" TensorLike = (_Tensor, torch.Tensor) class Dim(_C.Dim, _Tensor): # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence. # Tensor defines format, but we want to print Dims with special formatting __format__ = object.__format__ class Tensor(_Tensor, _C.Tensor): if not use_c: from_batched = staticmethod(_C.Tensor_from_batched) from_positional = staticmethod(_C.Tensor_from_positional) sum = _C._instancemethod(_C.Tensor_sum) def cat(tensors, dim, new_dim): n = dims() return stack(tensors, n, dim).index([n, dim], new_dim) if use_c: _wrap = _C._wrap def _def(name, *args, **kwargs): orig = getattr(torch.Tensor, name) setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) t__getitem__ = _C._instancemethod(_C.__getitem__) stack = _C.stack split = _C._instancemethod(_C.split) else: _wrap, _def = reference._wrap, reference._def t__getitem__ = reference.t__getitem__ stack = reference.stack split = reference.split # note: there is no python reference t__setitem__ = _C._instancemethod(_C.__setitem__) # this is patched in the C API because otherwise torch.Tensor will # no longer be considered a sequence and things will break # torch.Tensor.__getitem__ = t__getitem__ _Tensor.__getitem__ = t__getitem__ # torch.Tensor.__setitem__ = t__setitem__ _Tensor.__setitem__ = t__setitem__ torch.Tensor.split = split _Tensor.split = split torch.Tensor.expand = _C._instancemethod(_C.expand) torch.Tensor.index = _C._instancemethod(_C.index) wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__) del _Tensor.ndim if use_c: _Tensor.order = _C._instancemethod(_C.order) else: _Tensor.order = reference.positional _def("mean") _def("sum") _def("all") _def("amax") _def("amin") _def("aminmax") _def("any") _def("count_nonzero") _def("logsumexp") _def("nanmean") _def("nansum") _def("prod") _def("std", keepdim_offset=2) _def("var", keepdim_offset=2) _def("max", single_dim=True) _def("min", single_dim=True) _def("argmax", single_dim=True) _def("argmin", single_dim=True) _def("kthvalue", single_dim=True) _def("median", single_dim=True) _def("nanmedian", single_dim=True) _def("mode", single_dim=True) _def("sort", reduce=False) _def("argsort", reduce=False) _def("unbind", single_dim=True) _def("chunk", dim_offset=1, reduce=False) _def("cummax", single_dim=True, reduce=False) _def("cummin", single_dim=True, reduce=False) _def("cumprod", single_dim=True, reduce=False) _def("cumprod_", single_dim=True, reduce=False) _def("cumsum", single_dim=True, reduce=False) _def("cumsum_", single_dim=True, reduce=False) _def("logcumsumexp", single_dim=True, reduce=False) _def("renorm", dim_offset=1, single_dim=True, reduce=False) _def("softmax", single_dim=True, reduce=False) softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False) # stuff to handle in the future, because they require special # binding logic for dims # cross # diag_embed # diagonal # diagonal_scatter # diff # nanquantile # quantile # roll # rot90 # topk (new dimes on output) # should these all be subsumed by inplace indexing? # index_add_ # index_add # index_copy # index_copy_ # index_fill # index_fill_ # index_select # scatter # scatter_ # scatter_add # scatter_add_ # scatter_reduce