|
#pragma once |
|
|
|
#include <ATen/ceil_div.h> |
|
#include <ATen/cuda/DeviceUtils.cuh> |
|
#include <ATen/cuda/AsmUtils.cuh> |
|
#include <c10/macros/Macros.h> |
|
|
|
|
|
|
|
namespace at { |
|
namespace cuda { |
|
|
|
|
|
|
|
template <typename T, bool KillWARDependency, class BinaryFunction> |
|
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) { |
|
|
|
#if defined (USE_ROCM) |
|
unsigned long long int vote = WARP_BALLOT(in); |
|
T index = __popcll(getLaneMaskLe() & vote); |
|
T carry = __popcll(vote); |
|
#else |
|
T vote = WARP_BALLOT(in); |
|
T index = __popc(getLaneMaskLe() & vote); |
|
T carry = __popc(vote); |
|
#endif |
|
|
|
int warp = threadIdx.x / C10_WARP_SIZE; |
|
|
|
|
|
if (getLaneId() == 0) { |
|
smem[warp] = carry; |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
if (threadIdx.x == 0) { |
|
int current = 0; |
|
for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { |
|
T v = smem[i]; |
|
smem[i] = binop(smem[i], current); |
|
current = binop(current, v); |
|
} |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
if (warp >= 1) { |
|
index = binop(index, smem[warp - 1]); |
|
} |
|
|
|
*out = index; |
|
|
|
if (KillWARDependency) { |
|
__syncthreads(); |
|
} |
|
} |
|
|
|
|
|
|
|
template <typename T, bool KillWARDependency, class BinaryFunction> |
|
__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) { |
|
inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop); |
|
|
|
|
|
*out -= (T) in; |
|
|
|
|
|
*carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1]; |
|
|
|
if (KillWARDependency) { |
|
__syncthreads(); |
|
} |
|
} |
|
|
|
}} |
|
|