Roopansh's picture
Initial Commit
73c83cf
#include <ATen/ATen.h>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <vector>
#include "utils/checks.h"
#include "utils/cuda.cuh"
#include "inplace_abn.h"
#include <ATen/cuda/CUDAContext.h>
// Operations for reduce
template<typename T>
struct SumOp {
__device__ SumOp(const T *t, int c, int s)
: tensor(t), chn(c), sp(s) {}
__device__ __forceinline__ T operator()(int batch, int plane, int n) {
return tensor[(batch * chn + plane) * sp + n];
}
const T *tensor;
const int chn;
const int sp;
};
template<typename T>
struct VarOp {
__device__ VarOp(T m, const T *t, int c, int s)
: mean(m), tensor(t), chn(c), sp(s) {}
__device__ __forceinline__ T operator()(int batch, int plane, int n) {
T val = tensor[(batch * chn + plane) * sp + n];
return (val - mean) * (val - mean);
}
const T mean;
const T *tensor;
const int chn;
const int sp;
};
template<typename T>
struct GradOp {
__device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
__device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
T _dz = dz[(batch * chn + plane) * sp + n];
return Pair<T>(_dz, _y * _dz);
}
const T weight;
const T bias;
const T *z;
const T *dz;
const int chn;
const int sp;
};
/***********
* mean_var
***********/
template<typename T>
__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
int plane = blockIdx.x;
T norm = T(1) / T(num * sp);
T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm;
__syncthreads();
T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm;
if (threadIdx.x == 0) {
mean[plane] = _mean;
var[plane] = _var;
}
}
std::vector<at::Tensor> mean_var_cuda(at::Tensor x) {
CHECK_CUDA_INPUT(x);
// Extract dimensions
int64_t num, chn, sp;
get_dims(x, num, chn, sp);
// Prepare output tensors
auto mean = at::empty({chn}, x.options());
auto var = at::empty({chn}, x.options());
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
x.data<scalar_t>(),
mean.data<scalar_t>(),
var.data<scalar_t>(),
num, chn, sp);
}));
return {mean, var};
}
/**********
* forward
**********/
template<typename T>
__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
bool affine, float eps, int num, int chn, int sp) {
int plane = blockIdx.x;
T _mean = mean[plane];
T _var = var[plane];
T _weight = affine ? abs(weight[plane]) + eps : T(1);
T _bias = affine ? bias[plane] : T(0);
T mul = rsqrt(_var + eps) * _weight;
for (int batch = 0; batch < num; ++batch) {
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
T _x = x[(batch * chn + plane) * sp + n];
T _y = (_x - _mean) * mul + _bias;
x[(batch * chn + plane) * sp + n] = _y;
}
}
}
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
bool affine, float eps) {
CHECK_CUDA_INPUT(x);
CHECK_CUDA_INPUT(mean);
CHECK_CUDA_INPUT(var);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
// Extract dimensions
int64_t num, chn, sp;
get_dims(x, num, chn, sp);
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
x.data<scalar_t>(),
mean.data<scalar_t>(),
var.data<scalar_t>(),
weight.data<scalar_t>(),
bias.data<scalar_t>(),
affine, eps, num, chn, sp);
}));
return x;
}
/***********
* edz_eydz
***********/
template<typename T>
__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
int plane = blockIdx.x;
T _weight = affine ? abs(weight[plane]) + eps : 1.f;
T _bias = affine ? bias[plane] : 0.f;
Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp);
__syncthreads();
if (threadIdx.x == 0) {
edz[plane] = res.v1;
eydz[plane] = res.v2;
}
}
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
bool affine, float eps) {
CHECK_CUDA_INPUT(z);
CHECK_CUDA_INPUT(dz);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
// Extract dimensions
int64_t num, chn, sp;
get_dims(z, num, chn, sp);
auto edz = at::empty({chn}, z.options());
auto eydz = at::empty({chn}, z.options());
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
z.data<scalar_t>(),
dz.data<scalar_t>(),
weight.data<scalar_t>(),
bias.data<scalar_t>(),
edz.data<scalar_t>(),
eydz.data<scalar_t>(),
affine, eps, num, chn, sp);
}));
return {edz, eydz};
}
/***********
* backward
***********/
template<typename T>
__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
int plane = blockIdx.x;
T _weight = affine ? abs(weight[plane]) + eps : 1.f;
T _bias = affine ? bias[plane] : 0.f;
T _var = var[plane];
T _edz = edz[plane];
T _eydz = eydz[plane];
T _mul = _weight * rsqrt(_var + eps);
T count = T(num * sp);
for (int batch = 0; batch < num; ++batch) {
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
T _dz = dz[(batch * chn + plane) * sp + n];
T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
}
}
}
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
CHECK_CUDA_INPUT(z);
CHECK_CUDA_INPUT(dz);
CHECK_CUDA_INPUT(var);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
CHECK_CUDA_INPUT(edz);
CHECK_CUDA_INPUT(eydz);
// Extract dimensions
int64_t num, chn, sp;
get_dims(z, num, chn, sp);
auto dx = at::zeros_like(z);
// Run kernel
dim3 blocks(chn);
dim3 threads(getNumThreads(sp));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
z.data<scalar_t>(),
dz.data<scalar_t>(),
var.data<scalar_t>(),
weight.data<scalar_t>(),
bias.data<scalar_t>(),
edz.data<scalar_t>(),
eydz.data<scalar_t>(),
dx.data<scalar_t>(),
affine, eps, num, chn, sp);
}));
return dx;
}
/**************
* activations
**************/
template<typename T>
inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
// Create thrust pointers
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
auto stream = at::cuda::getCurrentCUDAStream();
thrust::transform_if(thrust::cuda::par.on(stream),
th_dz, th_dz + count, th_z, th_dz,
[slope] __device__ (const T& dz) { return dz * slope; },
[] __device__ (const T& z) { return z < 0; });
thrust::transform_if(thrust::cuda::par.on(stream),
th_z, th_z + count, th_z,
[slope] __device__ (const T& z) { return z / slope; },
[] __device__ (const T& z) { return z < 0; });
}
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
CHECK_CUDA_INPUT(z);
CHECK_CUDA_INPUT(dz);
int64_t count = z.numel();
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
}));
}
template<typename T>
inline void elu_backward_impl(T *z, T *dz, int64_t count) {
// Create thrust pointers
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
auto stream = at::cuda::getCurrentCUDAStream();
thrust::transform_if(thrust::cuda::par.on(stream),
th_dz, th_dz + count, th_z, th_z, th_dz,
[] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
[] __device__ (const T& z) { return z < 0; });
thrust::transform_if(thrust::cuda::par.on(stream),
th_z, th_z + count, th_z,
[] __device__ (const T& z) { return log1p(z); },
[] __device__ (const T& z) { return z < 0; });
}
void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
CHECK_CUDA_INPUT(z);
CHECK_CUDA_INPUT(dz);
int64_t count = z.numel();
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count);
}));
}