AbeShinzo0708's picture
Upload 2229 files
7e50900
raw
history blame
6.45 kB
#pragma once
#include <algorithm>
#include <vector>
#include <ATen/div_rtn.h>
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>
#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
TORCH_CHECK( \
T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
"Need " #T " of dimension ", \
DIM, \
" and " #T ".size[", \
DIM_SIZE, \
"] == ", \
SIZE, \
" but got input to be of shape ", \
T.sizes())
namespace at {
namespace native {
namespace internal {
namespace {
inline bool all_positive(IntArrayRef& arr) {
return std::all_of(
arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
}
inline bool all_nonnegative(std::vector<int64_t>& arr) {
return std::all_of(
arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
}
} // namespace
// calculate the rear part of output tensor sizes
template <int64_t dim>
std::vector<int64_t> get_output_size(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride_size,
IntArrayRef pad_size,
IntArrayRef dilation_size) {
std::vector<int64_t> sizes;
for (const auto index : c10::irange(dim)) {
sizes.push_back(
div_rtn<int64_t>(
input.size(index + input.dim() - dim) + 2 * pad_size[index] -
(dilation_size[index] * (kernel_size[index] - 1) + 1),
stride_size[index]) +
1);
}
return sizes;
}
// calculate the sizes of output tensor
template <int64_t dim>
std::vector<int64_t> get_output_size(
const Tensor& input,
const Tensor& weight,
IntArrayRef kernel_size,
IntArrayRef stride_size,
IntArrayRef pad_size,
IntArrayRef dilation_size) {
auto output_size = get_output_size<dim>(
input, kernel_size, stride_size, pad_size, dilation_size);
output_size.insert(output_size.begin(), weight.size(0));
if (input.dim() == dim + 2) {
output_size.insert(output_size.begin(), input.size(0));
}
return output_size;
}
/*
slow_conv_dilated_shape_check - check user-input to dilated convolution
forward and backward functions.
*/
template <int64_t dim>
void slow_conv_dilated_shape_check(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const Tensor& grad_output,
IntArrayRef kernel_size,
IntArrayRef stride_size,
IntArrayRef pad_size,
IntArrayRef dilation_size) {
/*
When the following tensors are defined:
bias, grad_weight, grad_output
then these are assumed to be contiguous without checking
because of these tensors are made contiguous by calling
.contiguous() method or by resizing of zero-sized tensors in
forward/backward functions.
When grad_weight is defined then it is assumed without
checking to have the same shape as weight, see backward
functions.
*/
// Check size arguments
TORCH_CHECK(
kernel_size.size() == dim,
"kernel sizes length should be ",
dim,
", but got ",
kernel_size.size());
TORCH_CHECK(
stride_size.size() == dim,
"strides length should be ",
dim,
", but got ",
stride_size.size());
TORCH_CHECK(
dilation_size.size() == dim,
"dilations length should be ",
dim,
", but got ",
dilation_size.size());
TORCH_CHECK(
pad_size.size() == dim,
"pads length should be ",
dim,
", but got ",
pad_size.size());
TORCH_CHECK(
all_positive(kernel_size),
"kernel size should be greater than zero, but got ",
kernel_size);
TORCH_CHECK(
all_positive(stride_size),
"stride should be greater than zero, but got ",
stride_size);
TORCH_CHECK(
all_positive(dilation_size),
"dilation should be greater than zero, but got ",
dilation_size);
// check input
TORCH_CHECK(input.defined(), "input must be defined");
bool is_batch = input.dim() == dim + 2;
int64_t n = (is_batch ? 2 : 1);
int64_t ndim = n + dim;
if (!is_batch) {
// input dim has to be dim + 1 if not batched
TORCH_CHECK(
input.dim() == dim + 1,
"input must be 4D or 5D tensor but got ",
input.dim(),
"D tensor");
}
// check output sizes
auto output_size = get_output_size<dim>(
input, kernel_size, stride_size, pad_size, dilation_size);
TORCH_CHECK(
all_nonnegative(output_size),
"calculated output size ",
output_size,
" is too small (all sizes must be non-negative)");
// check weight
TORCH_CHECK(weight.defined(), "weight must be defined");
TORCH_CHECK(
weight.dim() == dim + 2,
"weight must be ",
dim + 2,
"D tensor but got ",
weight.dim(),
"D tensor dim=",
dim);
TORCH_CHECK(
weight.sizes().slice(2) == kernel_size,
"weight[2:] shape ",
weight.sizes().slice(2),
" must be equal to kernel_size ",
kernel_size);
TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
// check bias when present
if (bias.defined()) {
TORCH_CHECK(
bias.dim() == 1,
"bias must be 1D tensor but got ",
bias.dim(),
"D tensor");
TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
}
// check grad_output when present
if (grad_output.defined()) {
TORCH_CHECK(
grad_output.dim() == ndim,
"grad_output must be ",
ndim,
"D tensor but got ",
grad_output.dim(),
"D tensor");
if (is_batch) {
TORCH_CHECK(
grad_output.size(0) == input.size(0),
"grad_output.size(0)=",
grad_output.size(0),
" must be input.size(0)=",
input.size(0));
}
TORCH_CHECK(
grad_output.size(n - 1) == weight.size(0),
"grad_output.size(",
n - 1,
")=",
grad_output.size(n - 1),
" must be weight.size(0)=",
weight.size(0));
TORCH_CHECK(
grad_output.sizes().slice(n) == output_size,
"grad_output[",
n,
":] shape",
grad_output.sizes().slice(n),
" must be equal to output size ",
output_size);
}
}
} // namespace internal
} // namespace native
} // namespace at