|
#pragma once |
|
|
|
#include <torch/all.h> |
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
#include <cuda.h> |
|
#include <cuda_fp16.h> |
|
#include <cuda_runtime.h> |
|
#include <iostream> |
|
|
|
namespace marlin { |
|
|
|
|
|
|
|
|
|
|
|
|
|
static constexpr int default_threads = 256; |
|
|
|
static constexpr int pipe_stages = |
|
4; |
|
|
|
static constexpr int min_thread_n = 64; |
|
static constexpr int min_thread_k = 64; |
|
|
|
static constexpr int tile_size = 16; |
|
static constexpr int max_par = 16; |
|
|
|
|
|
static constexpr int repack_stages = 8; |
|
|
|
static constexpr int repack_threads = 256; |
|
|
|
static constexpr int tile_k_size = tile_size; |
|
static constexpr int tile_n_size = tile_k_size * 4; |
|
|
|
|
|
template <typename T, int n> |
|
struct Vec { |
|
T elems[n]; |
|
__device__ T& operator[](int i) { return elems[i]; } |
|
}; |
|
|
|
using I4 = Vec<int, 4>; |
|
|
|
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } |
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
|
|
#else |
|
|
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, |
|
bool pred = true) { |
|
const int BYTES = 16; |
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); |
|
asm volatile( |
|
"{\n" |
|
" .reg .pred p;\n" |
|
" setp.ne.b32 p, %0, 0;\n" |
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n" |
|
"}\n" ::"r"((int)pred), |
|
"r"(smem), "l"(glob_ptr), "n"(BYTES)); |
|
} |
|
|
|
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { |
|
const int BYTES = 16; |
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); |
|
asm volatile( |
|
"{\n" |
|
" cp.async.cg.shared.global [%0], [%1], %2;\n" |
|
"}\n" ::"r"(smem), |
|
"l"(glob_ptr), "n"(BYTES)); |
|
} |
|
|
|
__device__ inline void cp_async_fence() { |
|
asm volatile("cp.async.commit_group;\n" ::); |
|
} |
|
|
|
template <int n> |
|
__device__ inline void cp_async_wait() { |
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); |
|
} |
|
|
|
#endif |
|
|
|
} |
|
|