|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <math_constants.h> |
|
#include <torch/extension.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
|
|
#include <iostream> |
|
|
|
#include "ATen/ATen.h" |
|
#include "ATen/cuda/CUDAContext.h" |
|
#include "compat.h" |
|
|
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") |
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") |
|
#define CHECK_INPUT(x) \ |
|
CHECK_CUDA(x); \ |
|
CHECK_CONTIGUOUS(x) |
|
|
|
__inline__ __device__ float WarpAllReduceMax(float val) { |
|
for (int mask = 1; mask < 32; mask *= 2) { |
|
val = max(val, __shfl_xor_sync(0xffffffff, val, mask)); |
|
} |
|
return val; |
|
} |
|
|
|
__inline__ __device__ float WarpAllReduceSum(float val) { |
|
for (int mask = 1; mask < 32; mask *= 2) { |
|
val += __shfl_xor_sync(0xffffffff, val, mask); |
|
} |
|
return val; |
|
} |
|
|
|
|
|
template<typename T> |
|
__global__ void attn_softmax_inplace_( |
|
T *input, |
|
long long rows, int cols |
|
) { |
|
int threadidx_x = threadIdx.x / 32; |
|
int threadidx_y = threadIdx.x % 32; |
|
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); |
|
int cols_per_thread = (cols + 31) / 32; |
|
int cols_this_thread = cols_per_thread; |
|
|
|
int last_y = (cols / cols_per_thread); |
|
|
|
if (threadidx_y == last_y) { |
|
cols_this_thread = cols - cols_per_thread * last_y; |
|
} |
|
else if (threadidx_y > last_y) { |
|
cols_this_thread = 0; |
|
} |
|
|
|
float buf[32]; |
|
|
|
int lane_id = threadidx_y; |
|
|
|
if (row_offset < rows) { |
|
T *row_input = input + row_offset * cols; |
|
T *row_output = row_input; |
|
|
|
#pragma unroll |
|
for (int i = 0; i < cols_this_thread; i++) { |
|
int idx = lane_id * cols_per_thread + i; |
|
buf[i] = static_cast<float>(row_input[idx]); |
|
} |
|
|
|
float thread_max = -1 * CUDART_INF_F; |
|
#pragma unroll |
|
for (int i = 0; i < cols_this_thread; i++) { |
|
thread_max = max(thread_max, buf[i]); |
|
} |
|
|
|
float warp_max = WarpAllReduceMax(thread_max); |
|
|
|
float thread_sum = 0.f; |
|
#pragma unroll |
|
for (int i = 0; i < cols_this_thread; i++) { |
|
buf[i] = __expf(buf[i] - warp_max); |
|
thread_sum += buf[i]; |
|
} |
|
|
|
float warp_sum = WarpAllReduceSum(thread_sum); |
|
#pragma unroll |
|
for (int i = 0; i < cols_this_thread; i++) { |
|
row_output[lane_id * cols_per_thread + i] = |
|
static_cast<T>(__fdividef(buf[i], warp_sum)); |
|
} |
|
} |
|
} |
|
|
|
|
|
void attn_softmax_inplace_forward_( |
|
at::Tensor input, |
|
long long rows, int cols |
|
) { |
|
CHECK_INPUT(input); |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); |
|
|
|
int grid = (rows + 3) / 4; |
|
dim3 block(128); |
|
|
|
if (input.dtype() == torch::kFloat32) { |
|
attn_softmax_inplace_<float><<<grid, block>>>( |
|
(float *)input.data_ptr(), |
|
rows, cols |
|
); |
|
} |
|
else { |
|
attn_softmax_inplace_<at::BFloat16><<<grid, block>>>( |
|
(at::BFloat16 *)input.data_ptr(), |
|
rows, cols |
|
); |
|
} |
|
} |
|
|
|
|
|
template<typename T> |
|
__global__ void attn_softmax_inplace_grad_( |
|
T *output, |
|
T *d_ov, |
|
T *values, |
|
long long rows, |
|
int cols_output, |
|
int cols_values |
|
) { |
|
int threadidx_x = threadIdx.x / 32; |
|
int threadidx_y = threadIdx.x % 32; |
|
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); |
|
int cols_per_thread = (cols_output + 31) / 32; |
|
int cols_this_thread = cols_per_thread; |
|
int rows_values = cols_output; |
|
|
|
|
|
long long value_row_offset = row_offset - row_offset % rows_values; |
|
int last_y = (cols_output / cols_per_thread); |
|
|
|
if (threadidx_y == last_y) { |
|
cols_this_thread = cols_output - cols_per_thread * last_y; |
|
} |
|
else if (threadidx_y > last_y) { |
|
cols_this_thread = 0; |
|
} |
|
|
|
float y_buf[32]; |
|
float dy_buf[32]; |
|
|
|
int lane_id = threadidx_y; |
|
|
|
if (row_offset < rows) { |
|
T *row_output = output + row_offset * cols_output; |
|
T *row_d_ov = d_ov + row_offset * cols_values; |
|
T *row_values = values + value_row_offset * cols_values; |
|
|
|
float thread_max = -1 * CUDART_INF_F; |
|
|
|
|
|
int value_row_idx = 0; |
|
int value_idx = 0; |
|
#pragma unroll |
|
for (int i = 0; i < cols_this_thread; i++) { |
|
T sum = 0.; |
|
#pragma unroll |
|
for (int j = 0; j < cols_values; j++) { |
|
value_row_idx = ((lane_id * cols_per_thread) + i); |
|
value_idx = value_row_idx * cols_values + j; |
|
sum += row_d_ov[j] * row_values[value_idx]; |
|
} |
|
dy_buf[i] = static_cast<float>(sum); |
|
} |
|
|
|
#pragma unroll |
|
for (int i = 0; i < cols_this_thread; i++) { |
|
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]); |
|
} |
|
|
|
float thread_sum = 0.; |
|
|
|
#pragma unroll |
|
for (int i = 0; i < cols_this_thread; i++) { |
|
thread_sum += y_buf[i] * dy_buf[i]; |
|
} |
|
|
|
float warp_sum = WarpAllReduceSum(thread_sum); |
|
|
|
#pragma unroll |
|
for (int i = 0; i < cols_this_thread; i++) { |
|
row_output[lane_id * cols_per_thread + i] = static_cast<T>( |
|
(dy_buf[i] - warp_sum) * y_buf[i] |
|
); |
|
} |
|
} |
|
} |
|
|
|
|
|
void attn_softmax_inplace_backward_( |
|
at::Tensor output, |
|
at::Tensor d_ov, |
|
at::Tensor values, |
|
long long rows, |
|
int cols_output, |
|
int cols_values |
|
) { |
|
CHECK_INPUT(output); |
|
CHECK_INPUT(d_ov); |
|
CHECK_INPUT(values); |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); |
|
|
|
int grid = (rows + 3) / 4; |
|
dim3 block(128); |
|
|
|
if (output.dtype() == torch::kFloat32) { |
|
attn_softmax_inplace_grad_<float><<<grid, block>>>( |
|
(float *)output.data_ptr(), |
|
(float *)d_ov.data_ptr(), |
|
(float *)values.data_ptr(), |
|
rows, cols_output, cols_values |
|
); |
|
} else { |
|
attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>( |
|
(at::BFloat16 *)output.data_ptr(), |
|
(at::BFloat16 *)d_ov.data_ptr(), |
|
(at::BFloat16 *)values.data_ptr(), |
|
rows, cols_output, cols_values |
|
); |
|
} |
|
} |
|
|