Spaces:
Running
Running
File size: 4,900 Bytes
375a1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import dis
import inspect
from typing import Sequence, Union
import torch
import functorch._C
from functorch._C import dim as _C
from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type
_C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
class DimensionMismatchError(Exception):
pass
class DimensionBindError(Exception):
pass
from . import op_properties
# use dict to avoid writing C++ bindings for set
pointwise = dict.fromkeys(op_properties.pointwise, True)
use_c = True
if not use_c:
from . import reference
class _Tensor:
# fast path around slow wrapping/unwrapping logic for simply queries used
# by the implementation...
@property
def dims(self):
return tuple(d for d in self._levels if isinstance(d, Dim))
def dim(self):
return self.ndim
if use_c:
__torch_function__ = classmethod(_C.__torch_function__)
expand = _C._instancemethod(_C.expand)
else:
__torch_function__ = reference.__torch_function__
expand = reference.expand
index = _C._instancemethod(_C.index)
def __repr__(self):
tensor, levels, ndim = self._tensor, self._levels, self.ndim
return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
TensorLike = (_Tensor, torch.Tensor)
class Dim(_C.Dim, _Tensor):
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
# Tensor defines format, but we want to print Dims with special formatting
__format__ = object.__format__
class Tensor(_Tensor, _C.Tensor):
if not use_c:
from_batched = staticmethod(_C.Tensor_from_batched)
from_positional = staticmethod(_C.Tensor_from_positional)
sum = _C._instancemethod(_C.Tensor_sum)
def cat(tensors, dim, new_dim):
n = dims()
return stack(tensors, n, dim).index([n, dim], new_dim)
if use_c:
_wrap = _C._wrap
def _def(name, *args, **kwargs):
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
t__getitem__ = _C._instancemethod(_C.__getitem__)
stack = _C.stack
split = _C._instancemethod(_C.split)
else:
_wrap, _def = reference._wrap, reference._def
t__getitem__ = reference.t__getitem__
stack = reference.stack
split = reference.split
# note: there is no python reference
t__setitem__ = _C._instancemethod(_C.__setitem__)
# this is patched in the C API because otherwise torch.Tensor will
# no longer be considered a sequence and things will break
# torch.Tensor.__getitem__ = t__getitem__
_Tensor.__getitem__ = t__getitem__
# torch.Tensor.__setitem__ = t__setitem__
_Tensor.__setitem__ = t__setitem__
torch.Tensor.split = split
_Tensor.split = split
torch.Tensor.expand = _C._instancemethod(_C.expand)
torch.Tensor.index = _C._instancemethod(_C.index)
wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
del _Tensor.ndim
if use_c:
_Tensor.order = _C._instancemethod(_C.order)
else:
_Tensor.order = reference.positional
_def("mean")
_def("sum")
_def("all")
_def("amax")
_def("amin")
_def("aminmax")
_def("any")
_def("count_nonzero")
_def("logsumexp")
_def("nanmean")
_def("nansum")
_def("prod")
_def("std", keepdim_offset=2)
_def("var", keepdim_offset=2)
_def("max", single_dim=True)
_def("min", single_dim=True)
_def("argmax", single_dim=True)
_def("argmin", single_dim=True)
_def("kthvalue", single_dim=True)
_def("median", single_dim=True)
_def("nanmedian", single_dim=True)
_def("mode", single_dim=True)
_def("sort", reduce=False)
_def("argsort", reduce=False)
_def("unbind", single_dim=True)
_def("chunk", dim_offset=1, reduce=False)
_def("cummax", single_dim=True, reduce=False)
_def("cummin", single_dim=True, reduce=False)
_def("cumprod", single_dim=True, reduce=False)
_def("cumprod_", single_dim=True, reduce=False)
_def("cumsum", single_dim=True, reduce=False)
_def("cumsum_", single_dim=True, reduce=False)
_def("logcumsumexp", single_dim=True, reduce=False)
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
_def("softmax", single_dim=True, reduce=False)
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
# stuff to handle in the future, because they require special
# binding logic for dims
# cross
# diag_embed
# diagonal
# diagonal_scatter
# diff
# nanquantile
# quantile
# roll
# rot90
# topk (new dimes on output)
# should these all be subsumed by inplace indexing?
# index_add_
# index_add
# index_copy
# index_copy_
# index_fill
# index_fill_
# index_select
# scatter
# scatter_
# scatter_add
# scatter_add_
# scatter_reduce
|