File size: 5,242 Bytes
7e50900 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
#pragma once
#include <ATen/ExpandUtils.h>
#include <ATen/native/CanUse32BitIndexMath.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>
namespace at { namespace native {
[[noreturn]]
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
}
static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
std::vector<Tensor> result;
for (const auto& index_opt : indices) {
if (!index_opt.has_value()) {
result.emplace_back();
} else {
const auto& index = *index_opt;
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
if (index.scalar_type() == kByte) {
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
" please use a dtype torch.bool instead.");
}
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self
for (const auto j : c10::irange(index.dim())) {
int64_t srcIdx = result.size() + j;
if (index.size(j) != self.size(srcIdx)) {
invalid_mask(self, srcIdx, index, j);
}
}
// Replace with nonzeros
auto nonzero = index.nonzero();
for (const auto j : c10::irange(index.dim())) {
result.emplace_back(nonzero.select(1, j));
}
} else {
result.emplace_back(std::move(index));
}
}
}
return result;
}
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices) {
for (const auto& tensor : indices) {
if (tensor.has_value() && tensor->defined()) {
auto scalarType = tensor->scalar_type();
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
}
}
}
}
inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
torch::List<c10::optional<Tensor>> result;
result.reserve(list.size());
for (const Tensor& a : list) {
result.push_back(a);
}
return result;
}
inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
torch::List<c10::optional<Tensor>> result;
result.reserve(list.size());
for (const IValue& a : list) {
result.push_back(a.isTensor() ? c10::optional<Tensor>(a.toTensor()) : c10::optional<Tensor>());
}
return result;
}
static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
// true if all the non-null tensors are adjacent
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
auto start = std::find_if(tl.begin(), tl.end(), isDefined);
auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
auto it = std::find_if(start, stop.base(), isNull);
return it == stop.base();
}
// Transposes the tensor and indices together so that all the non-null indices
// index the first k dimensions of the tensor. Returns the transposed tensor
// and the reordered indices. For example:
// transposeToFront(tensor, {nullptr, a, nullptr, b})
// returns
// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
transposeToFront(Tensor self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
for (const auto i : c10::irange(self.dim())) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (const auto i : c10::irange(self.dim())) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
}
inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
transposeToFrontAndInvPerm(Tensor self, TensorList indices) {
std::vector<int64_t> dims;
std::vector<int64_t> invPerm;
std::vector<Tensor> transposedIndices;
dims.reserve(self.dim());
invPerm.resize(self.dim());
for (const auto i : c10::irange(self.dim())) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back(indices[i]);
}
}
for (const auto i : c10::irange(self.dim())) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.emplace_back();
}
}
for (const auto i : c10::irange(self.dim())) {
invPerm[dims[i]] = i;
}
return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
}
struct AdvancedIndex {
AdvancedIndex(const Tensor& src, TensorList indices);
Tensor src;
std::vector<Tensor> indices;
DimVector indexed_sizes;
DimVector indexed_strides;
int64_t dims_before;
int64_t dims_after;
};
}}
|