Spaces:
Running
on
Zero
Running
on
Zero
#include <ATen/ATen.h> | |
#include <cuda_fp16.h> | |
#include <vector> | |
#include "utils/checks.h" | |
#include "utils/cuda.cuh" | |
#include "inplace_abn.h" | |
#include <ATen/cuda/CUDAContext.h> | |
// Operations for reduce | |
struct SumOpH { | |
__device__ SumOpH(const half *t, int c, int s) | |
: tensor(t), chn(c), sp(s) {} | |
__device__ __forceinline__ float operator()(int batch, int plane, int n) { | |
return __half2float(tensor[(batch * chn + plane) * sp + n]); | |
} | |
const half *tensor; | |
const int chn; | |
const int sp; | |
}; | |
struct VarOpH { | |
__device__ VarOpH(float m, const half *t, int c, int s) | |
: mean(m), tensor(t), chn(c), sp(s) {} | |
__device__ __forceinline__ float operator()(int batch, int plane, int n) { | |
const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]); | |
return (t - mean) * (t - mean); | |
} | |
const float mean; | |
const half *tensor; | |
const int chn; | |
const int sp; | |
}; | |
struct GradOpH { | |
__device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s) | |
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} | |
__device__ __forceinline__ Pair<float> operator()(int batch, int plane, int n) { | |
float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight; | |
float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); | |
return Pair<float>(_dz, _y * _dz); | |
} | |
const float weight; | |
const float bias; | |
const half *z; | |
const half *dz; | |
const int chn; | |
const int sp; | |
}; | |
/*********** | |
* mean_var | |
***********/ | |
__global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) { | |
int plane = blockIdx.x; | |
float norm = 1.f / static_cast<float>(num * sp); | |
float _mean = reduce<float, SumOpH>(SumOpH(x, chn, sp), plane, num, sp) * norm; | |
__syncthreads(); | |
float _var = reduce<float, VarOpH>(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm; | |
if (threadIdx.x == 0) { | |
mean[plane] = _mean; | |
var[plane] = _var; | |
} | |
} | |
std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x) { | |
CHECK_CUDA_INPUT(x); | |
// Extract dimensions | |
int64_t num, chn, sp; | |
get_dims(x, num, chn, sp); | |
// Prepare output tensors | |
auto mean = at::empty({chn},x.options().dtype(at::kFloat)); | |
auto var = at::empty({chn},x.options().dtype(at::kFloat)); | |
// Run kernel | |
dim3 blocks(chn); | |
dim3 threads(getNumThreads(sp)); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
mean_var_kernel_h<<<blocks, threads, 0, stream>>>( | |
reinterpret_cast<half*>(x.data<at::Half>()), | |
mean.data<float>(), | |
var.data<float>(), | |
num, chn, sp); | |
return {mean, var}; | |
} | |
/********** | |
* forward | |
**********/ | |
__global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias, | |
bool affine, float eps, int num, int chn, int sp) { | |
int plane = blockIdx.x; | |
const float _mean = mean[plane]; | |
const float _var = var[plane]; | |
const float _weight = affine ? abs(weight[plane]) + eps : 1.f; | |
const float _bias = affine ? bias[plane] : 0.f; | |
const float mul = rsqrt(_var + eps) * _weight; | |
for (int batch = 0; batch < num; ++batch) { | |
for (int n = threadIdx.x; n < sp; n += blockDim.x) { | |
half *x_ptr = x + (batch * chn + plane) * sp + n; | |
float _x = __half2float(*x_ptr); | |
float _y = (_x - _mean) * mul + _bias; | |
*x_ptr = __float2half(_y); | |
} | |
} | |
} | |
at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, | |
bool affine, float eps) { | |
CHECK_CUDA_INPUT(x); | |
CHECK_CUDA_INPUT(mean); | |
CHECK_CUDA_INPUT(var); | |
CHECK_CUDA_INPUT(weight); | |
CHECK_CUDA_INPUT(bias); | |
// Extract dimensions | |
int64_t num, chn, sp; | |
get_dims(x, num, chn, sp); | |
// Run kernel | |
dim3 blocks(chn); | |
dim3 threads(getNumThreads(sp)); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
forward_kernel_h<<<blocks, threads, 0, stream>>>( | |
reinterpret_cast<half*>(x.data<at::Half>()), | |
mean.data<float>(), | |
var.data<float>(), | |
weight.data<float>(), | |
bias.data<float>(), | |
affine, eps, num, chn, sp); | |
return x; | |
} | |
__global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias, | |
float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) { | |
int plane = blockIdx.x; | |
float _weight = affine ? abs(weight[plane]) + eps : 1.f; | |
float _bias = affine ? bias[plane] : 0.f; | |
Pair<float> res = reduce<Pair<float>, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp); | |
__syncthreads(); | |
if (threadIdx.x == 0) { | |
edz[plane] = res.v1; | |
eydz[plane] = res.v2; | |
} | |
} | |
std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, | |
bool affine, float eps) { | |
CHECK_CUDA_INPUT(z); | |
CHECK_CUDA_INPUT(dz); | |
CHECK_CUDA_INPUT(weight); | |
CHECK_CUDA_INPUT(bias); | |
// Extract dimensions | |
int64_t num, chn, sp; | |
get_dims(z, num, chn, sp); | |
auto edz = at::empty({chn},z.options().dtype(at::kFloat)); | |
auto eydz = at::empty({chn},z.options().dtype(at::kFloat)); | |
// Run kernel | |
dim3 blocks(chn); | |
dim3 threads(getNumThreads(sp)); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
edz_eydz_kernel_h<<<blocks, threads, 0, stream>>>( | |
reinterpret_cast<half*>(z.data<at::Half>()), | |
reinterpret_cast<half*>(dz.data<at::Half>()), | |
weight.data<float>(), | |
bias.data<float>(), | |
edz.data<float>(), | |
eydz.data<float>(), | |
affine, eps, num, chn, sp); | |
return {edz, eydz}; | |
} | |
__global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz, | |
const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) { | |
int plane = blockIdx.x; | |
float _weight = affine ? abs(weight[plane]) + eps : 1.f; | |
float _bias = affine ? bias[plane] : 0.f; | |
float _var = var[plane]; | |
float _edz = edz[plane]; | |
float _eydz = eydz[plane]; | |
float _mul = _weight * rsqrt(_var + eps); | |
float count = float(num * sp); | |
for (int batch = 0; batch < num; ++batch) { | |
for (int n = threadIdx.x; n < sp; n += blockDim.x) { | |
float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); | |
float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight; | |
dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul); | |
} | |
} | |
} | |
at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, | |
at::Tensor edz, at::Tensor eydz, bool affine, float eps) { | |
CHECK_CUDA_INPUT(z); | |
CHECK_CUDA_INPUT(dz); | |
CHECK_CUDA_INPUT(var); | |
CHECK_CUDA_INPUT(weight); | |
CHECK_CUDA_INPUT(bias); | |
CHECK_CUDA_INPUT(edz); | |
CHECK_CUDA_INPUT(eydz); | |
// Extract dimensions | |
int64_t num, chn, sp; | |
get_dims(z, num, chn, sp); | |
auto dx = at::zeros_like(z); | |
// Run kernel | |
dim3 blocks(chn); | |
dim3 threads(getNumThreads(sp)); | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
backward_kernel_h<<<blocks, threads, 0, stream>>>( | |
reinterpret_cast<half*>(z.data<at::Half>()), | |
reinterpret_cast<half*>(dz.data<at::Half>()), | |
var.data<float>(), | |
weight.data<float>(), | |
bias.data<float>(), | |
edz.data<float>(), | |
eydz.data<float>(), | |
reinterpret_cast<half*>(dx.data<at::Half>()), | |
affine, eps, num, chn, sp); | |
return dx; | |
} | |
__global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) { | |
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){ | |
float _z = __half2float(z[i]); | |
if (_z < 0) { | |
dz[i] = __float2half(__half2float(dz[i]) * slope); | |
z[i] = __float2half(_z / slope); | |
} | |
} | |
} | |
void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) { | |
CHECK_CUDA_INPUT(z); | |
CHECK_CUDA_INPUT(dz); | |
int64_t count = z.numel(); | |
dim3 threads(getNumThreads(count)); | |
dim3 blocks = (count + threads.x - 1) / threads.x; | |
auto stream = at::cuda::getCurrentCUDAStream(); | |
leaky_relu_backward_impl_h<<<blocks, threads, 0, stream>>>( | |
reinterpret_cast<half*>(z.data<at::Half>()), | |
reinterpret_cast<half*>(dz.data<at::Half>()), | |
slope, count); | |
} | |