|
#pragma once |
|
#include <cstdint> |
|
#include <c10/core/ScalarType.h> |
|
#include <ATen/cuda/CUDAConfig.h> |
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace at { |
|
namespace cuda { |
|
namespace cub { |
|
|
|
inline int get_num_bits(uint64_t max_key) { |
|
int num_bits = 1; |
|
while (max_key > 1) { |
|
max_key >>= 1; |
|
num_bits++; |
|
} |
|
return num_bits; |
|
} |
|
|
|
namespace detail { |
|
|
|
|
|
|
|
|
|
template <int N> struct alignas(N) OpaqueType { char data[N]; }; |
|
|
|
template<typename key_t, int value_size> |
|
void radix_sort_pairs_impl( |
|
const key_t *keys_in, key_t *keys_out, |
|
const OpaqueType<value_size> *values_in, OpaqueType<value_size> *values_out, |
|
int64_t n, bool descending, int64_t begin_bit, int64_t end_bit); |
|
|
|
} |
|
|
|
template<typename key_t, typename value_t> |
|
void radix_sort_pairs( |
|
const key_t *keys_in, key_t *keys_out, |
|
const value_t *values_in, value_t *values_out, |
|
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) { |
|
static_assert(std::is_trivially_copyable<value_t>::value || |
|
AT_ROCM_ENABLED(), |
|
"radix_sort_pairs value type must be trivially copyable"); |
|
|
|
using opaque_t = detail::OpaqueType<sizeof(value_t)>; |
|
static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0, |
|
"This size of value_t is not instantiated. Please instantiate it in cub.cu" |
|
" and modify this check."); |
|
static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned"); |
|
detail::radix_sort_pairs_impl( |
|
keys_in, keys_out, |
|
reinterpret_cast<const opaque_t*>(values_in), |
|
reinterpret_cast<opaque_t*>(values_out), |
|
n, descending, begin_bit, end_bit); |
|
} |
|
|
|
template<typename key_t> |
|
void radix_sort_keys( |
|
const key_t *keys_in, key_t *keys_out, |
|
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8); |
|
|
|
|
|
template <typename input_t, typename output_t> |
|
void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n); |
|
|
|
template <typename scalar_t> |
|
void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { |
|
return inclusive_sum_truncating(input, output, n); |
|
} |
|
|
|
|
|
template <typename input_t, typename output_t> |
|
void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n); |
|
|
|
template <typename scalar_t> |
|
void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) { |
|
return exclusive_sum_in_common_type(input, output, n); |
|
} |
|
|
|
void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n); |
|
inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) { |
|
return mask_exclusive_sum( |
|
reinterpret_cast<const uint8_t*>(mask), output_idx, n); |
|
} |
|
|
|
}}} |
|
|