|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "../gptq_marlin/marlin.cuh" |
|
#include "../gptq_marlin/marlin_dtypes.cuh" |
|
|
|
using namespace marlin; |
|
|
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ |
|
static_assert(std::is_same<scalar_t, half>::value || \ |
|
std::is_same<scalar_t, nv_bfloat16>::value, \ |
|
"only float16 and bfloat16 is supported"); |
|
|
|
template <typename T> |
|
inline std::string str(T x) { |
|
return std::to_string(x); |
|
} |
|
|
|
namespace fp8_marlin { |
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
|
|
|
template <typename scalar_t, |
|
const int num_bits, |
|
const int threads, |
|
const int thread_m_blocks, |
|
|
|
|
|
const int thread_n_blocks, |
|
const int thread_k_blocks, |
|
const int stages, |
|
|
|
const int group_blocks = -1 |
|
|
|
> |
|
__global__ void Marlin( |
|
const int4* __restrict__ A, |
|
const int4* __restrict__ B, |
|
int4* __restrict__ C, |
|
const int4* __restrict__ scales_ptr, |
|
|
|
int num_groups, |
|
int prob_m, |
|
int prob_n, |
|
int prob_k, |
|
int* locks |
|
) {} |
|
|
|
} |
|
|
|
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, |
|
torch::Tensor& b_scales, torch::Tensor& workspace, |
|
int64_t num_bits, int64_t size_m, int64_t size_n, |
|
int64_t size_k) { |
|
TORCH_CHECK_NOT_IMPLEMENTED(false, |
|
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"); |
|
return torch::empty({1, 1}); |
|
} |
|
|
|
#else |
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag, |
|
const typename ScalarType<scalar_t>::FragB& frag_b, |
|
typename ScalarType<scalar_t>::FragC& frag_c) { |
|
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag); |
|
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); |
|
float* c = reinterpret_cast<float*>(&frag_c); |
|
if constexpr (std::is_same<scalar_t, half>::value) { |
|
asm volatile( |
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " |
|
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" |
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) |
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), |
|
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); |
|
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { |
|
asm volatile( |
|
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " |
|
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" |
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) |
|
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), |
|
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); |
|
} else { |
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); |
|
} |
|
} |
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a, |
|
const void* smem_ptr) { |
|
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a); |
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); |
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" |
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) |
|
: "r"(smem)); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) { |
|
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); |
|
} |
|
|
|
template <> |
|
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) { |
|
|
|
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; |
|
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; |
|
|
|
|
|
constexpr int MASK1 = 0x80000000; |
|
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); |
|
constexpr int MASK3 = MASK2 & 0x7fffffff; |
|
constexpr int MASK = MASK3 | (MASK3 >> 16); |
|
|
|
|
|
|
|
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); |
|
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); |
|
|
|
|
|
constexpr int BIAS_OFFSET = |
|
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); |
|
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); |
|
|
|
|
|
typename ScalarType<half>::FragB frag_b; |
|
|
|
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg); |
|
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg); |
|
return frag_b; |
|
} |
|
|
|
template <> |
|
__device__ inline typename ScalarType<nv_bfloat16>::FragB |
|
dequant_8bit<nv_bfloat16>(int q) { |
|
|
|
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; |
|
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; |
|
|
|
|
|
constexpr int MASK1 = 0x80000000; |
|
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); |
|
constexpr int MASK3 = MASK2 & 0x7fffffff; |
|
constexpr int MASK = MASK3 | (MASK3 >> 16); |
|
|
|
|
|
|
|
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); |
|
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); |
|
|
|
|
|
constexpr int BIAS_OFFSET = |
|
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); |
|
|
|
|
|
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; |
|
const nv_bfloat162 bias_reg = |
|
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS)); |
|
|
|
|
|
typename ScalarType<nv_bfloat16>::FragB frag_b; |
|
|
|
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg); |
|
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg); |
|
return frag_b; |
|
} |
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b, |
|
typename ScalarType<scalar_t>::FragS& frag_s, |
|
int i) { |
|
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; |
|
scalar_t2 s = |
|
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]); |
|
frag_b[0] = __hmul2(frag_b[0], s); |
|
frag_b[1] = __hmul2(frag_b[1], s); |
|
} |
|
|
|
|
|
template <typename scalar_t> |
|
__device__ inline void scale_float(float* c, |
|
typename ScalarType<scalar_t>::FragS& s) { |
|
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s); |
|
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0])); |
|
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1])); |
|
} |
|
|
|
|
|
__device__ inline void barrier_acquire(int* lock, int count) { |
|
if (threadIdx.x == 0) { |
|
int state = -1; |
|
do |
|
|
|
|
|
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" |
|
: "=r"(state) |
|
: "l"(lock)); |
|
while (state != count); |
|
} |
|
__syncthreads(); |
|
} |
|
|
|
|
|
__device__ inline void barrier_release(int* lock, bool reset = false) { |
|
__syncthreads(); |
|
if (threadIdx.x == 0) { |
|
if (reset) { |
|
lock[0] = 0; |
|
return; |
|
} |
|
int val = 1; |
|
|
|
|
|
asm volatile("fence.acq_rel.gpu;\n"); |
|
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" |
|
: |
|
: "l"(lock), "r"(val)); |
|
} |
|
} |
|
|
|
template <typename scalar_t, |
|
const int num_bits, |
|
const int threads, |
|
const int thread_m_blocks, |
|
|
|
|
|
const int thread_n_blocks, |
|
const int thread_k_blocks, |
|
const int stages, |
|
|
|
const int group_blocks = -1 |
|
|
|
> |
|
__global__ void Marlin( |
|
const int4* __restrict__ A, |
|
const int4* __restrict__ B, |
|
int4* __restrict__ C, |
|
const int4* __restrict__ scales_ptr, |
|
|
|
int num_groups, |
|
int prob_m, |
|
int prob_n, |
|
int prob_k, |
|
int* locks |
|
) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using Dtype = ScalarType<scalar_t>; |
|
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; |
|
using FragA = typename ScalarType<scalar_t>::FragA; |
|
using FragB = typename ScalarType<scalar_t>::FragB; |
|
using FragC = typename ScalarType<scalar_t>::FragC; |
|
using FragS = typename ScalarType<scalar_t>::FragS; |
|
|
|
constexpr int pack_factor = 32 / num_bits; |
|
|
|
|
|
|
|
int parallel = 1; |
|
if (prob_m > 16 * thread_m_blocks) { |
|
parallel = prob_m / (16 * thread_m_blocks); |
|
prob_m = 16 * thread_m_blocks; |
|
} |
|
|
|
int k_tiles = prob_k / 16 / thread_k_blocks; |
|
int n_tiles = prob_n / 16 / thread_n_blocks; |
|
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); |
|
|
|
int slice_row = (iters * blockIdx.x) % k_tiles; |
|
int slice_col_par = (iters * blockIdx.x) / k_tiles; |
|
int slice_col = slice_col_par; |
|
int slice_iters; |
|
int slice_count = |
|
0; |
|
int slice_idx; |
|
|
|
|
|
|
|
|
|
if (slice_col_par >= n_tiles) { |
|
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; |
|
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; |
|
locks += (slice_col_par / n_tiles) * n_tiles; |
|
slice_col = slice_col_par % n_tiles; |
|
} |
|
|
|
|
|
|
|
auto init_slice = [&]() { |
|
slice_iters = |
|
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); |
|
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; |
|
if (slice_iters == 0) return; |
|
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; |
|
slice_count = 1; |
|
slice_idx = 0; |
|
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); |
|
if (col_first <= k_tiles * (slice_col_par + 1)) { |
|
int col_off = col_first - k_tiles * slice_col_par; |
|
slice_count = div_ceil(k_tiles - col_off, iters); |
|
if (col_off > 0) slice_count++; |
|
int delta_first = iters * blockIdx.x - col_first; |
|
if (delta_first < 0 || (col_off == 0 && delta_first == 0)) |
|
slice_idx = slice_count - 1; |
|
else { |
|
slice_idx = slice_count - 1 - delta_first / iters; |
|
if (col_off > 0) slice_idx--; |
|
} |
|
} |
|
if (slice_col == n_tiles) { |
|
A += 16 * thread_m_blocks * prob_k / 8; |
|
C += 16 * thread_m_blocks * prob_n / 8; |
|
locks += n_tiles; |
|
slice_col = 0; |
|
} |
|
}; |
|
init_slice(); |
|
|
|
|
|
|
|
|
|
int a_gl_stride = prob_k / 8; |
|
|
|
constexpr int a_sh_stride = 16 * thread_k_blocks / 8; |
|
|
|
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; |
|
|
|
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); |
|
|
|
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); |
|
|
|
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); |
|
|
|
constexpr int a_sh_rd_delta_i = a_sh_stride * 16; |
|
|
|
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); |
|
|
|
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); |
|
|
|
|
|
int b_gl_stride = 16 * prob_n / (pack_factor * 4); |
|
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; |
|
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; |
|
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; |
|
|
|
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; |
|
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); |
|
constexpr int b_sh_wr_delta = threads * b_thread_vecs; |
|
constexpr int b_sh_rd_delta = threads * b_thread_vecs; |
|
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; |
|
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; |
|
|
|
|
|
int s_gl_stride = prob_n / 8; |
|
constexpr int s_sh_stride = 16 * thread_n_blocks / 8; |
|
|
|
|
|
constexpr int tb_k = 16 * thread_k_blocks; |
|
constexpr int g_idx_stage = 0; |
|
|
|
|
|
int act_s_col_stride = 1; |
|
int act_s_col_warp_stride = act_s_col_stride * 8; |
|
int tb_n_warps = thread_n_blocks / 4; |
|
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; |
|
|
|
|
|
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + |
|
(threadIdx.x % a_gl_rd_delta_o); |
|
a_gl_rd += a_gl_rd_delta_o * slice_row; |
|
|
|
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + |
|
(threadIdx.x % a_gl_rd_delta_o); |
|
|
|
int a_sh_rd = |
|
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; |
|
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); |
|
|
|
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + |
|
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs; |
|
b_gl_rd += b_sh_stride * slice_col; |
|
b_gl_rd += b_gl_rd_delta_o * slice_row; |
|
int b_sh_wr = threadIdx.x * b_thread_vecs; |
|
int b_sh_rd = threadIdx.x * b_thread_vecs; |
|
|
|
|
|
int slice_k_start = tb_k * slice_row; |
|
int slice_k_start_shared_fetch = slice_k_start; |
|
int slice_n_offset = act_s_col_tb_stride * slice_col; |
|
|
|
|
|
int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; |
|
int s_sh_wr = threadIdx.x; |
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; |
|
|
|
|
|
int s_sh_rd = |
|
8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; |
|
|
|
|
|
|
|
|
|
bool a_sh_wr_pred[a_sh_wr_iters]; |
|
#pragma unroll |
|
for (int i = 0; i < a_sh_wr_iters; i++) |
|
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto transform_a = [&](int i) { |
|
int row = i / a_gl_rd_delta_o; |
|
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; |
|
}; |
|
|
|
|
|
|
|
int a_sh_wr_trans[a_sh_wr_iters]; |
|
#pragma unroll |
|
for (int i = 0; i < a_sh_wr_iters; i++) |
|
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); |
|
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < thread_m_blocks; j++) |
|
a_sh_rd_trans[i][j] = |
|
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
const int4* B_ptr[b_sh_wr_iters]; |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) |
|
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; |
|
|
|
extern __shared__ int4 sh[]; |
|
|
|
int4* sh_a = sh; |
|
int4* sh_b = sh_a + (stages * a_sh_stage); |
|
int4* sh_g_idx = sh_b + (stages * b_sh_stage); |
|
int4* sh_s = sh_g_idx + (stages * g_idx_stage); |
|
|
|
|
|
FragA frag_a[2][thread_m_blocks]; |
|
I4 frag_b_quant[2][b_thread_vecs]; |
|
FragC frag_c[thread_m_blocks][4][2]; |
|
FragS frag_s[2][4]; |
|
|
|
|
|
auto zero_accums = [&]() { |
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) |
|
reinterpret_cast<float*>(frag_c)[i] = 0; |
|
}; |
|
|
|
int sh_first_group_id = -1; |
|
int sh_num_groups = -1; |
|
constexpr int sh_max_num_groups = 32; |
|
|
|
auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, |
|
int last_group_id) { |
|
sh_first_group_id = first_group_id; |
|
sh_num_groups = last_group_id - first_group_id + 1; |
|
|
|
if (sh_num_groups < sh_max_num_groups) { |
|
sh_num_groups = sh_max_num_groups; |
|
} |
|
|
|
if (sh_first_group_id + sh_num_groups > num_groups) { |
|
sh_num_groups = num_groups - sh_first_group_id; |
|
} |
|
|
|
int row_offset = first_group_id * s_gl_stride; |
|
|
|
if (is_async) { |
|
for (int i = 0; i < sh_num_groups; i++) { |
|
if (threadIdx.x < s_sh_stride) { |
|
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], |
|
&scales_ptr[row_offset + (i * s_gl_stride) + |
|
slice_n_offset + threadIdx.x]); |
|
} |
|
} |
|
} else { |
|
for (int i = 0; i < sh_num_groups; i++) { |
|
if (threadIdx.x < s_sh_stride) { |
|
sh_s[(i * s_sh_stride) + threadIdx.x] = |
|
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + |
|
threadIdx.x]; |
|
} |
|
} |
|
} |
|
}; |
|
|
|
|
|
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { |
|
if (pred) { |
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe; |
|
#pragma unroll |
|
for (int i = 0; i < a_sh_wr_iters; i++) { |
|
cp_async4_pred( |
|
&sh_a_stage[a_sh_wr_trans[i]], |
|
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], |
|
a_sh_wr_pred[i]); |
|
} |
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe; |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < b_thread_vecs; j++) { |
|
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); |
|
} |
|
|
|
B_ptr[i] += b_gl_rd_delta_o; |
|
} |
|
} |
|
|
|
|
|
cp_async_fence(); |
|
}; |
|
|
|
|
|
auto wait_for_stage = [&]() { |
|
|
|
|
|
|
|
|
|
cp_async_wait<stages - 2>(); |
|
__syncthreads(); |
|
}; |
|
|
|
|
|
|
|
auto fetch_to_registers = [&](int k, int pipe) { |
|
int4* sh_a_stage = sh_a + a_sh_stage * pipe; |
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks; i++) |
|
ldsm4<scalar_t>(frag_a[k % 2][i], |
|
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); |
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe; |
|
|
|
#pragma unroll |
|
for (int i = 0; i < b_thread_vecs; i++) { |
|
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>( |
|
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); |
|
} |
|
}; |
|
|
|
bool is_same_group[stages]; |
|
int same_group_id[stages]; |
|
|
|
auto init_same_group = [&](int pipe) { |
|
is_same_group[pipe] = false; |
|
same_group_id[pipe] = 0; |
|
return; |
|
}; |
|
|
|
|
|
auto matmul = [&](int k) { |
|
|
|
|
|
#pragma unroll |
|
for (int j = 0; j < 4; j++) { |
|
FragB frag_b0; |
|
FragB frag_b1; |
|
|
|
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]); |
|
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; |
|
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; |
|
|
|
frag_b0 = dequant_8bit<scalar_t>(b_quant_0); |
|
frag_b1 = dequant_8bit<scalar_t>(b_quant_1); |
|
|
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks; i++) { |
|
mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); |
|
mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto thread_block_reduce = [&]() { |
|
constexpr int red_off = threads / b_sh_stride_threads / 2; |
|
if (red_off >= 1) { |
|
int red_idx = threadIdx.x / b_sh_stride_threads; |
|
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; |
|
constexpr int red_sh_delta = b_sh_stride_threads; |
|
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + |
|
(threadIdx.x % b_sh_stride_threads); |
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
for (int m_block = 0; m_block < thread_m_blocks; m_block++) { |
|
#pragma unroll |
|
for (int i = red_off; i > 0; i /= 2) { |
|
if (i <= red_idx && red_idx < 2 * i) { |
|
#pragma unroll |
|
for (int j = 0; j < 4 * 2; j++) { |
|
int red_sh_wr = |
|
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); |
|
if (i < red_off) { |
|
float* c_rd = |
|
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]); |
|
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]); |
|
#pragma unroll |
|
for (int k = 0; k < 4; k++) |
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += |
|
c_rd[k] + c_wr[k]; |
|
} |
|
sh[red_sh_wr] = |
|
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j]; |
|
} |
|
} |
|
__syncthreads(); |
|
} |
|
if (red_idx == 0) { |
|
#pragma unroll |
|
for (int i = 0; i < 4 * 2; i++) { |
|
float* c_rd = |
|
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]); |
|
#pragma unroll |
|
for (int j = 0; j < 4; j++) |
|
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += |
|
c_rd[j]; |
|
} |
|
} |
|
__syncthreads(); |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto global_reduce = [&](bool first = false, bool last = false) { |
|
|
|
|
|
|
|
constexpr int active_threads = 32 * thread_n_blocks / 4; |
|
if (threadIdx.x < active_threads) { |
|
int c_gl_stride = prob_n / 8; |
|
int c_gl_wr_delta_o = 8 * c_gl_stride; |
|
int c_gl_wr_delta_i = 4 * (active_threads / 32); |
|
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + |
|
4 * (threadIdx.x / 32) + threadIdx.x % 4; |
|
c_gl_wr += (2 * thread_n_blocks) * slice_col; |
|
constexpr int c_sh_wr_delta = active_threads; |
|
int c_sh_wr = threadIdx.x; |
|
|
|
int row = (threadIdx.x % 32) / 4; |
|
|
|
if (!first) { |
|
|
|
|
|
|
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks * 4; i++) { |
|
cp_async4_pred( |
|
&sh[c_sh_wr + c_sh_wr_delta * i], |
|
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + |
|
c_gl_wr_delta_i * (i % 2)], |
|
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); |
|
} |
|
cp_async_fence(); |
|
cp_async_wait<0>(); |
|
} |
|
|
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks * 4; i++) { |
|
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { |
|
if (!first) { |
|
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; |
|
#pragma unroll |
|
for (int j = 0; j < 2 * 4; j++) { |
|
reinterpret_cast<float*>( |
|
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += |
|
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]); |
|
} |
|
} |
|
if (!last) { |
|
int4 c; |
|
#pragma unroll |
|
for (int j = 0; j < 2 * 4; j++) { |
|
reinterpret_cast<scalar_t*>(&c)[j] = |
|
Dtype::float2num(reinterpret_cast<float*>( |
|
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); |
|
} |
|
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = |
|
c; |
|
} |
|
} |
|
} |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
auto write_result = [&]() { |
|
int c_gl_stride = prob_n / 8; |
|
constexpr int c_sh_stride = 2 * thread_n_blocks + 1; |
|
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); |
|
constexpr int c_sh_rd_delta = |
|
c_sh_stride * (threads / (2 * thread_n_blocks)); |
|
|
|
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + |
|
(threadIdx.x % (2 * thread_n_blocks)); |
|
c_gl_wr += (2 * thread_n_blocks) * slice_col; |
|
int c_sh_wr = |
|
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; |
|
c_sh_wr += 32 * (threadIdx.x / 32); |
|
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + |
|
(threadIdx.x % (2 * thread_n_blocks)); |
|
|
|
int c_gl_wr_end = c_gl_stride * prob_m; |
|
|
|
|
|
|
|
auto write = [&](int idx, float c0, float c1, FragS& s) { |
|
scalar_t2 res = |
|
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); |
|
|
|
((scalar_t2*)sh)[idx] = res; |
|
}; |
|
|
|
if (threadIdx.x / 32 < thread_n_blocks / 4) { |
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < 4; j++) { |
|
int wr = c_sh_wr + 8 * j; |
|
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], |
|
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); |
|
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], |
|
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); |
|
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], |
|
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); |
|
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], |
|
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); |
|
} |
|
c_sh_wr += 16 * (4 * c_sh_stride); |
|
} |
|
} |
|
__syncthreads(); |
|
|
|
#pragma unroll |
|
for (int i = 0; |
|
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); |
|
i++) { |
|
if (c_gl_wr < c_gl_wr_end) { |
|
C[c_gl_wr] = sh[c_sh_rd]; |
|
c_gl_wr += c_gl_wr_delta; |
|
c_sh_rd += c_sh_rd_delta; |
|
} |
|
} |
|
}; |
|
|
|
|
|
auto start_pipes = [&]() { |
|
|
|
#pragma unroll |
|
for (int i = 0; i < stages - 1; i++) { |
|
fetch_to_shared(i, i, i < slice_iters); |
|
} |
|
|
|
zero_accums(); |
|
wait_for_stage(); |
|
init_same_group(0); |
|
fetch_to_registers(0, 0); |
|
a_gl_rd += a_gl_rd_delta_o * (stages - 1); |
|
slice_k_start_shared_fetch += tb_k * (stages - 1); |
|
}; |
|
if (slice_iters) { |
|
start_pipes(); |
|
} |
|
|
|
|
|
while (slice_iters) { |
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
for (int pipe = 0; pipe < stages;) { |
|
#pragma unroll |
|
for (int k = 0; k < b_sh_wr_iters; k++) { |
|
fetch_to_registers(k + 1, pipe % stages); |
|
if (k == b_sh_wr_iters - 2) { |
|
fetch_to_shared((pipe + stages - 1) % stages, pipe, |
|
slice_iters >= stages); |
|
pipe++; |
|
wait_for_stage(); |
|
init_same_group(pipe % stages); |
|
} |
|
matmul(k); |
|
} |
|
slice_iters--; |
|
if (slice_iters == 0) { |
|
break; |
|
} |
|
} |
|
|
|
a_gl_rd += a_gl_rd_delta_o * stages; |
|
slice_k_start += tb_k * stages; |
|
slice_k_start_shared_fetch += tb_k * stages; |
|
|
|
|
|
|
|
|
|
if (slice_iters == 0) { |
|
cp_async_wait<0>(); |
|
bool last = slice_idx == slice_count - 1; |
|
|
|
|
|
if (s_sh_wr_pred) { |
|
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); |
|
} |
|
cp_async_fence(); |
|
|
|
thread_block_reduce(); |
|
|
|
cp_async_wait<0>(); |
|
__syncthreads(); |
|
if (threadIdx.x / 32 < thread_n_blocks / 4) { |
|
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0]; |
|
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; |
|
} |
|
|
|
|
|
|
|
|
|
if (threadIdx.x / 32 < thread_n_blocks / 4) { |
|
#pragma unroll |
|
for (int i = 0; i < thread_m_blocks; i++) { |
|
#pragma unroll |
|
for (int j = 0; j < 4; j++) { |
|
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]), |
|
frag_s[j / 2][2 * (j % 2) + 0]); |
|
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][2]), |
|
frag_s[j / 2][2 * (j % 2) + 0]); |
|
|
|
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]), |
|
frag_s[j / 2][2 * (j % 2) + 1]); |
|
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]), |
|
frag_s[j / 2][2 * (j % 2) + 1]); |
|
} |
|
} |
|
} |
|
|
|
if (slice_count > 1) { |
|
|
|
barrier_acquire(&locks[slice_col], slice_idx); |
|
global_reduce(slice_idx == 0, last); |
|
barrier_release(&locks[slice_col], last); |
|
} |
|
if (last) |
|
write_result(); |
|
slice_row = 0; |
|
slice_col_par++; |
|
slice_col++; |
|
init_slice(); |
|
if (slice_iters) { |
|
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + |
|
(threadIdx.x % a_gl_rd_delta_o); |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) |
|
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; |
|
if (slice_col == 0) { |
|
#pragma unroll |
|
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; |
|
} |
|
|
|
|
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; |
|
|
|
start_pipes(); |
|
} |
|
} |
|
} |
|
} |
|
|
|
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ |
|
THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ |
|
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ |
|
thread_n_blocks == THREAD_N_BLOCKS && \ |
|
thread_k_blocks == THREAD_K_BLOCKS && \ |
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ |
|
cudaFuncSetAttribute( \ |
|
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \ |
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \ |
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ |
|
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \ |
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS> \ |
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ |
|
A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ |
|
locks); \ |
|
} |
|
|
|
typedef struct { |
|
int thread_k; |
|
int thread_n; |
|
int num_threads; |
|
} thread_config_t; |
|
|
|
typedef struct { |
|
int max_m_blocks; |
|
thread_config_t tb_cfg; |
|
} exec_config_t; |
|
|
|
thread_config_t small_batch_thread_configs[] = { |
|
|
|
|
|
|
|
{128, 128, 256}, |
|
{64, 128, 128}, |
|
{128, 64, 128}, |
|
}; |
|
|
|
thread_config_t large_batch_thread_configs[] = { |
|
|
|
|
|
|
|
{64, 256, 256}, |
|
{64, 128, 128}, |
|
{128, 64, 128}, |
|
|
|
}; |
|
|
|
int get_scales_cache_size(thread_config_t const& th_config, int prob_m, |
|
int prob_n, int prob_k, int num_bits, |
|
int group_size) { |
|
int tb_n = th_config.thread_n; |
|
|
|
|
|
|
|
int tb_groups = 1; |
|
int tb_scales = tb_groups * tb_n * 2; |
|
|
|
return tb_scales * pipe_stages; |
|
} |
|
|
|
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, |
|
int prob_m, int prob_n, int prob_k, int num_bits, |
|
int scales_cache_size, int max_shared_mem) { |
|
int pack_factor = 32 / num_bits; |
|
|
|
|
|
int tb_k = th_config.thread_k; |
|
int tb_n = th_config.thread_n; |
|
|
|
int b_size = (tb_k * tb_n / pack_factor) * 4; |
|
|
|
|
|
int m_blocks = div_ceil(prob_m, 16); |
|
int tb_max_m = 16; |
|
|
|
while (true) { |
|
if (m_blocks >= max_m_blocks) { |
|
tb_max_m *= max_m_blocks; |
|
break; |
|
} |
|
|
|
max_m_blocks--; |
|
if (max_m_blocks == 0) { |
|
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); |
|
} |
|
} |
|
|
|
int a_size = (tb_max_m * tb_k) * 2; |
|
|
|
float pipe_size = (a_size + b_size) * pipe_stages; |
|
|
|
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); |
|
|
|
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); |
|
} |
|
|
|
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, |
|
int prob_m, int prob_n, int prob_k, int num_bits, |
|
int group_size, int max_shared_mem) { |
|
|
|
if (th_config.thread_k == -1 || th_config.thread_n == -1 || |
|
th_config.num_threads == -1) { |
|
return false; |
|
} |
|
|
|
|
|
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { |
|
return false; |
|
} |
|
|
|
|
|
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { |
|
return false; |
|
} |
|
|
|
|
|
if (th_config.num_threads < 128) { |
|
return false; |
|
} |
|
|
|
|
|
int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, |
|
prob_k, num_bits, group_size); |
|
|
|
|
|
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
|
num_bits, scales_cache_size, max_shared_mem)) { |
|
return false; |
|
} |
|
|
|
return true; |
|
} |
|
|
|
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, |
|
int num_bits, int group_size, |
|
int max_shared_mem) { |
|
int max_m_blocks = 4; |
|
while (max_m_blocks > 0) { |
|
if (prob_m <= 16) { |
|
for (auto th_config : small_batch_thread_configs) { |
|
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
|
num_bits, group_size, max_shared_mem)) { |
|
return exec_config_t{max_m_blocks, th_config}; |
|
} |
|
} |
|
} else { |
|
for (auto th_config : large_batch_thread_configs) { |
|
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
|
num_bits, group_size, max_shared_mem)) { |
|
return exec_config_t{max_m_blocks, th_config}; |
|
} |
|
} |
|
} |
|
|
|
max_m_blocks--; |
|
|
|
} |
|
|
|
return exec_config_t{0, {-1, -1, -1}}; |
|
} |
|
|
|
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ |
|
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ |
|
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ |
|
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ |
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) |
|
|
|
template <typename scalar_t> |
|
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, |
|
int prob_n, int prob_k, void* workspace, int num_bits, |
|
int num_groups, int group_size, int dev, |
|
cudaStream_t stream, int thread_k, int thread_n, int sms, |
|
int max_par) { |
|
TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); |
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, |
|
", ", prob_n, ", ", prob_k, "]"); |
|
|
|
int tot_m = prob_m; |
|
int tot_m_blocks = div_ceil(tot_m, 16); |
|
int pad = 16 * tot_m_blocks - tot_m; |
|
|
|
if (sms == -1) { |
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); |
|
} |
|
|
|
int max_shared_mem = 0; |
|
cudaDeviceGetAttribute(&max_shared_mem, |
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); |
|
TORCH_CHECK(max_shared_mem > 0); |
|
|
|
|
|
exec_config_t exec_cfg; |
|
if (thread_k != -1 && thread_n != -1) { |
|
|
|
exec_cfg = |
|
exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; |
|
} else { |
|
|
|
exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, |
|
group_size, max_shared_mem); |
|
} |
|
|
|
TORCH_CHECK( |
|
exec_cfg.max_m_blocks > 0 && |
|
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, |
|
prob_n, prob_k, num_bits, group_size, max_shared_mem), |
|
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, |
|
", thread_k = ", exec_cfg.tb_cfg.thread_k, |
|
", thread_n = ", exec_cfg.tb_cfg.thread_n, |
|
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, |
|
", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, |
|
", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); |
|
|
|
int num_threads = exec_cfg.tb_cfg.num_threads; |
|
thread_k = exec_cfg.tb_cfg.thread_k; |
|
thread_n = exec_cfg.tb_cfg.thread_n; |
|
|
|
int thread_k_blocks = thread_k / 16; |
|
int thread_n_blocks = thread_n / 16; |
|
|
|
int blocks = sms; |
|
|
|
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, |
|
" is not divisible by thread_n = ", thread_n); |
|
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, |
|
" is not divisible by thread_k = ", thread_k); |
|
|
|
int group_blocks = -1; |
|
|
|
const int4* A_ptr = (const int4*)A; |
|
const int4* B_ptr = (const int4*)B; |
|
int4* C_ptr = (int4*)C; |
|
const int4* s_ptr = (const int4*)s; |
|
|
|
int* locks = (int*)workspace; |
|
|
|
|
|
for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { |
|
int thread_m_blocks = tot_m_blocks - i; |
|
prob_m = tot_m - 16 * i; |
|
int par = 1; |
|
if (thread_m_blocks > exec_cfg.max_m_blocks) { |
|
|
|
|
|
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); |
|
if (par > max_par) par = max_par; |
|
prob_m = (16 * exec_cfg.max_m_blocks) * par; |
|
i += exec_cfg.max_m_blocks * (par - 1); |
|
thread_m_blocks = exec_cfg.max_m_blocks; |
|
} |
|
|
|
|
|
if (false) { |
|
} |
|
CALL_IF(8, 32, 2, 256) |
|
CALL_IF(8, 16, 4, 256) |
|
CALL_IF(8, 8, 8, 256) |
|
CALL_IF(8, 8, 4, 128) |
|
CALL_IF(8, 4, 8, 128) |
|
else { |
|
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + |
|
str(prob_n) + ", " + str(prob_k) + "]" + |
|
", num_groups = " + str(num_groups) + |
|
", group_size = " + str(group_size) + |
|
", thread_m_blocks = " + str(thread_m_blocks) + |
|
", thread_n_blocks = " + str(thread_n_blocks) + |
|
", thread_k_blocks = " + str(thread_k_blocks)); |
|
} |
|
|
|
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; |
|
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; |
|
} |
|
} |
|
|
|
} |
|
|
|
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, |
|
torch::Tensor& b_scales, torch::Tensor& workspace, |
|
int64_t num_bits, int64_t size_m, int64_t size_n, |
|
int64_t size_k) { |
|
|
|
TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); |
|
int pack_factor = 32 / num_bits; |
|
|
|
|
|
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), |
|
", size_m = ", size_m); |
|
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), |
|
", size_k = ", size_k); |
|
|
|
|
|
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, |
|
" is not divisible by tile_size = ", marlin::tile_size); |
|
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), |
|
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), |
|
", size_k = ", size_k, ", tile_size = ", marlin::tile_size); |
|
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, |
|
"b_q_weight.size(1) = ", b_q_weight.size(1), |
|
" is not divisible by tile_size = ", marlin::tile_size); |
|
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; |
|
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, |
|
", actual_size_n = ", actual_size_n); |
|
|
|
|
|
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); |
|
TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); |
|
|
|
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); |
|
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); |
|
|
|
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); |
|
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); |
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); |
|
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); |
|
torch::Tensor c = torch::empty({size_m, size_n}, options); |
|
|
|
|
|
|
|
int thread_k = -1; |
|
|
|
|
|
int thread_n = -1; |
|
|
|
int sms = -1; |
|
|
|
|
|
int num_groups = -1; |
|
int group_size = -1; |
|
|
|
int b_rank = b_scales.sizes().size(); |
|
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); |
|
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), |
|
" is not size_n = ", size_n); |
|
|
|
TORCH_CHECK(b_scales.size(0) == 1) |
|
num_groups = b_scales.size(0); |
|
|
|
|
|
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, |
|
", is not divisible by min_thread_n = ", marlin::min_thread_n); |
|
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; |
|
TORCH_CHECK(workspace.numel() >= min_workspace_size, |
|
"workspace.numel = ", workspace.numel(), |
|
" is below min_workspace_size = ", min_workspace_size); |
|
|
|
int dev = a.get_device(); |
|
if (a.scalar_type() == at::ScalarType::Half) { |
|
fp8_marlin::marlin_mm_f16i4<half>( |
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), |
|
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k, |
|
workspace.data_ptr(), num_bits, num_groups, group_size, dev, |
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, |
|
marlin::max_par); |
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) { |
|
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>( |
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), |
|
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m, |
|
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, |
|
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, |
|
marlin::max_par); |
|
} else { |
|
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); |
|
} |
|
|
|
return c; |
|
} |
|
|
|
#endif |
|
|
|
|