|
#pragma once |
|
|
|
#include <ATen/CollapseDims.h> |
|
#include <ATen/Parallel.h> |
|
#include <ATen/TensorUtils.h> |
|
#include <c10/util/irange.h> |
|
#include <cstring> |
|
#include <limits> |
|
#include <utility> |
|
|
|
namespace at { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline Tensor sort_strides(Tensor& tensor_) { |
|
IntArrayRef strides = tensor_.strides(); |
|
std::vector<int64_t> indices; |
|
indices.reserve(tensor_.ndimension()); |
|
for (const auto i : c10::irange(tensor_.ndimension())) { |
|
indices.push_back(i); |
|
} |
|
std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) { |
|
return strides[i1] > strides[i2]; |
|
}); |
|
Tensor tensor = tensor_.permute(indices); |
|
return tensor; |
|
} |
|
|
|
template <typename T, int N> |
|
struct strided_tensor_iter_fixed { |
|
public: |
|
T* data_ = NULL; |
|
int64_t dim_ = 0; |
|
|
|
int64_t counter_[N] = {0}; |
|
int64_t sizes_[N] = {0}; |
|
int64_t strides_[N] = {0}; |
|
|
|
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete; |
|
void operator=(strided_tensor_iter_fixed const& x) = delete; |
|
strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default; |
|
strided_tensor_iter_fixed(Tensor& tensor, bool sort_strides = false) |
|
: data_(tensor.data_ptr<T>()) { |
|
(void)sort_strides; |
|
std::memset(counter_, 0, sizeof(int64_t) * N); |
|
if (tensor.dim() > 0) { |
|
std::memcpy( |
|
sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t)); |
|
std::memcpy( |
|
strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t)); |
|
} |
|
dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension())); |
|
} |
|
}; |
|
|
|
template <typename T> |
|
struct strided_tensor_iter { |
|
private: |
|
public: |
|
T* data_ = NULL; |
|
int64_t dim_; |
|
|
|
std::vector<int64_t> counter_; |
|
std::vector<int64_t> sizes_; |
|
std::vector<int64_t> strides_; |
|
|
|
strided_tensor_iter(strided_tensor_iter const&) = delete; |
|
void operator=(strided_tensor_iter const& x) = delete; |
|
strided_tensor_iter(strided_tensor_iter&&) = default; |
|
strided_tensor_iter(Tensor& tensor) |
|
: data_(tensor.data_ptr<T>()), |
|
dim_(tensor.ndimension()), |
|
counter_(dim_, 0), |
|
sizes_(tensor.sizes().vec()), |
|
strides_(tensor.strides().vec()) { |
|
dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_)); |
|
} |
|
}; |
|
|
|
inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) { |
|
if (tensors.size() == 0) |
|
return true; |
|
int64_t all_numel = tensors[0].numel(); |
|
for (const auto i : c10::irange(1, tensors.size())) { |
|
if (tensors[i].numel() != all_numel) |
|
return false; |
|
} |
|
return true; |
|
} |
|
|
|
inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) { |
|
std::ostringstream oss; |
|
oss << "inconsistent tensor size, expected "; |
|
for (size_t i = 0; i < tensors.size() - 1; i++) { |
|
oss << tensors[i].sizes() << ", "; |
|
} |
|
oss << "and " << tensors[tensors.size() - 1].sizes() |
|
<< " to have the same number of elements, but got "; |
|
for (size_t i = 0; i < tensors.size() - 1; i++) { |
|
oss << tensors[i].numel() << ", "; |
|
} |
|
oss << "and " << tensors[tensors.size() - 1].numel() |
|
<< " elements respectively"; |
|
return oss.str(); |
|
} |
|
|
|
inline bool _apply_preamble(ArrayRef<Tensor> tensors) { |
|
checkDeviceType("CPU_tensor_apply", tensors, kCPU); |
|
checkLayout("CPU_tensor_apply", tensors, kStrided); |
|
if (!_all_equal_numel(tensors)) |
|
AT_ERROR(_all_equal_numel_error(tensors)); |
|
|
|
for (auto& t : tensors) |
|
if (t.numel() == 0) |
|
return false; |
|
return true; |
|
} |
|
|
|
inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) { |
|
int64_t dim = 0; |
|
for (auto& t : tensors) |
|
dim = std::max(dim, t.ndimension()); |
|
return dim; |
|
} |
|
|
|
inline void iterate(int64_t ){}; |
|
|
|
template <typename Arg, typename... Args> |
|
inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) { |
|
iter.counter_[iter.dim_ - 1] += size; |
|
iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1]; |
|
iterate(size, iter_tail...); |
|
} |
|
|
|
inline bool iterate_continue() { |
|
return true; |
|
}; |
|
|
|
template <typename Arg, typename... Args> |
|
inline bool iterate_continue(Arg& iter, Args&... iter_tail) { |
|
return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] && |
|
iterate_continue(iter_tail...); |
|
} |
|
|
|
inline int64_t max_iterate_size() { |
|
return std::numeric_limits<int64_t>::max(); |
|
}; |
|
|
|
template <typename Arg, typename... Args> |
|
inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) { |
|
return std::min( |
|
(iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]), |
|
max_iterate_size(iter_tail...)); |
|
} |
|
|
|
inline void iterate_overflow(){}; |
|
|
|
template <typename Arg, typename... Args> |
|
inline void iterate_overflow(Arg& iter, Args&... iter_tail) { |
|
if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) { |
|
for (int64_t i = iter.dim_ - 1; i > 0; i--) { |
|
if (iter.counter_[i] == iter.sizes_[i]) { |
|
iter.counter_[i] = 0; |
|
iter.counter_[i - 1]++; |
|
iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) + |
|
iter.strides_[i - 1]; |
|
} |
|
} |
|
} |
|
iterate_overflow(iter_tail...); |
|
} |
|
|
|
inline void forward(int64_t ){}; |
|
|
|
template <typename Arg, typename... Args> |
|
inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) { |
|
int64_t multi = offset; |
|
for (int64_t i = iter.dim_ - 1; i >= 0; i--) { |
|
int64_t inc = multi % iter.sizes_[i]; |
|
multi = multi / iter.sizes_[i]; |
|
iter.data_ = iter.data_ + inc * iter.strides_[i]; |
|
iter.counter_[i] += inc; |
|
} |
|
forward(offset, iter_tail...); |
|
} |
|
|
|
inline int64_t max_dim() { |
|
return 0; |
|
} |
|
|
|
template <typename Arg, typename... Args> |
|
inline int64_t max_dim(Arg& iter, Args&... iter_tail) { |
|
return std::max(iter.dim_, max_dim(iter_tail...)); |
|
} |
|
|
|
inline void apply_op(){}; |
|
|
|
template <typename Op, typename... Args> |
|
inline void apply_op( |
|
int64_t numel, |
|
int64_t offset, |
|
const Op& op, |
|
Args... iters) { |
|
|
|
if (numel == 1 && max_dim(iters...) == 0) { |
|
op(*iters.data_...); |
|
return; |
|
} |
|
if (offset > 0) |
|
forward(offset, iters...); |
|
|
|
for (int64_t i = 0; i < numel;) { |
|
for (; iterate_continue(iters...) && i < numel;) { |
|
op(*iters.data_...); |
|
iterate(1, iters...); |
|
i++; |
|
} |
|
iterate_overflow(iters...); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar1, typename scalar2, typename Op> |
|
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) { |
|
if (!_apply_preamble({tensor1, tensor2})) |
|
return; |
|
if (_max_dim_tensors({tensor1, tensor2}) <= 8) { |
|
apply_op( |
|
tensor1.numel(), |
|
0, |
|
op, |
|
strided_tensor_iter_fixed<scalar1, 8>(tensor1), |
|
strided_tensor_iter_fixed<scalar2, 8>(tensor2)); |
|
} else { |
|
apply_op( |
|
tensor1.numel(), |
|
0, |
|
op, |
|
strided_tensor_iter<scalar1>(tensor1), |
|
strided_tensor_iter<scalar2>(tensor2)); |
|
} |
|
} |
|
|
|
template <typename scalar1, typename scalar2, typename scalar3, typename Op> |
|
inline void CPU_tensor_apply3( |
|
Tensor tensor1, |
|
Tensor tensor2, |
|
Tensor tensor3, |
|
const Op op) { |
|
if (!_apply_preamble({tensor1, tensor2, tensor3})) |
|
return; |
|
if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) { |
|
apply_op( |
|
tensor1.numel(), |
|
0, |
|
op, |
|
strided_tensor_iter_fixed<scalar1, 8>(tensor1), |
|
strided_tensor_iter_fixed<scalar2, 8>(tensor2), |
|
strided_tensor_iter_fixed<scalar3, 8>(tensor3)); |
|
} else { |
|
apply_op( |
|
tensor1.numel(), |
|
0, |
|
op, |
|
strided_tensor_iter<scalar1>(tensor1), |
|
strided_tensor_iter<scalar2>(tensor2), |
|
strided_tensor_iter<scalar3>(tensor3)); |
|
} |
|
} |
|
|
|
template < |
|
typename scalar1, |
|
typename scalar2, |
|
typename scalar3, |
|
typename scalar4, |
|
typename Op> |
|
inline void CPU_tensor_apply4( |
|
Tensor tensor1, |
|
Tensor tensor2, |
|
Tensor tensor3, |
|
Tensor tensor4, |
|
const Op op) { |
|
if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4})) |
|
return; |
|
if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) { |
|
apply_op( |
|
tensor1.numel(), |
|
0, |
|
op, |
|
strided_tensor_iter_fixed<scalar1, 8>(tensor1), |
|
strided_tensor_iter_fixed<scalar2, 8>(tensor2), |
|
strided_tensor_iter_fixed<scalar3, 8>(tensor3), |
|
strided_tensor_iter_fixed<scalar4, 8>(tensor4)); |
|
} else { |
|
apply_op( |
|
tensor1.numel(), |
|
0, |
|
op, |
|
strided_tensor_iter<scalar1>(tensor1), |
|
strided_tensor_iter<scalar2>(tensor2), |
|
strided_tensor_iter<scalar3>(tensor3), |
|
strided_tensor_iter<scalar4>(tensor4)); |
|
} |
|
} |
|
|
|
} |
|
|