File size: 15,148 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 |
#pragma once
#include <ATen/cuda/cub.h>
#include <cstddef>
#include <type_traits>
#include <iterator>
#include <limits>
#include <ATen/cuda/cub_definitions.cuh>
#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
#include <cub/cub.cuh>
#else
// include cub in a safe manner, see:
// https://github.com/pytorch/pytorch/pull/55292
#undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
#undef CUB_NS_PREFIX
#undef CUB_NS_QUALIFIER
#define CUB_NS_PREFIX namespace at_cuda_detail {
#define CUB_NS_POSTFIX }
#define CUB_NS_QUALIFIER ::at_cuda_detail::cub
#include <cub/cub.cuh>
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
#undef CUB_NS_QUALIFIER
#endif
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>
// handle the temporary storage and 'twice' calls for cub API
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
#ifdef USE_ROCM
#define NO_ROCM(x)
#define ROCM_HIPCUB(x) ::hipcub
#else
#define NO_ROCM(x) x
#define ROCM_HIPCUB(x) x
#endif
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
#if !defined(USE_ROCM)
namespace at_cuda_detail {
#endif
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
template <>
struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
{
static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
unsigned short max_word = 0x7F7F;
return reinterpret_cast<c10::BFloat16&>(max_word);
}
static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
unsigned short lowest_word = 0xFF7F;
return reinterpret_cast<c10::BFloat16&>(lowest_word);
}
};
template <>
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
#if !defined(USE_ROCM)
} // namespace at_cuda_detail
#endif
#endif
#if !defined(USE_ROCM)
namespace at::native {
namespace cub = ::at_cuda_detail::cub;
} // namespace at::native
#endif
namespace at::cuda::cub {
namespace detail {
template<typename T>
struct cuda_type {
using type = T;
};
template<>
struct cuda_type<c10::Half> {
using type = __half;
};
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
template<>
struct cuda_type<c10::BFloat16> {
using type = __nv_bfloat16;
};
#elif defined(USE_ROCM)
template<>
struct cuda_type<c10::BFloat16> {
using type = hip_bfloat16;
};
#endif
} // namespace detail
template<typename key_t, typename value_t, typename OffsetIteratorT>
inline void segmented_sort_pairs(
const key_t *keys_in, key_t *keys_out,
const value_t *values_in, value_t *values_out,
int64_t num_elements, int64_t num_segments,
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr keys_out_owner;
if (keys_out == nullptr) {
keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
}
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
if (descending) {
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
keys_in_, keys_out_, values_in, values_out,
num_elements, num_segments, begin_offsets, end_offsets,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
keys_in_, keys_out_, values_in, values_out,
num_elements, num_segments, begin_offsets, end_offsets,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
}
}
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename KeysOutputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
inline void unique_by_key(
KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
KeysOutputIteratorT keys_out, ValuesOutputIteratorT values_out,
NumSelectedIteratorT num_selected, int64_t num_input_items)
{
// TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
constexpr bool null_keys_out = std::is_same<KeysOutputIteratorT, std::nullptr_t>::value;
using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
using RealKeysOutputIteratorT = typename std::conditional<null_keys_out, KeyT *, KeysOutputIteratorT>::type;
RealKeysOutputIteratorT keys_out_;
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr keys_out_owner;
if constexpr (null_keys_out) {
keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
keys_out_ = static_cast<KeyT *>(keys_out_owner.get());
} else {
keys_out_ = keys_out;
}
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
}
#endif
namespace impl {
template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT>
C10_LAUNCH_BOUNDS_1(1)
__global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){
// NOTE: out here not the final scan output, but an intermediate of the accumulation type.
using acc_t = typename std::iterator_traits<OutputIteratorT>::value_type;
*out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
}
#if !CUB_SUPPORTS_FUTURE_VALUE()
template<typename ValueT, typename InputIteratorT>
struct chained_iterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = ValueT;
using pointer = ValueT*;
using reference = ValueT&;
InputIteratorT iter;
ValueT *first;
difference_type offset = 0;
__device__ ValueT operator[](difference_type i) {
i += offset;
if (i == 0) {
return *first;
} else {
return ValueT(iter[i - 1]);
}
}
__device__ chained_iterator operator+(difference_type i) {
return chained_iterator{iter, first, i};
}
__device__ ValueT operator*() {
return (*this)[0];
}
};
#endif
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
}
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size>
inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
#if defined(USE_ROCM)
//For ROCm, use hipCUB chained iterators
CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
input,
output,
scan_op,
num_items,
at::cuda::getCurrentCUDAStream());
C10_HIP_KERNEL_LAUNCH_CHECK();
#else
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
int size_cub = std::min<int64_t>(num_items, max_cub_size);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input,
output,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
C10_CUDA_KERNEL_LAUNCH_CHECK();
using input_t = typename std::iterator_traits<InputIteratorT>::value_type;
for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr first_elem = allocator->allocate(sizeof(input_t));
auto first_elem_ptr = reinterpret_cast<input_t *>(first_elem.get());
size_cub = std::min<int64_t>(num_items - i, max_cub_size);
impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
output + i - 1,
input + i,
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
using tuple = typename ArgIndexInputIterator::value_type;
auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
if (x.key == 0) {
return *first_elem_ptr;
} else {
return x.value;
}
};
auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>(
ArgIndexInputIterator(input + i), input_iter_transform);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i + 1,
output + i,
scan_op,
::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}
template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
#if defined(USE_ROCM)
//For ROCm, use hipCUB chained iterators
CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
input,
output,
scan_op,
init_value,
num_items,
at::cuda::getCurrentCUDAStream());
C10_HIP_KERNEL_LAUNCH_CHECK();
#else
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
int size_cub = std::min<int64_t>(num_items, max_cub_size);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input,
output,
scan_op,
init_value,
size_cub,
at::cuda::getCurrentCUDAStream());
C10_CUDA_KERNEL_LAUNCH_CHECK();
for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT));
auto first_elem_ptr = reinterpret_cast<InitValueT *>(first_elem.get());
size_cub = std::min<int64_t>(num_items - i, max_cub_size);
impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
output + i - 1,
input + i - 1,
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
input + i, first_elem_ptr};
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i,
output + i,
scan_op,
::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub InclusiveSumByKey does not support more than INT_MAX elements");
CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT>
inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub InclusiveSumByKey does not support more than INT_MAX elements");
CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
}
#endif
template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
void unique(InputIteratorT input, OutputIteratorT output,
NumSelectedIteratorT num_selected_out, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub unique does not support more than INT_MAX elements");
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
}
template <typename InputIteratorT, typename OutputIteratorT, typename CountsOutputIteratorT,
typename LengthOutputIteratorT>
void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out,
LengthOutputIteratorT length_out, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub run_length_encode does not support more than INT_MAX elements");
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
input, output, counts_out, length_out, num_items,
at::cuda::getCurrentCUDAStream());
}
template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub reduce does not support more than INT_MAX elements");
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce,
input, output, num_items, op, init,
at::cuda::getCurrentCUDAStream());
}
} // namespace at::cuda::cub
|