|
#include <ATen/ATen.h> |
|
|
|
#include <cuda_fp16.h> |
|
|
|
#include <vector> |
|
|
|
#include "utils/checks.h" |
|
#include "utils/cuda.cuh" |
|
#include "inplace_abn.h" |
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
|
|
|
|
struct SumOpH { |
|
__device__ SumOpH(const half *t, int c, int s) |
|
: tensor(t), chn(c), sp(s) {} |
|
__device__ __forceinline__ float operator()(int batch, int plane, int n) { |
|
return __half2float(tensor[(batch * chn + plane) * sp + n]); |
|
} |
|
const half *tensor; |
|
const int chn; |
|
const int sp; |
|
}; |
|
|
|
struct VarOpH { |
|
__device__ VarOpH(float m, const half *t, int c, int s) |
|
: mean(m), tensor(t), chn(c), sp(s) {} |
|
__device__ __forceinline__ float operator()(int batch, int plane, int n) { |
|
const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]); |
|
return (t - mean) * (t - mean); |
|
} |
|
const float mean; |
|
const half *tensor; |
|
const int chn; |
|
const int sp; |
|
}; |
|
|
|
struct GradOpH { |
|
__device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s) |
|
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} |
|
__device__ __forceinline__ Pair<float> operator()(int batch, int plane, int n) { |
|
float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight; |
|
float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); |
|
return Pair<float>(_dz, _y * _dz); |
|
} |
|
const float weight; |
|
const float bias; |
|
const half *z; |
|
const half *dz; |
|
const int chn; |
|
const int sp; |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
__global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) { |
|
int plane = blockIdx.x; |
|
float norm = 1.f / static_cast<float>(num * sp); |
|
|
|
float _mean = reduce<float, SumOpH>(SumOpH(x, chn, sp), plane, num, sp) * norm; |
|
__syncthreads(); |
|
float _var = reduce<float, VarOpH>(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm; |
|
|
|
if (threadIdx.x == 0) { |
|
mean[plane] = _mean; |
|
var[plane] = _var; |
|
} |
|
} |
|
|
|
std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x) { |
|
CHECK_CUDA_INPUT(x); |
|
|
|
|
|
int64_t num, chn, sp; |
|
get_dims(x, num, chn, sp); |
|
|
|
|
|
auto mean = at::empty({chn},x.options().dtype(at::kFloat)); |
|
auto var = at::empty({chn},x.options().dtype(at::kFloat)); |
|
|
|
|
|
dim3 blocks(chn); |
|
dim3 threads(getNumThreads(sp)); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
mean_var_kernel_h<<<blocks, threads, 0, stream>>>( |
|
reinterpret_cast<half*>(x.data<at::Half>()), |
|
mean.data<float>(), |
|
var.data<float>(), |
|
num, chn, sp); |
|
|
|
return {mean, var}; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
__global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias, |
|
bool affine, float eps, int num, int chn, int sp) { |
|
int plane = blockIdx.x; |
|
|
|
const float _mean = mean[plane]; |
|
const float _var = var[plane]; |
|
const float _weight = affine ? abs(weight[plane]) + eps : 1.f; |
|
const float _bias = affine ? bias[plane] : 0.f; |
|
|
|
const float mul = rsqrt(_var + eps) * _weight; |
|
|
|
for (int batch = 0; batch < num; ++batch) { |
|
for (int n = threadIdx.x; n < sp; n += blockDim.x) { |
|
half *x_ptr = x + (batch * chn + plane) * sp + n; |
|
float _x = __half2float(*x_ptr); |
|
float _y = (_x - _mean) * mul + _bias; |
|
|
|
*x_ptr = __float2half(_y); |
|
} |
|
} |
|
} |
|
|
|
at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, |
|
bool affine, float eps) { |
|
CHECK_CUDA_INPUT(x); |
|
CHECK_CUDA_INPUT(mean); |
|
CHECK_CUDA_INPUT(var); |
|
CHECK_CUDA_INPUT(weight); |
|
CHECK_CUDA_INPUT(bias); |
|
|
|
|
|
int64_t num, chn, sp; |
|
get_dims(x, num, chn, sp); |
|
|
|
|
|
dim3 blocks(chn); |
|
dim3 threads(getNumThreads(sp)); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
forward_kernel_h<<<blocks, threads, 0, stream>>>( |
|
reinterpret_cast<half*>(x.data<at::Half>()), |
|
mean.data<float>(), |
|
var.data<float>(), |
|
weight.data<float>(), |
|
bias.data<float>(), |
|
affine, eps, num, chn, sp); |
|
|
|
return x; |
|
} |
|
|
|
__global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias, |
|
float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) { |
|
int plane = blockIdx.x; |
|
|
|
float _weight = affine ? abs(weight[plane]) + eps : 1.f; |
|
float _bias = affine ? bias[plane] : 0.f; |
|
|
|
Pair<float> res = reduce<Pair<float>, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp); |
|
__syncthreads(); |
|
|
|
if (threadIdx.x == 0) { |
|
edz[plane] = res.v1; |
|
eydz[plane] = res.v2; |
|
} |
|
} |
|
|
|
std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, |
|
bool affine, float eps) { |
|
CHECK_CUDA_INPUT(z); |
|
CHECK_CUDA_INPUT(dz); |
|
CHECK_CUDA_INPUT(weight); |
|
CHECK_CUDA_INPUT(bias); |
|
|
|
|
|
int64_t num, chn, sp; |
|
get_dims(z, num, chn, sp); |
|
|
|
auto edz = at::empty({chn},z.options().dtype(at::kFloat)); |
|
auto eydz = at::empty({chn},z.options().dtype(at::kFloat)); |
|
|
|
|
|
dim3 blocks(chn); |
|
dim3 threads(getNumThreads(sp)); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
edz_eydz_kernel_h<<<blocks, threads, 0, stream>>>( |
|
reinterpret_cast<half*>(z.data<at::Half>()), |
|
reinterpret_cast<half*>(dz.data<at::Half>()), |
|
weight.data<float>(), |
|
bias.data<float>(), |
|
edz.data<float>(), |
|
eydz.data<float>(), |
|
affine, eps, num, chn, sp); |
|
|
|
return {edz, eydz}; |
|
} |
|
|
|
__global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz, |
|
const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) { |
|
int plane = blockIdx.x; |
|
|
|
float _weight = affine ? abs(weight[plane]) + eps : 1.f; |
|
float _bias = affine ? bias[plane] : 0.f; |
|
float _var = var[plane]; |
|
float _edz = edz[plane]; |
|
float _eydz = eydz[plane]; |
|
|
|
float _mul = _weight * rsqrt(_var + eps); |
|
float count = float(num * sp); |
|
|
|
for (int batch = 0; batch < num; ++batch) { |
|
for (int n = threadIdx.x; n < sp; n += blockDim.x) { |
|
float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); |
|
float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight; |
|
|
|
dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul); |
|
} |
|
} |
|
} |
|
|
|
at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, |
|
at::Tensor edz, at::Tensor eydz, bool affine, float eps) { |
|
CHECK_CUDA_INPUT(z); |
|
CHECK_CUDA_INPUT(dz); |
|
CHECK_CUDA_INPUT(var); |
|
CHECK_CUDA_INPUT(weight); |
|
CHECK_CUDA_INPUT(bias); |
|
CHECK_CUDA_INPUT(edz); |
|
CHECK_CUDA_INPUT(eydz); |
|
|
|
|
|
int64_t num, chn, sp; |
|
get_dims(z, num, chn, sp); |
|
|
|
auto dx = at::zeros_like(z); |
|
|
|
|
|
dim3 blocks(chn); |
|
dim3 threads(getNumThreads(sp)); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
backward_kernel_h<<<blocks, threads, 0, stream>>>( |
|
reinterpret_cast<half*>(z.data<at::Half>()), |
|
reinterpret_cast<half*>(dz.data<at::Half>()), |
|
var.data<float>(), |
|
weight.data<float>(), |
|
bias.data<float>(), |
|
edz.data<float>(), |
|
eydz.data<float>(), |
|
reinterpret_cast<half*>(dx.data<at::Half>()), |
|
affine, eps, num, chn, sp); |
|
|
|
return dx; |
|
} |
|
|
|
__global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) { |
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){ |
|
float _z = __half2float(z[i]); |
|
if (_z < 0) { |
|
dz[i] = __float2half(__half2float(dz[i]) * slope); |
|
z[i] = __float2half(_z / slope); |
|
} |
|
} |
|
} |
|
|
|
void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) { |
|
CHECK_CUDA_INPUT(z); |
|
CHECK_CUDA_INPUT(dz); |
|
|
|
int64_t count = z.numel(); |
|
dim3 threads(getNumThreads(count)); |
|
dim3 blocks = (count + threads.x - 1) / threads.x; |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
leaky_relu_backward_impl_h<<<blocks, threads, 0, stream>>>( |
|
reinterpret_cast<half*>(z.data<at::Half>()), |
|
reinterpret_cast<half*>(dz.data<at::Half>()), |
|
slope, count); |
|
} |
|
|
|
|