Add GPTQ-Marlin
Browse files- build.toml +14 -0
- ext-torch/torch_binding.cpp +26 -0
- ext-torch/torch_binding.h +28 -0
- gptq_marlin/awq_marlin_repack.cu +258 -0
- gptq_marlin/gptq_marlin.cu +2423 -0
- gptq_marlin/gptq_marlin_repack.cu +333 -0
build.toml
CHANGED
@@ -5,6 +5,7 @@ version = "0.0.1"
|
|
5 |
name = "quantization"
|
6 |
src = [
|
7 |
"core/registration.h",
|
|
|
8 |
"ext-torch/torch_binding.cpp",
|
9 |
"ext-torch/torch_binding.h"
|
10 |
]
|
@@ -69,3 +70,16 @@ src = [
|
|
69 |
]
|
70 |
include = [ "." ]
|
71 |
depends = [ "torch" ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
name = "quantization"
|
6 |
src = [
|
7 |
"core/registration.h",
|
8 |
+
"core/scalar_type.hpp",
|
9 |
"ext-torch/torch_binding.cpp",
|
10 |
"ext-torch/torch_binding.h"
|
11 |
]
|
|
|
70 |
]
|
71 |
include = [ "." ]
|
72 |
depends = [ "torch" ]
|
73 |
+
|
74 |
+
[kernel.gptq_marlin]
|
75 |
+
capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
76 |
+
src = [
|
77 |
+
"core/scalar_type.hpp",
|
78 |
+
"gptq_marlin/awq_marlin_repack.cu",
|
79 |
+
"gptq_marlin/gptq_marlin.cu",
|
80 |
+
"gptq_marlin/gptq_marlin_repack.cu",
|
81 |
+
"gptq_marlin/marlin.cuh",
|
82 |
+
"gptq_marlin/marlin_dtypes.cuh"
|
83 |
+
]
|
84 |
+
include = [ "." ]
|
85 |
+
depends = [ "torch" ]
|
ext-torch/torch_binding.cpp
CHANGED
@@ -66,6 +66,32 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
66 |
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
67 |
"SymInt size_k) -> Tensor");
|
68 |
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
}
|
70 |
|
71 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
66 |
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
67 |
"SymInt size_k) -> Tensor");
|
68 |
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
69 |
+
|
70 |
+
// awq_marlin repack from AWQ.
|
71 |
+
ops.def(
|
72 |
+
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
73 |
+
"SymInt size_n, int num_bits) -> Tensor");
|
74 |
+
ops.impl("awq_marlin_repack", &awq_marlin_repack);
|
75 |
+
|
76 |
+
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
77 |
+
ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
78 |
+
ops.def(
|
79 |
+
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
80 |
+
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
81 |
+
"int b_q_type, "
|
82 |
+
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
83 |
+
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
84 |
+
|
85 |
+
// gptq_marlin repack from GPTQ.
|
86 |
+
ops.def(
|
87 |
+
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
88 |
+
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
89 |
+
ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
90 |
+
}
|
91 |
+
|
92 |
+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
|
93 |
+
ops.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
94 |
+
ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
95 |
}
|
96 |
|
97 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
ext-torch/torch_binding.h
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
|
3 |
#include <torch/torch.h>
|
4 |
|
|
|
|
|
5 |
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
6 |
|
7 |
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
@@ -46,3 +48,29 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|
46 |
torch::Tensor& b_scales, torch::Tensor& workspace,
|
47 |
int64_t num_bits, int64_t size_m, int64_t size_n,
|
48 |
int64_t size_k);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
#include <torch/torch.h>
|
4 |
|
5 |
+
#include <core/scalar_type.hpp>
|
6 |
+
|
7 |
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
8 |
|
9 |
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
|
|
48 |
torch::Tensor& b_scales, torch::Tensor& workspace,
|
49 |
int64_t num_bits, int64_t size_m, int64_t size_n,
|
50 |
int64_t size_k);
|
51 |
+
|
52 |
+
// GPTQ-Marlin
|
53 |
+
|
54 |
+
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
55 |
+
int64_t size_n, int64_t num_bits);
|
56 |
+
|
57 |
+
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
58 |
+
c10::SymInt size_k, c10::SymInt size_n,
|
59 |
+
int64_t num_bits);
|
60 |
+
|
61 |
+
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
62 |
+
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
63 |
+
torch::Tensor& g_idx, torch::Tensor& perm,
|
64 |
+
torch::Tensor& workspace,
|
65 |
+
vllm::ScalarTypeId const& b_q_type_id,
|
66 |
+
int64_t size_m, int64_t size_n, int64_t size_k,
|
67 |
+
bool is_k_full, bool has_zp,
|
68 |
+
bool use_fp32_reduce, bool is_zp_float);
|
69 |
+
|
70 |
+
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
71 |
+
int64_t size_k, int64_t size_n,
|
72 |
+
int64_t num_bits);
|
73 |
+
|
74 |
+
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
75 |
+
torch::Tensor& perm, c10::SymInt size_k,
|
76 |
+
c10::SymInt size_n, int64_t num_bits);
|
gptq_marlin/awq_marlin_repack.cu
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "marlin.cuh"
|
2 |
+
|
3 |
+
namespace marlin {
|
4 |
+
|
5 |
+
template <int const num_threads, int const num_bits>
|
6 |
+
__global__ void awq_marlin_repack_kernel(
|
7 |
+
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
8 |
+
int size_k, int size_n) {
|
9 |
+
constexpr int pack_factor = 32 / num_bits;
|
10 |
+
|
11 |
+
int k_tiles = size_k / tile_k_size;
|
12 |
+
int n_tiles = size_n / tile_n_size;
|
13 |
+
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
14 |
+
|
15 |
+
int start_k_tile = blockIdx.x * block_k_tiles;
|
16 |
+
if (start_k_tile >= k_tiles) {
|
17 |
+
return;
|
18 |
+
}
|
19 |
+
|
20 |
+
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
21 |
+
|
22 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
23 |
+
auto wait_for_stage = [&]() {
|
24 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
25 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
26 |
+
// shared memory load is fully complete (as it may otherwise be
|
27 |
+
// overwritten).
|
28 |
+
cp_async_wait<repack_stages - 2>();
|
29 |
+
__syncthreads();
|
30 |
+
};
|
31 |
+
|
32 |
+
extern __shared__ int4 sh[];
|
33 |
+
|
34 |
+
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
35 |
+
|
36 |
+
constexpr int stage_n_threads = tile_n_ints / 4;
|
37 |
+
constexpr int stage_k_threads = tile_k_size;
|
38 |
+
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
39 |
+
|
40 |
+
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
41 |
+
if (n_tile_id >= n_tiles) {
|
42 |
+
cp_async_fence();
|
43 |
+
return;
|
44 |
+
}
|
45 |
+
|
46 |
+
int first_n = n_tile_id * tile_n_size;
|
47 |
+
int first_n_packed = first_n / pack_factor;
|
48 |
+
|
49 |
+
int4* sh_ptr = sh + stage_size * pipe;
|
50 |
+
|
51 |
+
if (threadIdx.x < stage_size) {
|
52 |
+
int k_id = threadIdx.x / stage_n_threads;
|
53 |
+
int n_id = threadIdx.x % stage_n_threads;
|
54 |
+
|
55 |
+
int first_k = k_tile_id * tile_k_size;
|
56 |
+
|
57 |
+
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
58 |
+
reinterpret_cast<int4 const*>(
|
59 |
+
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
|
60 |
+
first_n_packed + (n_id * 4)])));
|
61 |
+
}
|
62 |
+
|
63 |
+
cp_async_fence();
|
64 |
+
};
|
65 |
+
|
66 |
+
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
67 |
+
if (n_tile_id >= n_tiles) {
|
68 |
+
return;
|
69 |
+
}
|
70 |
+
|
71 |
+
int warp_id = threadIdx.x / 32;
|
72 |
+
int th_id = threadIdx.x % 32;
|
73 |
+
|
74 |
+
if (warp_id >= 4) {
|
75 |
+
return;
|
76 |
+
}
|
77 |
+
|
78 |
+
int tc_col = th_id / 4;
|
79 |
+
int tc_row = (th_id % 4) * 2;
|
80 |
+
|
81 |
+
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
82 |
+
|
83 |
+
int cur_n = warp_id * 16 + tc_col;
|
84 |
+
int cur_n_packed = cur_n / pack_factor;
|
85 |
+
int cur_n_pos = cur_n % pack_factor;
|
86 |
+
|
87 |
+
constexpr int sh_stride = tile_n_ints;
|
88 |
+
constexpr uint32_t mask = (1 << num_bits) - 1;
|
89 |
+
|
90 |
+
int4* sh_stage_ptr = sh + stage_size * pipe;
|
91 |
+
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
92 |
+
|
93 |
+
// Undo interleaving
|
94 |
+
int cur_n_pos_unpacked;
|
95 |
+
if constexpr (num_bits == 4) {
|
96 |
+
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
97 |
+
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
98 |
+
} else {
|
99 |
+
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
100 |
+
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
101 |
+
}
|
102 |
+
|
103 |
+
uint32_t vals[8];
|
104 |
+
#pragma unroll
|
105 |
+
for (int i = 0; i < 4; i++) {
|
106 |
+
int cur_elem = tc_row + tc_offsets[i];
|
107 |
+
|
108 |
+
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
109 |
+
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
110 |
+
sh_stride * cur_elem];
|
111 |
+
|
112 |
+
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
113 |
+
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
114 |
+
}
|
115 |
+
|
116 |
+
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
117 |
+
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
118 |
+
|
119 |
+
// Result of:
|
120 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
121 |
+
if constexpr (num_bits == 4) {
|
122 |
+
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
123 |
+
|
124 |
+
uint32_t res = 0;
|
125 |
+
#pragma unroll
|
126 |
+
for (int i = 0; i < 8; i++) {
|
127 |
+
res |= vals[pack_idx[i]] << (i * 4);
|
128 |
+
}
|
129 |
+
|
130 |
+
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
131 |
+
|
132 |
+
} else {
|
133 |
+
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
134 |
+
|
135 |
+
uint32_t res1 = 0;
|
136 |
+
uint32_t res2 = 0;
|
137 |
+
#pragma unroll
|
138 |
+
for (int i = 0; i < 4; i++) {
|
139 |
+
res1 |= vals[pack_idx[i]] << (i * 8);
|
140 |
+
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
141 |
+
}
|
142 |
+
|
143 |
+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
144 |
+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
145 |
+
}
|
146 |
+
};
|
147 |
+
|
148 |
+
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
149 |
+
#pragma unroll
|
150 |
+
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
151 |
+
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
152 |
+
}
|
153 |
+
|
154 |
+
wait_for_stage();
|
155 |
+
};
|
156 |
+
#pragma unroll
|
157 |
+
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
158 |
+
int n_tile_id = 0;
|
159 |
+
|
160 |
+
start_pipes(k_tile_id, n_tile_id);
|
161 |
+
|
162 |
+
while (n_tile_id < n_tiles) {
|
163 |
+
#pragma unroll
|
164 |
+
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
165 |
+
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
166 |
+
n_tile_id + pipe + repack_stages - 1);
|
167 |
+
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
168 |
+
wait_for_stage();
|
169 |
+
}
|
170 |
+
n_tile_id += repack_stages;
|
171 |
+
}
|
172 |
+
}
|
173 |
+
}
|
174 |
+
|
175 |
+
} // namespace marlin
|
176 |
+
|
177 |
+
#define CALL_IF(NUM_BITS) \
|
178 |
+
else if (num_bits == NUM_BITS) { \
|
179 |
+
cudaFuncSetAttribute( \
|
180 |
+
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
181 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
182 |
+
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
183 |
+
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
184 |
+
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
185 |
+
}
|
186 |
+
|
187 |
+
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
188 |
+
int64_t size_n, int64_t num_bits) {
|
189 |
+
// Verify compatibility with marlin tile of 16x64
|
190 |
+
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
191 |
+
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
192 |
+
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
193 |
+
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
194 |
+
|
195 |
+
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
196 |
+
"num_bits must be 4 or 8. Got = ", num_bits);
|
197 |
+
int const pack_factor = 32 / num_bits;
|
198 |
+
|
199 |
+
// Verify B
|
200 |
+
TORCH_CHECK(b_q_weight.size(0) == size_k,
|
201 |
+
"b_q_weight.size(0) = ", b_q_weight.size(0),
|
202 |
+
" is not size_k = ", size_k);
|
203 |
+
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
|
204 |
+
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
205 |
+
", size_n = ", size_n, ", pack_factor = ", pack_factor);
|
206 |
+
|
207 |
+
// Verify device and strides
|
208 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
209 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
210 |
+
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
211 |
+
|
212 |
+
// Alloc buffers
|
213 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
214 |
+
auto options = torch::TensorOptions()
|
215 |
+
.dtype(b_q_weight.dtype())
|
216 |
+
.device(b_q_weight.device());
|
217 |
+
torch::Tensor out = torch::empty(
|
218 |
+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
219 |
+
options);
|
220 |
+
|
221 |
+
// Get ptrs
|
222 |
+
uint32_t const* b_q_weight_ptr =
|
223 |
+
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
224 |
+
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
225 |
+
|
226 |
+
// Get dev info
|
227 |
+
int dev = b_q_weight.get_device();
|
228 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
229 |
+
int blocks;
|
230 |
+
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
231 |
+
|
232 |
+
int max_shared_mem = 0;
|
233 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
234 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
235 |
+
TORCH_CHECK(max_shared_mem > 0);
|
236 |
+
|
237 |
+
if (false) {
|
238 |
+
}
|
239 |
+
CALL_IF(4)
|
240 |
+
CALL_IF(8)
|
241 |
+
else {
|
242 |
+
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
243 |
+
}
|
244 |
+
|
245 |
+
return out;
|
246 |
+
}
|
247 |
+
|
248 |
+
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
249 |
+
c10::SymInt size_k, c10::SymInt size_n,
|
250 |
+
int64_t num_bits) {
|
251 |
+
int const pack_factor = 32 / num_bits;
|
252 |
+
auto options = torch::TensorOptions()
|
253 |
+
.dtype(b_q_weight.dtype())
|
254 |
+
.device(b_q_weight.device());
|
255 |
+
return torch::empty_symint(
|
256 |
+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
257 |
+
options);
|
258 |
+
}
|
gptq_marlin/gptq_marlin.cu
ADDED
@@ -0,0 +1,2423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Modified by Neural Magic
|
3 |
+
* Copyright (C) Marlin.2024 Elias Frantar
|
4 |
+
*
|
5 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
* you may not use this file except in compliance with the License.
|
7 |
+
* You may obtain a copy of the License at
|
8 |
+
*
|
9 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
*
|
11 |
+
* Unless required by applicable law or agreed to in writing, software
|
12 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
* See the License for the specific language governing permissions and
|
15 |
+
* limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
/*
|
19 |
+
* Adapted from https://github.com/IST-DASLab/marlin
|
20 |
+
*/
|
21 |
+
|
22 |
+
#include "marlin.cuh"
|
23 |
+
#include "marlin_dtypes.cuh"
|
24 |
+
#include "core/scalar_type.hpp"
|
25 |
+
|
26 |
+
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
27 |
+
static_assert(std::is_same<scalar_t, half>::value || \
|
28 |
+
std::is_same<scalar_t, nv_bfloat16>::value, \
|
29 |
+
"only float16 and bfloat16 is supported");
|
30 |
+
|
31 |
+
template <typename T>
|
32 |
+
inline std::string str(T x) {
|
33 |
+
return std::to_string(x);
|
34 |
+
}
|
35 |
+
|
36 |
+
namespace marlin {
|
37 |
+
|
38 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
39 |
+
|
40 |
+
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
41 |
+
int const* __restrict__ perm_int_ptr,
|
42 |
+
int4* __restrict__ out_int4_ptr, int size_m,
|
43 |
+
int size_k, int block_rows) {}
|
44 |
+
|
45 |
+
template <typename scalar_t, // compute dtype, half or nv_float16
|
46 |
+
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
47 |
+
const int threads, // number of threads in a threadblock
|
48 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
49 |
+
// dimension (batchsize) of the
|
50 |
+
// threadblock
|
51 |
+
const int thread_n_blocks, // same for n dimension (output)
|
52 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
53 |
+
const int stages, // number of stages for the async global->shared
|
54 |
+
// fetch pipeline
|
55 |
+
const bool has_act_order, // whether act_order is enabled
|
56 |
+
const int group_blocks = -1, // number of consecutive 16x16 blocks
|
57 |
+
// with a separate quantization scale
|
58 |
+
const bool is_zp_float // is zero point of float16 type?
|
59 |
+
>
|
60 |
+
__global__ void Marlin(
|
61 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
62 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
63 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
64 |
+
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
65 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
66 |
+
// (k/groupsize)xn
|
67 |
+
const int* __restrict__ g_idx, // int32 group indices of shape k
|
68 |
+
int num_groups, // number of scale groups per output channel
|
69 |
+
int prob_m, // batch dimension m
|
70 |
+
int prob_n, // output dimension n
|
71 |
+
int prob_k, // reduction dimension k
|
72 |
+
int* locks, // extra global storage for barrier synchronization
|
73 |
+
bool use_fp32_reduce // whether to use fp32 global reduce
|
74 |
+
) {}
|
75 |
+
|
76 |
+
} // namespace marlin
|
77 |
+
|
78 |
+
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
79 |
+
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
80 |
+
torch::Tensor& g_idx, torch::Tensor& perm,
|
81 |
+
torch::Tensor& workspace,
|
82 |
+
vllm::ScalarTypeId const b_q_type_id,
|
83 |
+
int64_t size_m, int64_t size_n, int64_t size_k,
|
84 |
+
bool is_k_full, bool has_zp, bool is_zp_float) {
|
85 |
+
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
86 |
+
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
87 |
+
return torch::empty({1, 1});
|
88 |
+
}
|
89 |
+
|
90 |
+
#else
|
91 |
+
|
92 |
+
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
93 |
+
// output/accumulation.
|
94 |
+
template <typename scalar_t>
|
95 |
+
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
|
96 |
+
const typename ScalarType<scalar_t>::FragB& frag_b,
|
97 |
+
typename ScalarType<scalar_t>::FragC& frag_c) {
|
98 |
+
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
99 |
+
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
100 |
+
float* c = reinterpret_cast<float*>(&frag_c);
|
101 |
+
if constexpr (std::is_same<scalar_t, half>::value) {
|
102 |
+
asm volatile(
|
103 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
104 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
105 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
106 |
+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
107 |
+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
108 |
+
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
109 |
+
asm volatile(
|
110 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
111 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
112 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
113 |
+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
114 |
+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
115 |
+
} else {
|
116 |
+
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
117 |
+
}
|
118 |
+
}
|
119 |
+
|
120 |
+
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
121 |
+
// memory, directly in tensor core layout.
|
122 |
+
template <typename scalar_t>
|
123 |
+
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
|
124 |
+
const void* smem_ptr) {
|
125 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
126 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
127 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
128 |
+
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
129 |
+
: "r"(smem));
|
130 |
+
}
|
131 |
+
|
132 |
+
// Lookup-table based 3-input logical operation; explicitly used for
|
133 |
+
// dequantization as the compiler does not seem to automatically recognize it in
|
134 |
+
// all cases.
|
135 |
+
template <int lut>
|
136 |
+
__device__ inline int lop3(int a, int b, int c) {
|
137 |
+
int res;
|
138 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
139 |
+
: "=r"(res)
|
140 |
+
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
141 |
+
return res;
|
142 |
+
}
|
143 |
+
|
144 |
+
// Constructs destination register by taking bytes from 2 sources (based on
|
145 |
+
// mask)
|
146 |
+
template <int start_byte, int mask>
|
147 |
+
__device__ inline uint32_t prmt(uint32_t a) {
|
148 |
+
uint32_t res;
|
149 |
+
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
150 |
+
: "=r"(res)
|
151 |
+
: "r"(a), "n"(start_byte), "n"(mask));
|
152 |
+
return res;
|
153 |
+
}
|
154 |
+
|
155 |
+
template <typename scalar_t, vllm::ScalarTypeId w_type_id>
|
156 |
+
__device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
|
157 |
+
|
158 |
+
//
|
159 |
+
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
160 |
+
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
161 |
+
// with some small changes:
|
162 |
+
// - FP16:
|
163 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
164 |
+
// - BF16:
|
165 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
166 |
+
//
|
167 |
+
template <>
|
168 |
+
__device__ inline typename ScalarType<half>::FragB
|
169 |
+
dequant<half, vllm::kU4B8.id()>(int q) {
|
170 |
+
const int LO = 0x000f000f;
|
171 |
+
const int HI = 0x00f000f0;
|
172 |
+
const int EX = 0x64006400;
|
173 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
174 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
175 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
176 |
+
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
177 |
+
// directly into `SUB` and `ADD`.
|
178 |
+
const int SUB = 0x64086408;
|
179 |
+
const int MUL = 0x2c002c00;
|
180 |
+
const int ADD = 0xd480d480;
|
181 |
+
typename ScalarType<half>::FragB frag_b;
|
182 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
183 |
+
*reinterpret_cast<const half2*>(&SUB));
|
184 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
185 |
+
*reinterpret_cast<const half2*>(&MUL),
|
186 |
+
*reinterpret_cast<const half2*>(&ADD));
|
187 |
+
return frag_b;
|
188 |
+
}
|
189 |
+
|
190 |
+
template <>
|
191 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
192 |
+
dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
|
193 |
+
static constexpr uint32_t MASK = 0x000f000f;
|
194 |
+
static constexpr uint32_t EX = 0x43004300;
|
195 |
+
|
196 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
197 |
+
|
198 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
199 |
+
q >>= 4;
|
200 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
201 |
+
|
202 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
203 |
+
static constexpr uint32_t MUL = 0x3F803F80;
|
204 |
+
static constexpr uint32_t ADD = 0xC308C308;
|
205 |
+
|
206 |
+
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
207 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
208 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
209 |
+
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
210 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
211 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
212 |
+
return frag_b;
|
213 |
+
}
|
214 |
+
|
215 |
+
template <>
|
216 |
+
__device__ inline typename ScalarType<half>::FragB
|
217 |
+
dequant<half, vllm::kU4.id()>(int q) {
|
218 |
+
const int LO = 0x000f000f;
|
219 |
+
const int HI = 0x00f000f0;
|
220 |
+
const int EX = 0x64006400;
|
221 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
222 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
223 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
224 |
+
|
225 |
+
const int SUB = 0x64006400;
|
226 |
+
const int MUL = 0x2c002c00;
|
227 |
+
const int ADD = 0xd400d400;
|
228 |
+
typename ScalarType<half>::FragB frag_b;
|
229 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
230 |
+
*reinterpret_cast<const half2*>(&SUB));
|
231 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
232 |
+
*reinterpret_cast<const half2*>(&MUL),
|
233 |
+
*reinterpret_cast<const half2*>(&ADD));
|
234 |
+
return frag_b;
|
235 |
+
}
|
236 |
+
|
237 |
+
template <>
|
238 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
239 |
+
dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
|
240 |
+
static constexpr uint32_t MASK = 0x000f000f;
|
241 |
+
static constexpr uint32_t EX = 0x43004300;
|
242 |
+
|
243 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
244 |
+
|
245 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
246 |
+
q >>= 4;
|
247 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
248 |
+
|
249 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
250 |
+
static constexpr uint32_t MUL = 0x3F803F80;
|
251 |
+
static constexpr uint32_t ADD = 0xC300C300;
|
252 |
+
|
253 |
+
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
254 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
255 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
256 |
+
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
257 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
258 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
259 |
+
return frag_b;
|
260 |
+
}
|
261 |
+
|
262 |
+
//
|
263 |
+
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
264 |
+
// bf16 Reference:
|
265 |
+
// - FP16:
|
266 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
267 |
+
// - BF16:
|
268 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
269 |
+
//
|
270 |
+
template <>
|
271 |
+
__device__ inline typename ScalarType<half>::FragB
|
272 |
+
dequant<half, vllm::kU8B128.id()>(int q) {
|
273 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
274 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
275 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
276 |
+
|
277 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
278 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
279 |
+
|
280 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
281 |
+
|
282 |
+
typename ScalarType<half>::FragB frag_b;
|
283 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
284 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
285 |
+
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
286 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
287 |
+
return frag_b;
|
288 |
+
}
|
289 |
+
|
290 |
+
template <>
|
291 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
292 |
+
dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) {
|
293 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
294 |
+
|
295 |
+
float fp32_intermediates[4];
|
296 |
+
uint32_t* fp32_intermediates_casted =
|
297 |
+
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
298 |
+
|
299 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
300 |
+
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
301 |
+
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
302 |
+
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
303 |
+
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
304 |
+
|
305 |
+
fp32_intermediates[0] -= 8388736.f;
|
306 |
+
fp32_intermediates[1] -= 8388736.f;
|
307 |
+
fp32_intermediates[2] -= 8388736.f;
|
308 |
+
fp32_intermediates[3] -= 8388736.f;
|
309 |
+
|
310 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
311 |
+
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
312 |
+
fp32_intermediates_casted[1], 0x7632);
|
313 |
+
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
314 |
+
fp32_intermediates_casted[3], 0x7632);
|
315 |
+
|
316 |
+
return frag_b;
|
317 |
+
}
|
318 |
+
|
319 |
+
template <>
|
320 |
+
__device__ inline typename ScalarType<half>::FragB
|
321 |
+
dequant<half, vllm::kU8.id()>(int q) {
|
322 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
323 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
324 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
325 |
+
|
326 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
327 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
328 |
+
|
329 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
330 |
+
|
331 |
+
typename ScalarType<half>::FragB frag_b;
|
332 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
333 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
334 |
+
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
335 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
336 |
+
return frag_b;
|
337 |
+
}
|
338 |
+
|
339 |
+
template <>
|
340 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
341 |
+
dequant<nv_bfloat16, vllm::kU8.id()>(int q) {
|
342 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
343 |
+
|
344 |
+
float fp32_intermediates[4];
|
345 |
+
uint32_t* fp32_intermediates_casted =
|
346 |
+
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
347 |
+
|
348 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
349 |
+
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
350 |
+
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
351 |
+
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
352 |
+
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
353 |
+
|
354 |
+
fp32_intermediates[0] -= 8388608.f;
|
355 |
+
fp32_intermediates[1] -= 8388608.f;
|
356 |
+
fp32_intermediates[2] -= 8388608.f;
|
357 |
+
fp32_intermediates[3] -= 8388608.f;
|
358 |
+
|
359 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
360 |
+
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
361 |
+
fp32_intermediates_casted[1], 0x7632);
|
362 |
+
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
363 |
+
fp32_intermediates_casted[3], 0x7632);
|
364 |
+
|
365 |
+
return frag_b;
|
366 |
+
}
|
367 |
+
|
368 |
+
// Multiply dequantized values by the corresponding quantization scale; used
|
369 |
+
// only for grouped quantization.
|
370 |
+
template <typename scalar_t>
|
371 |
+
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
372 |
+
typename ScalarType<scalar_t>::FragS& frag_s,
|
373 |
+
int i) {
|
374 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
375 |
+
scalar_t2 s =
|
376 |
+
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
|
377 |
+
frag_b[0] = __hmul2(frag_b[0], s);
|
378 |
+
frag_b[1] = __hmul2(frag_b[1], s);
|
379 |
+
}
|
380 |
+
|
381 |
+
template <typename scalar_t>
|
382 |
+
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
|
383 |
+
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
|
384 |
+
int i) {
|
385 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
386 |
+
scalar_t2 zp =
|
387 |
+
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
|
388 |
+
frag_b[0] = __hsub2(frag_b[0], zp);
|
389 |
+
frag_b[1] = __hsub2(frag_b[1], zp);
|
390 |
+
}
|
391 |
+
|
392 |
+
// Same as above, but for act_order (each K is multiplied individually)
|
393 |
+
template <typename scalar_t>
|
394 |
+
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
|
395 |
+
typename ScalarType<scalar_t>::FragS& frag_s_1,
|
396 |
+
typename ScalarType<scalar_t>::FragS& frag_s_2,
|
397 |
+
typename ScalarType<scalar_t>::FragS& frag_s_3,
|
398 |
+
typename ScalarType<scalar_t>::FragS& frag_s_4,
|
399 |
+
int i) {
|
400 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
401 |
+
scalar_t2 s_val_1_2;
|
402 |
+
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
|
403 |
+
s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
|
404 |
+
|
405 |
+
scalar_t2 s_val_3_4;
|
406 |
+
s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];
|
407 |
+
s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];
|
408 |
+
|
409 |
+
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
|
410 |
+
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
|
411 |
+
}
|
412 |
+
|
413 |
+
// Given 2 floats multiply by 2 scales (halves)
|
414 |
+
template <typename scalar_t>
|
415 |
+
__device__ inline void scale_float(float* c,
|
416 |
+
typename ScalarType<scalar_t>::FragS& s) {
|
417 |
+
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
|
418 |
+
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
|
419 |
+
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
|
420 |
+
}
|
421 |
+
|
422 |
+
// Wait until barrier reaches `count`, then lock for current threadblock.
|
423 |
+
__device__ inline void barrier_acquire(int* lock, int count) {
|
424 |
+
if (threadIdx.x == 0) {
|
425 |
+
int state = -1;
|
426 |
+
do
|
427 |
+
// Guarantee that subsequent writes by this threadblock will be visible
|
428 |
+
// globally.
|
429 |
+
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
430 |
+
: "=r"(state)
|
431 |
+
: "l"(lock));
|
432 |
+
while (state != count);
|
433 |
+
}
|
434 |
+
__syncthreads();
|
435 |
+
}
|
436 |
+
|
437 |
+
// Release barrier and increment visitation count.
|
438 |
+
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
439 |
+
__syncthreads();
|
440 |
+
if (threadIdx.x == 0) {
|
441 |
+
if (reset) {
|
442 |
+
lock[0] = 0;
|
443 |
+
return;
|
444 |
+
}
|
445 |
+
int val = 1;
|
446 |
+
// Make sure that all writes since acquiring this barrier are visible
|
447 |
+
// globally, while releasing the barrier.
|
448 |
+
asm volatile("fence.acq_rel.gpu;\n");
|
449 |
+
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
450 |
+
:
|
451 |
+
: "l"(lock), "r"(val));
|
452 |
+
}
|
453 |
+
}
|
454 |
+
|
455 |
+
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
456 |
+
// on the given "perm" indices.
|
457 |
+
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
458 |
+
int const* __restrict__ perm_int_ptr,
|
459 |
+
int4* __restrict__ out_int4_ptr, int size_m,
|
460 |
+
int size_k, int block_rows) {
|
461 |
+
int start_row = block_rows * blockIdx.x;
|
462 |
+
int finish_row = start_row + block_rows;
|
463 |
+
if (finish_row > size_m) {
|
464 |
+
finish_row = size_m;
|
465 |
+
}
|
466 |
+
int cur_block_rows = finish_row - start_row;
|
467 |
+
|
468 |
+
int row_stride = size_k * sizeof(half) / 16;
|
469 |
+
|
470 |
+
auto permute_row = [&](int row) {
|
471 |
+
int iters = size_k / default_threads;
|
472 |
+
int rest = size_k % default_threads;
|
473 |
+
|
474 |
+
int offset = row * row_stride;
|
475 |
+
|
476 |
+
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
477 |
+
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
478 |
+
|
479 |
+
int base_k = 0;
|
480 |
+
|
481 |
+
for (int i = 0; i < iters; i++) {
|
482 |
+
int cur_k = base_k + threadIdx.x;
|
483 |
+
int src_pos = perm_int_ptr[cur_k];
|
484 |
+
|
485 |
+
out_half[cur_k] = a_row_half[src_pos];
|
486 |
+
|
487 |
+
base_k += default_threads;
|
488 |
+
}
|
489 |
+
|
490 |
+
if (rest) {
|
491 |
+
if (threadIdx.x < rest) {
|
492 |
+
int cur_k = base_k + threadIdx.x;
|
493 |
+
int src_pos = perm_int_ptr[cur_k];
|
494 |
+
|
495 |
+
out_half[cur_k] = a_row_half[src_pos];
|
496 |
+
}
|
497 |
+
}
|
498 |
+
};
|
499 |
+
|
500 |
+
for (int i = 0; i < cur_block_rows; i++) {
|
501 |
+
int cur_row = start_row + i;
|
502 |
+
if (cur_row < size_m) {
|
503 |
+
permute_row(cur_row);
|
504 |
+
}
|
505 |
+
}
|
506 |
+
}
|
507 |
+
|
508 |
+
template <typename scalar_t, // compute dtype, half or nv_float16
|
509 |
+
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
510 |
+
const int threads, // number of threads in a threadblock
|
511 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
512 |
+
// dimension (batchsize) of the
|
513 |
+
// threadblock
|
514 |
+
const int thread_n_blocks, // same for n dimension (output)
|
515 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
516 |
+
const int stages, // number of stages for the async global->shared
|
517 |
+
// fetch pipeline
|
518 |
+
const bool has_act_order, // whether act_order is enabled
|
519 |
+
const bool has_zp, // whether zero-points are enabled
|
520 |
+
const int group_blocks = -1, // number of consecutive 16x16 blocks
|
521 |
+
// with a separate quantization scale
|
522 |
+
const bool is_zp_float // is zero point of float16 type?
|
523 |
+
>
|
524 |
+
__global__ void Marlin(
|
525 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
526 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
527 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
528 |
+
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
529 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
530 |
+
// (k/groupsize)xn
|
531 |
+
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
532 |
+
// (k/groupsize)x(n/pack_factor)
|
533 |
+
const int* __restrict__ g_idx, // int32 group indices of shape k
|
534 |
+
int num_groups, // number of scale groups per output channel
|
535 |
+
int prob_m, // batch dimension m
|
536 |
+
int prob_n, // output dimension n
|
537 |
+
int prob_k, // reduction dimension k
|
538 |
+
int* locks, // extra global storage for barrier synchronization
|
539 |
+
bool use_fp32_reduce // whether to use fp32 global reduce
|
540 |
+
) {
|
541 |
+
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
542 |
+
// same size, which might involve multiple column "slices" (of width 16 *
|
543 |
+
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
544 |
+
// example:
|
545 |
+
// 0 1 3
|
546 |
+
// 0 2 3
|
547 |
+
// 1 2 4
|
548 |
+
// While this kind of partitioning makes things somewhat more complicated, it
|
549 |
+
// ensures good utilization of all SMs for many kinds of shape and GPU
|
550 |
+
// configurations, while requiring as few slow global cross-threadblock
|
551 |
+
// reductions as possible.
|
552 |
+
using Dtype = ScalarType<scalar_t>;
|
553 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
554 |
+
using FragA = typename ScalarType<scalar_t>::FragA;
|
555 |
+
using FragB = typename ScalarType<scalar_t>::FragB;
|
556 |
+
using FragC = typename ScalarType<scalar_t>::FragC;
|
557 |
+
using FragS = typename ScalarType<scalar_t>::FragS;
|
558 |
+
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
559 |
+
|
560 |
+
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
561 |
+
|
562 |
+
constexpr int pack_factor = 32 / w_type.size_bits();
|
563 |
+
|
564 |
+
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
565 |
+
// better partitioning with less reductions
|
566 |
+
int parallel = 1;
|
567 |
+
if (prob_m > 16 * thread_m_blocks) {
|
568 |
+
parallel = prob_m / (16 * thread_m_blocks);
|
569 |
+
prob_m = 16 * thread_m_blocks;
|
570 |
+
}
|
571 |
+
|
572 |
+
int k_tiles = prob_k / 16 / thread_k_blocks;
|
573 |
+
int n_tiles = prob_n / 16 / thread_n_blocks;
|
574 |
+
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
|
575 |
+
|
576 |
+
if constexpr (!has_act_order && group_blocks != -1) {
|
577 |
+
if (group_blocks >= thread_k_blocks) {
|
578 |
+
// Ensure that the number of tiles in each stripe is a multiple of the
|
579 |
+
// groupsize; this avoids an annoying special case where a stripe starts
|
580 |
+
// in the middle of group.
|
581 |
+
iters = (group_blocks / thread_k_blocks) *
|
582 |
+
div_ceil(iters, (group_blocks / thread_k_blocks));
|
583 |
+
}
|
584 |
+
}
|
585 |
+
|
586 |
+
int slice_row = (iters * blockIdx.x) % k_tiles;
|
587 |
+
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
588 |
+
int slice_col = slice_col_par;
|
589 |
+
int slice_iters; // number of threadblock tiles in the current slice
|
590 |
+
int slice_count =
|
591 |
+
0; // total number of active threadblocks in the current slice
|
592 |
+
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
593 |
+
// top
|
594 |
+
|
595 |
+
int par_id = 0;
|
596 |
+
|
597 |
+
// We can easily implement parallel problem execution by just remapping
|
598 |
+
// indices and advancing global pointers
|
599 |
+
if (slice_col_par >= n_tiles) {
|
600 |
+
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
|
601 |
+
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
602 |
+
locks += (slice_col_par / n_tiles) * n_tiles;
|
603 |
+
slice_col = slice_col_par % n_tiles;
|
604 |
+
par_id = slice_col_par / n_tiles;
|
605 |
+
}
|
606 |
+
|
607 |
+
// Compute all information about the current slice which is required for
|
608 |
+
// synchronization.
|
609 |
+
auto init_slice = [&]() {
|
610 |
+
slice_iters =
|
611 |
+
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
612 |
+
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
613 |
+
if (slice_iters == 0) return;
|
614 |
+
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
615 |
+
slice_count = 1;
|
616 |
+
slice_idx = 0;
|
617 |
+
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
|
618 |
+
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
619 |
+
int col_off = col_first - k_tiles * slice_col_par;
|
620 |
+
slice_count = div_ceil(k_tiles - col_off, iters);
|
621 |
+
if (col_off > 0) slice_count++;
|
622 |
+
int delta_first = iters * blockIdx.x - col_first;
|
623 |
+
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
624 |
+
slice_idx = slice_count - 1;
|
625 |
+
else {
|
626 |
+
slice_idx = slice_count - 1 - delta_first / iters;
|
627 |
+
if (col_off > 0) slice_idx--;
|
628 |
+
}
|
629 |
+
}
|
630 |
+
if (slice_col == n_tiles) {
|
631 |
+
A += 16 * thread_m_blocks * prob_k / 8;
|
632 |
+
C += 16 * thread_m_blocks * prob_n / 8;
|
633 |
+
locks += n_tiles;
|
634 |
+
slice_col = 0;
|
635 |
+
par_id++;
|
636 |
+
}
|
637 |
+
};
|
638 |
+
init_slice();
|
639 |
+
|
640 |
+
// A sizes/strides
|
641 |
+
|
642 |
+
// stride of the A matrix in global memory
|
643 |
+
int a_gl_stride = prob_k / 8;
|
644 |
+
// stride of an A matrix tile in shared memory
|
645 |
+
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
|
646 |
+
// delta between subsequent A tiles in global memory
|
647 |
+
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
|
648 |
+
// between subsequent accesses within a tile
|
649 |
+
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
|
650 |
+
// between shared memory writes
|
651 |
+
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
|
652 |
+
// between shared memory tile reads
|
653 |
+
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
|
654 |
+
// within a shared memory tile
|
655 |
+
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
|
656 |
+
// overall size of a tile
|
657 |
+
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
|
658 |
+
// number of shared write iterations for a tile
|
659 |
+
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
|
660 |
+
|
661 |
+
// B sizes/strides
|
662 |
+
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
663 |
+
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
664 |
+
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
|
665 |
+
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
666 |
+
|
667 |
+
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
668 |
+
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
|
669 |
+
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
|
670 |
+
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
|
671 |
+
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
672 |
+
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
673 |
+
|
674 |
+
// Scale sizes/strides without act_order
|
675 |
+
int s_gl_stride = prob_n / 8;
|
676 |
+
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
677 |
+
constexpr int s_tb_groups =
|
678 |
+
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
679 |
+
? thread_k_blocks / group_blocks
|
680 |
+
: 1;
|
681 |
+
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
682 |
+
int s_gl_rd_delta = s_gl_stride;
|
683 |
+
|
684 |
+
// Scale size/strides with act_order
|
685 |
+
constexpr int tb_k = 16 * thread_k_blocks;
|
686 |
+
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
|
687 |
+
// constexpr int act_s_row_stride = 1;
|
688 |
+
// int act_s_col_stride = act_s_row_stride * num_groups;
|
689 |
+
int act_s_col_stride = 1;
|
690 |
+
int act_s_col_warp_stride = act_s_col_stride * 8;
|
691 |
+
int tb_n_warps = thread_n_blocks / 4;
|
692 |
+
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
693 |
+
|
694 |
+
// Zero-points sizes/strides
|
695 |
+
int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
|
696 |
+
constexpr int zp_sh_stride = is_zp_float
|
697 |
+
? 16 * thread_n_blocks / 8
|
698 |
+
: ((16 * thread_n_blocks) / pack_factor) / 4;
|
699 |
+
constexpr int zp_tb_groups = s_tb_groups;
|
700 |
+
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
701 |
+
int zp_gl_rd_delta = zp_gl_stride;
|
702 |
+
|
703 |
+
// Global A read index of current thread.
|
704 |
+
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
705 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
706 |
+
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
707 |
+
// Shared write index of current thread.
|
708 |
+
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
709 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
710 |
+
// Shared read index.
|
711 |
+
int a_sh_rd =
|
712 |
+
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
713 |
+
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
714 |
+
|
715 |
+
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
|
716 |
+
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
717 |
+
b_gl_rd += b_sh_stride * slice_col;
|
718 |
+
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
719 |
+
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
720 |
+
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
721 |
+
|
722 |
+
// For act_order
|
723 |
+
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
724 |
+
int slice_k_start = tb_k * slice_row;
|
725 |
+
int slice_k_finish = slice_k_start + tb_k * slice_iters;
|
726 |
+
int slice_k_start_shared_fetch = slice_k_start;
|
727 |
+
int slice_n_offset = act_s_col_tb_stride * slice_col;
|
728 |
+
|
729 |
+
// No act_order
|
730 |
+
int s_gl_rd;
|
731 |
+
if constexpr (!has_act_order) {
|
732 |
+
if constexpr (group_blocks == -1) {
|
733 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
734 |
+
} else {
|
735 |
+
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
736 |
+
s_sh_stride * slice_col + threadIdx.x;
|
737 |
+
}
|
738 |
+
}
|
739 |
+
int s_sh_wr = threadIdx.x;
|
740 |
+
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
741 |
+
|
742 |
+
// Zero-points
|
743 |
+
int zp_gl_rd;
|
744 |
+
if constexpr (has_zp) {
|
745 |
+
if constexpr (group_blocks == -1) {
|
746 |
+
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
747 |
+
} else {
|
748 |
+
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
749 |
+
zp_sh_stride * slice_col + threadIdx.x;
|
750 |
+
}
|
751 |
+
}
|
752 |
+
int zp_sh_wr = threadIdx.x;
|
753 |
+
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
754 |
+
|
755 |
+
// We use a different scale layout for grouped and column-wise quantization as
|
756 |
+
// we scale a `half2` tile in column-major layout in the former and in
|
757 |
+
// row-major in the latter case.
|
758 |
+
int s_sh_rd;
|
759 |
+
if constexpr (group_blocks != -1)
|
760 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
761 |
+
(threadIdx.x % 32) / 4;
|
762 |
+
else
|
763 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
764 |
+
(threadIdx.x % 32) % 4;
|
765 |
+
|
766 |
+
// Zero-points have the same read layout as the scales
|
767 |
+
// (without column-wise case)
|
768 |
+
constexpr int num_col_threads = 8;
|
769 |
+
constexpr int num_row_threads = 4;
|
770 |
+
constexpr int num_ints_per_thread = 8 / pack_factor;
|
771 |
+
int zp_sh_rd;
|
772 |
+
if constexpr (has_zp) {
|
773 |
+
if constexpr (is_zp_float) {
|
774 |
+
if constexpr (group_blocks != -1) {
|
775 |
+
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
776 |
+
(threadIdx.x % 32) / 4;
|
777 |
+
}
|
778 |
+
} else {
|
779 |
+
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
780 |
+
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
781 |
+
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
782 |
+
}
|
783 |
+
}
|
784 |
+
|
785 |
+
// Precompute which thread should not read memory in which iterations; this is
|
786 |
+
// needed if there are more threads than required for a certain tilesize or
|
787 |
+
// when the batchsize is not a multiple of 16.
|
788 |
+
bool a_sh_wr_pred[a_sh_wr_iters];
|
789 |
+
#pragma unroll
|
790 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
791 |
+
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
792 |
+
|
793 |
+
// To ensure that writing and reading A tiles to/from shared memory, the
|
794 |
+
// latter in fragment format, is fully bank conflict free, we need to use a
|
795 |
+
// rather fancy XOR-based layout. The key here is that neither reads nor
|
796 |
+
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
|
797 |
+
// same shared memory banks. Further, it seems (based on NSight-Compute) that
|
798 |
+
// each warp must also write a consecutive memory segment?
|
799 |
+
auto transform_a = [&](int i) {
|
800 |
+
int row = i / a_gl_rd_delta_o;
|
801 |
+
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
802 |
+
};
|
803 |
+
// Since the computation of this remapping is non-trivial and, due to our main
|
804 |
+
// loop unrolls, all shared memory accesses are static, we simply precompute
|
805 |
+
// both transformed reads and writes.
|
806 |
+
int a_sh_wr_trans[a_sh_wr_iters];
|
807 |
+
#pragma unroll
|
808 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
809 |
+
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
810 |
+
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
811 |
+
#pragma unroll
|
812 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
813 |
+
#pragma unroll
|
814 |
+
for (int j = 0; j < thread_m_blocks; j++)
|
815 |
+
a_sh_rd_trans[i][j] =
|
816 |
+
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
817 |
+
}
|
818 |
+
|
819 |
+
// Since B-accesses have non-constant stride they have to be computed at
|
820 |
+
// runtime; we break dependencies between subsequent accesses with a tile by
|
821 |
+
// maintining multiple pointers (we have enough registers), a tiny
|
822 |
+
// optimization.
|
823 |
+
const int4* B_ptr[b_sh_wr_iters];
|
824 |
+
#pragma unroll
|
825 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
826 |
+
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
827 |
+
|
828 |
+
extern __shared__ int4 sh[];
|
829 |
+
// Shared memory storage for global fetch pipelines.
|
830 |
+
int4* sh_a = sh;
|
831 |
+
int4* sh_b = sh_a + (stages * a_sh_stage);
|
832 |
+
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
833 |
+
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
834 |
+
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
835 |
+
|
836 |
+
// Register storage for double buffer of shared memory reads.
|
837 |
+
FragA frag_a[2][thread_m_blocks];
|
838 |
+
I4 frag_b_quant[2][b_thread_vecs];
|
839 |
+
FragC frag_c[thread_m_blocks][4][2];
|
840 |
+
FragS frag_s[2][4]; // No act-order
|
841 |
+
FragS act_frag_s[2][4][4]; // For act-order
|
842 |
+
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
843 |
+
FragZP frag_zp; // Zero-points in fp16
|
844 |
+
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
|
845 |
+
|
846 |
+
// Zero accumulators.
|
847 |
+
auto zero_accums = [&]() {
|
848 |
+
#pragma unroll
|
849 |
+
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
850 |
+
reinterpret_cast<float*>(frag_c)[i] = 0;
|
851 |
+
};
|
852 |
+
|
853 |
+
int sh_first_group_id = -1;
|
854 |
+
int sh_num_groups = -1;
|
855 |
+
constexpr int sh_max_num_groups = 32;
|
856 |
+
|
857 |
+
auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
|
858 |
+
int last_group_id) {
|
859 |
+
sh_first_group_id = first_group_id;
|
860 |
+
sh_num_groups = last_group_id - first_group_id + 1;
|
861 |
+
|
862 |
+
if (sh_num_groups < sh_max_num_groups) {
|
863 |
+
sh_num_groups = sh_max_num_groups;
|
864 |
+
}
|
865 |
+
|
866 |
+
if (sh_first_group_id + sh_num_groups > num_groups) {
|
867 |
+
sh_num_groups = num_groups - sh_first_group_id;
|
868 |
+
}
|
869 |
+
|
870 |
+
int row_offset = first_group_id * s_gl_stride;
|
871 |
+
|
872 |
+
if (is_async) {
|
873 |
+
for (int i = 0; i < sh_num_groups; i++) {
|
874 |
+
if (threadIdx.x < s_sh_stride) {
|
875 |
+
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
|
876 |
+
&scales_ptr[row_offset + (i * s_gl_stride) +
|
877 |
+
slice_n_offset + threadIdx.x]);
|
878 |
+
}
|
879 |
+
}
|
880 |
+
} else {
|
881 |
+
for (int i = 0; i < sh_num_groups; i++) {
|
882 |
+
if (threadIdx.x < s_sh_stride) {
|
883 |
+
sh_s[(i * s_sh_stride) + threadIdx.x] =
|
884 |
+
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
|
885 |
+
threadIdx.x];
|
886 |
+
}
|
887 |
+
}
|
888 |
+
}
|
889 |
+
};
|
890 |
+
// Asynchronously fetch the next A, B and s tile from global to the next
|
891 |
+
// shared memory pipeline location.
|
892 |
+
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
893 |
+
if (pred) {
|
894 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
895 |
+
#pragma unroll
|
896 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
897 |
+
cp_async4_pred(
|
898 |
+
&sh_a_stage[a_sh_wr_trans[i]],
|
899 |
+
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
900 |
+
a_sh_wr_pred[i]);
|
901 |
+
}
|
902 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
903 |
+
#pragma unroll
|
904 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
905 |
+
#pragma unroll
|
906 |
+
for (int j = 0; j < b_thread_vecs; j++) {
|
907 |
+
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
908 |
+
}
|
909 |
+
|
910 |
+
B_ptr[i] += b_gl_rd_delta_o;
|
911 |
+
}
|
912 |
+
|
913 |
+
if constexpr (has_act_order) {
|
914 |
+
// Fetch g_idx thread-block portion
|
915 |
+
int full_pipe = a_off;
|
916 |
+
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
|
917 |
+
if (cur_k < prob_k && cur_k < slice_k_finish) {
|
918 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
919 |
+
|
920 |
+
int4 const* cur_g_idx_stage_ptr =
|
921 |
+
reinterpret_cast<int4 const*>(&g_idx[cur_k]);
|
922 |
+
|
923 |
+
if (threadIdx.x < g_idx_stage) {
|
924 |
+
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
|
925 |
+
&cur_g_idx_stage_ptr[threadIdx.x]);
|
926 |
+
}
|
927 |
+
}
|
928 |
+
} else {
|
929 |
+
if constexpr (group_blocks != -1) {
|
930 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
931 |
+
|
932 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
933 |
+
// Only fetch scales if this tile starts a new group
|
934 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
935 |
+
if (s_sh_wr_pred) {
|
936 |
+
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
937 |
+
}
|
938 |
+
s_gl_rd += s_gl_rd_delta;
|
939 |
+
}
|
940 |
+
} else {
|
941 |
+
for (int i = 0; i < s_tb_groups; i++) {
|
942 |
+
if (s_sh_wr_pred) {
|
943 |
+
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
|
944 |
+
&scales_ptr[s_gl_rd]);
|
945 |
+
}
|
946 |
+
s_gl_rd += s_gl_rd_delta;
|
947 |
+
}
|
948 |
+
}
|
949 |
+
}
|
950 |
+
|
951 |
+
if constexpr (has_zp && group_blocks != -1) {
|
952 |
+
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
953 |
+
|
954 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
955 |
+
// Only fetch zero-points if this tile starts a new group
|
956 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
957 |
+
if (zp_sh_wr_pred) {
|
958 |
+
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
959 |
+
}
|
960 |
+
zp_gl_rd += zp_gl_rd_delta;
|
961 |
+
}
|
962 |
+
} else {
|
963 |
+
for (int i = 0; i < zp_tb_groups; i++) {
|
964 |
+
if (zp_sh_wr_pred) {
|
965 |
+
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
966 |
+
&zp_ptr[zp_gl_rd]);
|
967 |
+
}
|
968 |
+
zp_gl_rd += zp_gl_rd_delta;
|
969 |
+
}
|
970 |
+
}
|
971 |
+
}
|
972 |
+
}
|
973 |
+
}
|
974 |
+
// Insert a fence even when we are winding down the pipeline to ensure that
|
975 |
+
// waiting is also correct at this point.
|
976 |
+
cp_async_fence();
|
977 |
+
};
|
978 |
+
|
979 |
+
auto fetch_zp_to_shared = [&]() {
|
980 |
+
if (zp_sh_wr_pred) {
|
981 |
+
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
982 |
+
}
|
983 |
+
};
|
984 |
+
|
985 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
986 |
+
auto wait_for_stage = [&]() {
|
987 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
988 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
989 |
+
// shared memory load is fully complete (as it may otherwise be
|
990 |
+
// overwritten).
|
991 |
+
cp_async_wait<stages - 2>();
|
992 |
+
__syncthreads();
|
993 |
+
};
|
994 |
+
|
995 |
+
// Load the next sub-tile from the current location in the shared memory pipe
|
996 |
+
// into the current register buffer.
|
997 |
+
auto fetch_to_registers = [&](int k, int pipe) {
|
998 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
999 |
+
#pragma unroll
|
1000 |
+
for (int i = 0; i < thread_m_blocks; i++)
|
1001 |
+
ldsm4<scalar_t>(frag_a[k % 2][i],
|
1002 |
+
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
1003 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
1004 |
+
|
1005 |
+
#pragma unroll
|
1006 |
+
for (int i = 0; i < b_thread_vecs; i++) {
|
1007 |
+
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
1008 |
+
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
1009 |
+
}
|
1010 |
+
};
|
1011 |
+
|
1012 |
+
bool is_same_group[stages];
|
1013 |
+
int same_group_id[stages];
|
1014 |
+
|
1015 |
+
auto init_same_group = [&](int pipe) {
|
1016 |
+
if constexpr (!has_act_order) {
|
1017 |
+
is_same_group[pipe] = false;
|
1018 |
+
same_group_id[pipe] = 0;
|
1019 |
+
return;
|
1020 |
+
}
|
1021 |
+
|
1022 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
1023 |
+
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
1024 |
+
|
1025 |
+
int group_id_1 = sh_g_idx_int_ptr[0];
|
1026 |
+
int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
|
1027 |
+
|
1028 |
+
is_same_group[pipe] = group_id_1 == group_id_2;
|
1029 |
+
same_group_id[pipe] = group_id_1;
|
1030 |
+
};
|
1031 |
+
|
1032 |
+
auto fetch_scales_to_registers = [&](int k, int full_pipe) {
|
1033 |
+
int pipe = full_pipe % stages;
|
1034 |
+
|
1035 |
+
if constexpr (!has_act_order) {
|
1036 |
+
// No act-order case
|
1037 |
+
if constexpr (group_blocks != -1) {
|
1038 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
1039 |
+
int4* sh_s_stage =
|
1040 |
+
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
1041 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
1042 |
+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
1043 |
+
} else {
|
1044 |
+
int warp_id = threadIdx.x / 32;
|
1045 |
+
int n_warps = thread_n_blocks / 4;
|
1046 |
+
|
1047 |
+
int warp_row = warp_id / n_warps;
|
1048 |
+
|
1049 |
+
int cur_k = warp_row * 16;
|
1050 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
1051 |
+
|
1052 |
+
int k_blocks = cur_k / 16;
|
1053 |
+
int cur_group_id = k_blocks / group_blocks;
|
1054 |
+
|
1055 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
1056 |
+
|
1057 |
+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
1058 |
+
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
1059 |
+
}
|
1060 |
+
}
|
1061 |
+
|
1062 |
+
return;
|
1063 |
+
}
|
1064 |
+
|
1065 |
+
// Act-order case
|
1066 |
+
|
1067 |
+
// Determine K of the "current" thread-block
|
1068 |
+
int cur_k = slice_k_start + tb_k * full_pipe;
|
1069 |
+
if (cur_k >= prob_k || cur_k >= slice_k_finish) {
|
1070 |
+
return;
|
1071 |
+
}
|
1072 |
+
|
1073 |
+
// Reset (to current thread-block) since we read g_idx portion from the
|
1074 |
+
// shared memory
|
1075 |
+
cur_k = 0;
|
1076 |
+
|
1077 |
+
// Progress to current iteration
|
1078 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
1079 |
+
|
1080 |
+
// Determine "position" inside the thread-block (based on warp and
|
1081 |
+
// thread-id)
|
1082 |
+
int warp_id = threadIdx.x / 32;
|
1083 |
+
int n_warps =
|
1084 |
+
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
1085 |
+
|
1086 |
+
int warp_row = warp_id / n_warps;
|
1087 |
+
int warp_col = warp_id % n_warps;
|
1088 |
+
|
1089 |
+
cur_k += warp_row * 16;
|
1090 |
+
|
1091 |
+
int th_id = threadIdx.x % 32;
|
1092 |
+
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
1093 |
+
|
1094 |
+
int s_col_shift =
|
1095 |
+
/*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
|
1096 |
+
(th_id / 4) * act_s_col_stride;
|
1097 |
+
|
1098 |
+
if (is_same_group[pipe]) {
|
1099 |
+
if (k % 2 == 0) {
|
1100 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
|
1101 |
+
sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
|
1102 |
+
s_col_shift];
|
1103 |
+
} else {
|
1104 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
|
1105 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
|
1106 |
+
}
|
1107 |
+
|
1108 |
+
for (int i = 1; i < 4; i++) {
|
1109 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
|
1110 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
|
1111 |
+
}
|
1112 |
+
return;
|
1113 |
+
}
|
1114 |
+
|
1115 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
1116 |
+
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
1117 |
+
|
1118 |
+
constexpr int k_frag_offsets[4] = {0, 1, 8,
|
1119 |
+
9}; // Tensor core offsets per thread
|
1120 |
+
|
1121 |
+
#pragma unroll
|
1122 |
+
for (int i = 0; i < 4; i++) {
|
1123 |
+
int actual_k = cur_k + k_frag_offsets[i];
|
1124 |
+
|
1125 |
+
int group_id = sh_g_idx_int_ptr[actual_k];
|
1126 |
+
int rel_group_id = group_id - sh_first_group_id;
|
1127 |
+
|
1128 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
|
1129 |
+
sh_s[rel_group_id * s_sh_stride + s_col_shift];
|
1130 |
+
}
|
1131 |
+
};
|
1132 |
+
|
1133 |
+
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
1134 |
+
// This code does not handle group_blocks == 0,
|
1135 |
+
// which signifies act_order.
|
1136 |
+
// has_zp implies AWQ, which doesn't have act_order,
|
1137 |
+
static_assert(!has_zp || group_blocks != 0);
|
1138 |
+
|
1139 |
+
if constexpr (has_zp && !is_zp_float) {
|
1140 |
+
int pipe = full_pipe % stages;
|
1141 |
+
|
1142 |
+
if constexpr (group_blocks == -1) {
|
1143 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
1144 |
+
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
1145 |
+
}
|
1146 |
+
|
1147 |
+
} else if constexpr (group_blocks >= thread_k_blocks) {
|
1148 |
+
int4* sh_zp_stage =
|
1149 |
+
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
1150 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
1151 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
1152 |
+
frag_qzp[k % 2][i] =
|
1153 |
+
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
1154 |
+
}
|
1155 |
+
} else {
|
1156 |
+
int warp_id = threadIdx.x / 32;
|
1157 |
+
int n_warps = thread_n_blocks / 4;
|
1158 |
+
|
1159 |
+
int warp_row = warp_id / n_warps;
|
1160 |
+
|
1161 |
+
int cur_k = warp_row * 16;
|
1162 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
1163 |
+
|
1164 |
+
int k_blocks = cur_k / 16;
|
1165 |
+
int cur_group_id = 0;
|
1166 |
+
|
1167 |
+
// Suppress bogus and persistent divide-by-zero warning
|
1168 |
+
#pragma nv_diagnostic push
|
1169 |
+
#pragma nv_diag_suppress divide_by_zero
|
1170 |
+
cur_group_id = k_blocks / group_blocks;
|
1171 |
+
#pragma nv_diagnostic pop
|
1172 |
+
|
1173 |
+
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
1174 |
+
|
1175 |
+
sh_zp_stage += cur_group_id * zp_sh_stride;
|
1176 |
+
|
1177 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
1178 |
+
frag_qzp[k % 2][i] =
|
1179 |
+
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
1180 |
+
}
|
1181 |
+
}
|
1182 |
+
}
|
1183 |
+
|
1184 |
+
else if constexpr (has_zp && is_zp_float) {
|
1185 |
+
int pipe = full_pipe % stages;
|
1186 |
+
|
1187 |
+
if constexpr (group_blocks != -1) {
|
1188 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
1189 |
+
int4* sh_zp_stage =
|
1190 |
+
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
1191 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
1192 |
+
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
1193 |
+
} else {
|
1194 |
+
int warp_id = threadIdx.x / 32;
|
1195 |
+
int n_warps = thread_n_blocks / 4;
|
1196 |
+
|
1197 |
+
int warp_row = warp_id / n_warps;
|
1198 |
+
|
1199 |
+
int cur_k = warp_row * 16;
|
1200 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
1201 |
+
|
1202 |
+
int k_blocks = cur_k / 16;
|
1203 |
+
// Suppress bogus and persistent divide-by-zero warning
|
1204 |
+
#pragma nv_diagnostic push
|
1205 |
+
#pragma nv_diag_suppress divide_by_zero
|
1206 |
+
int cur_group_id = k_blocks / group_blocks;
|
1207 |
+
#pragma nv_diagnostic pop
|
1208 |
+
|
1209 |
+
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
1210 |
+
|
1211 |
+
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
|
1212 |
+
sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
|
1213 |
+
}
|
1214 |
+
}
|
1215 |
+
}
|
1216 |
+
};
|
1217 |
+
|
1218 |
+
// Execute the actual tensor core matmul of a sub-tile.
|
1219 |
+
auto matmul = [&](int k) {
|
1220 |
+
if constexpr (has_zp && !is_zp_float) {
|
1221 |
+
FragB frag_zp_0;
|
1222 |
+
FragB frag_zp_1;
|
1223 |
+
int zp_quant_0, zp_quant_1;
|
1224 |
+
|
1225 |
+
if constexpr (w_type.size_bits() == 4) {
|
1226 |
+
zp_quant_0 = frag_qzp[k % 2][0];
|
1227 |
+
zp_quant_1 = zp_quant_0 >> 8;
|
1228 |
+
} else {
|
1229 |
+
static_assert(w_type.size_bits() == 8);
|
1230 |
+
zp_quant_0 = frag_qzp[k % 2][0];
|
1231 |
+
zp_quant_1 = frag_qzp[k % 2][1];
|
1232 |
+
}
|
1233 |
+
|
1234 |
+
frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
|
1235 |
+
frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
|
1236 |
+
|
1237 |
+
frag_zp[0] = frag_zp_0[0];
|
1238 |
+
frag_zp[1] = frag_zp_0[1];
|
1239 |
+
frag_zp[2] = frag_zp_1[0];
|
1240 |
+
frag_zp[3] = frag_zp_1[1];
|
1241 |
+
}
|
1242 |
+
|
1243 |
+
// We have the m dimension as the inner loop in order to encourage overlapping
|
1244 |
+
// dequantization and matmul operations.
|
1245 |
+
#pragma unroll
|
1246 |
+
for (int j = 0; j < 4; j++) {
|
1247 |
+
FragB frag_b0;
|
1248 |
+
FragB frag_b1;
|
1249 |
+
int b_quant_0, b_quant_1;
|
1250 |
+
|
1251 |
+
if constexpr (w_type.size_bits() == 4) {
|
1252 |
+
b_quant_0 = frag_b_quant[k % 2][0][j];
|
1253 |
+
b_quant_1 = b_quant_0 >> 8;
|
1254 |
+
} else {
|
1255 |
+
static_assert(w_type.size_bits() == 8);
|
1256 |
+
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
1257 |
+
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
1258 |
+
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
1259 |
+
}
|
1260 |
+
|
1261 |
+
frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
|
1262 |
+
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
|
1263 |
+
|
1264 |
+
// Apply zero-point to frag_b0
|
1265 |
+
if constexpr (has_zp && !is_zp_float) {
|
1266 |
+
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
1267 |
+
}
|
1268 |
+
|
1269 |
+
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
1270 |
+
sub_zp<scalar_t>(frag_b0, frag_zpf[k % 2][j], 0);
|
1271 |
+
}
|
1272 |
+
|
1273 |
+
// Apply scale to frag_b0
|
1274 |
+
if constexpr (has_act_order) {
|
1275 |
+
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
|
1276 |
+
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
|
1277 |
+
act_frag_s[k % 2][3][j], 0);
|
1278 |
+
} else {
|
1279 |
+
if constexpr (group_blocks != -1) {
|
1280 |
+
scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
|
1281 |
+
}
|
1282 |
+
}
|
1283 |
+
|
1284 |
+
// Apply zero-point to frag_b1
|
1285 |
+
if constexpr (has_zp && !is_zp_float) {
|
1286 |
+
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
1287 |
+
}
|
1288 |
+
|
1289 |
+
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
1290 |
+
sub_zp<scalar_t>(frag_b1, frag_zpf[k % 2][j], 1);
|
1291 |
+
}
|
1292 |
+
|
1293 |
+
// Apply scale to frag_b1
|
1294 |
+
if constexpr (has_act_order) {
|
1295 |
+
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
1296 |
+
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
|
1297 |
+
act_frag_s[k % 2][3][j], 1);
|
1298 |
+
|
1299 |
+
} else {
|
1300 |
+
if constexpr (group_blocks != -1) {
|
1301 |
+
scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
|
1302 |
+
}
|
1303 |
+
}
|
1304 |
+
|
1305 |
+
#pragma unroll
|
1306 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
1307 |
+
mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
1308 |
+
mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
1309 |
+
}
|
1310 |
+
}
|
1311 |
+
};
|
1312 |
+
|
1313 |
+
// Since we slice across the k dimension of a tile in order to increase the
|
1314 |
+
// number of warps while keeping the n dimension of a tile reasonable, we have
|
1315 |
+
// multiple warps that accumulate their partial sums of the same output
|
1316 |
+
// location; which we have to reduce over in the end. We do in shared memory.
|
1317 |
+
auto thread_block_reduce = [&]() {
|
1318 |
+
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
1319 |
+
if (red_off >= 1) {
|
1320 |
+
int red_idx = threadIdx.x / b_sh_stride_threads;
|
1321 |
+
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
1322 |
+
constexpr int red_sh_delta = b_sh_stride_threads;
|
1323 |
+
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
1324 |
+
(threadIdx.x % b_sh_stride_threads);
|
1325 |
+
|
1326 |
+
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
1327 |
+
// unnecessary read or write iterations, e.g., for two warps we write only
|
1328 |
+
// once by warp 1 and read only once by warp 0.
|
1329 |
+
|
1330 |
+
#pragma unroll
|
1331 |
+
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
1332 |
+
#pragma unroll
|
1333 |
+
for (int i = red_off; i > 0; i /= 2) {
|
1334 |
+
if (i <= red_idx && red_idx < 2 * i) {
|
1335 |
+
#pragma unroll
|
1336 |
+
for (int j = 0; j < 4 * 2; j++) {
|
1337 |
+
int red_sh_wr =
|
1338 |
+
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
1339 |
+
if (i < red_off) {
|
1340 |
+
float* c_rd =
|
1341 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
1342 |
+
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
1343 |
+
#pragma unroll
|
1344 |
+
for (int k = 0; k < 4; k++)
|
1345 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
1346 |
+
c_rd[k] + c_wr[k];
|
1347 |
+
}
|
1348 |
+
sh[red_sh_wr] =
|
1349 |
+
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
1350 |
+
}
|
1351 |
+
}
|
1352 |
+
__syncthreads();
|
1353 |
+
}
|
1354 |
+
if (red_idx == 0) {
|
1355 |
+
#pragma unroll
|
1356 |
+
for (int i = 0; i < 4 * 2; i++) {
|
1357 |
+
float* c_rd =
|
1358 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
1359 |
+
#pragma unroll
|
1360 |
+
for (int j = 0; j < 4; j++)
|
1361 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
1362 |
+
c_rd[j];
|
1363 |
+
}
|
1364 |
+
}
|
1365 |
+
__syncthreads();
|
1366 |
+
}
|
1367 |
+
}
|
1368 |
+
};
|
1369 |
+
|
1370 |
+
// Since multiple threadblocks may process parts of the same column slice, we
|
1371 |
+
// finally have to globally reduce over the results. As the striped
|
1372 |
+
// partitioning minimizes the number of such reductions and our outputs are
|
1373 |
+
// usually rather small, we perform this reduction serially in L2 cache.
|
1374 |
+
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
|
1375 |
+
// We are very careful here to reduce directly in the output buffer to
|
1376 |
+
// maximize L2 cache utilization in this step. To do this, we write out
|
1377 |
+
// results in FP16 (but still reduce with FP32 compute).
|
1378 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
1379 |
+
if (threadIdx.x < active_threads) {
|
1380 |
+
int c_gl_stride = prob_n / 8;
|
1381 |
+
int c_gl_wr_delta_o = 8 * c_gl_stride;
|
1382 |
+
int c_gl_wr_delta_i = 4 * (active_threads / 32);
|
1383 |
+
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
|
1384 |
+
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
1385 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
1386 |
+
constexpr int c_sh_wr_delta = active_threads;
|
1387 |
+
int c_sh_wr = threadIdx.x;
|
1388 |
+
|
1389 |
+
int row = (threadIdx.x % 32) / 4;
|
1390 |
+
|
1391 |
+
if (!first) {
|
1392 |
+
// Interestingly, doing direct global accesses here really seems to mess up
|
1393 |
+
// the compiler and lead to slowdowns, hence we also use async-copies even
|
1394 |
+
// though these fetches are not actually asynchronous.
|
1395 |
+
#pragma unroll
|
1396 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
1397 |
+
cp_async4_pred(
|
1398 |
+
&sh[c_sh_wr + c_sh_wr_delta * i],
|
1399 |
+
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
1400 |
+
c_gl_wr_delta_i * (i % 2)],
|
1401 |
+
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
1402 |
+
}
|
1403 |
+
cp_async_fence();
|
1404 |
+
cp_async_wait<0>();
|
1405 |
+
}
|
1406 |
+
|
1407 |
+
#pragma unroll
|
1408 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
1409 |
+
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
1410 |
+
if (!first) {
|
1411 |
+
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
1412 |
+
#pragma unroll
|
1413 |
+
for (int j = 0; j < 2 * 4; j++) {
|
1414 |
+
reinterpret_cast<float*>(
|
1415 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
1416 |
+
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
|
1417 |
+
}
|
1418 |
+
}
|
1419 |
+
if (!last) {
|
1420 |
+
int4 c;
|
1421 |
+
#pragma unroll
|
1422 |
+
for (int j = 0; j < 2 * 4; j++) {
|
1423 |
+
reinterpret_cast<scalar_t*>(&c)[j] =
|
1424 |
+
Dtype::float2num(reinterpret_cast<float*>(
|
1425 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
1426 |
+
}
|
1427 |
+
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
1428 |
+
c;
|
1429 |
+
}
|
1430 |
+
}
|
1431 |
+
}
|
1432 |
+
}
|
1433 |
+
};
|
1434 |
+
|
1435 |
+
// Globally reduce over threadblocks that compute the same column block.
|
1436 |
+
// We use a tmp C buffer to reduce in full fp32 precision.
|
1437 |
+
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
|
1438 |
+
constexpr int tb_m = thread_m_blocks * 16;
|
1439 |
+
constexpr int tb_n = thread_n_blocks * 16;
|
1440 |
+
|
1441 |
+
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
|
1442 |
+
|
1443 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
1444 |
+
bool is_th_active = threadIdx.x < active_threads;
|
1445 |
+
|
1446 |
+
int par_offset = c_size * n_tiles * par_id;
|
1447 |
+
int slice_offset = c_size * slice_col;
|
1448 |
+
|
1449 |
+
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
|
1450 |
+
constexpr int th_size = num_floats * sizeof(float) / 16;
|
1451 |
+
|
1452 |
+
int c_cur_offset = par_offset + slice_offset;
|
1453 |
+
|
1454 |
+
if (!is_th_active) {
|
1455 |
+
return;
|
1456 |
+
}
|
1457 |
+
|
1458 |
+
if (!first) {
|
1459 |
+
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
1460 |
+
#pragma unroll
|
1461 |
+
for (int k = 0; k < th_size; k++) {
|
1462 |
+
sh[threadIdx.x] =
|
1463 |
+
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
1464 |
+
|
1465 |
+
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
|
1466 |
+
#pragma unroll
|
1467 |
+
for (int f = 0; f < 4; f++) {
|
1468 |
+
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
1469 |
+
}
|
1470 |
+
}
|
1471 |
+
}
|
1472 |
+
|
1473 |
+
if (!last) {
|
1474 |
+
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
|
1475 |
+
#pragma unroll
|
1476 |
+
for (int k = 0; k < th_size; k++) {
|
1477 |
+
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
|
1478 |
+
}
|
1479 |
+
}
|
1480 |
+
};
|
1481 |
+
|
1482 |
+
// Write out the reduce final result in the correct layout. We only actually
|
1483 |
+
// reshuffle matrix fragments in this step, the reduction above is performed
|
1484 |
+
// in fragment layout.
|
1485 |
+
auto write_result = [&]() {
|
1486 |
+
int c_gl_stride = prob_n / 8;
|
1487 |
+
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
1488 |
+
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
1489 |
+
constexpr int c_sh_rd_delta =
|
1490 |
+
c_sh_stride * (threads / (2 * thread_n_blocks));
|
1491 |
+
|
1492 |
+
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
1493 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
1494 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
1495 |
+
int c_sh_wr =
|
1496 |
+
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
|
1497 |
+
c_sh_wr += 32 * (threadIdx.x / 32);
|
1498 |
+
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
1499 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
1500 |
+
|
1501 |
+
int c_gl_wr_end = c_gl_stride * prob_m;
|
1502 |
+
|
1503 |
+
// We first reorder in shared memory to guarantee the most efficient final
|
1504 |
+
// global write patterns
|
1505 |
+
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
1506 |
+
scalar_t2 res =
|
1507 |
+
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
1508 |
+
|
1509 |
+
// For per-column quantization we finally apply the scale here (only for
|
1510 |
+
// 4-bit)
|
1511 |
+
if constexpr (!has_act_order && group_blocks == -1 &&
|
1512 |
+
w_type.size_bits() == 4) {
|
1513 |
+
res = __hmul2(res, s[0]);
|
1514 |
+
}
|
1515 |
+
|
1516 |
+
((scalar_t2*)sh)[idx] = res;
|
1517 |
+
};
|
1518 |
+
|
1519 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
1520 |
+
#pragma unroll
|
1521 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
1522 |
+
#pragma unroll
|
1523 |
+
for (int j = 0; j < 4; j++) {
|
1524 |
+
int wr = c_sh_wr + 8 * j;
|
1525 |
+
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
1526 |
+
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
1527 |
+
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
1528 |
+
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
1529 |
+
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
1530 |
+
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
1531 |
+
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
1532 |
+
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
1533 |
+
}
|
1534 |
+
c_sh_wr += 16 * (4 * c_sh_stride);
|
1535 |
+
}
|
1536 |
+
}
|
1537 |
+
__syncthreads();
|
1538 |
+
|
1539 |
+
#pragma unroll
|
1540 |
+
for (int i = 0;
|
1541 |
+
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
1542 |
+
i++) {
|
1543 |
+
if (c_gl_wr < c_gl_wr_end) {
|
1544 |
+
C[c_gl_wr] = sh[c_sh_rd];
|
1545 |
+
c_gl_wr += c_gl_wr_delta;
|
1546 |
+
c_sh_rd += c_sh_rd_delta;
|
1547 |
+
}
|
1548 |
+
}
|
1549 |
+
};
|
1550 |
+
|
1551 |
+
// Start global fetch and register load pipelines.
|
1552 |
+
auto start_pipes = [&]() {
|
1553 |
+
|
1554 |
+
#pragma unroll
|
1555 |
+
for (int i = 0; i < stages - 1; i++) {
|
1556 |
+
if (has_act_order && i == 0) {
|
1557 |
+
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
1558 |
+
if (last_g_idx >= prob_k) {
|
1559 |
+
last_g_idx = prob_k - 1;
|
1560 |
+
}
|
1561 |
+
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
1562 |
+
}
|
1563 |
+
|
1564 |
+
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
1565 |
+
if (i == 0) {
|
1566 |
+
fetch_zp_to_shared();
|
1567 |
+
}
|
1568 |
+
}
|
1569 |
+
fetch_to_shared(i, i, i < slice_iters);
|
1570 |
+
}
|
1571 |
+
|
1572 |
+
zero_accums();
|
1573 |
+
wait_for_stage();
|
1574 |
+
init_same_group(0);
|
1575 |
+
fetch_to_registers(0, 0);
|
1576 |
+
fetch_scales_to_registers(0, 0);
|
1577 |
+
fetch_zp_to_registers(0, 0);
|
1578 |
+
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
1579 |
+
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
1580 |
+
};
|
1581 |
+
if (slice_iters) {
|
1582 |
+
start_pipes();
|
1583 |
+
}
|
1584 |
+
|
1585 |
+
// Main loop.
|
1586 |
+
while (slice_iters) {
|
1587 |
+
// We unroll over both the global fetch and the register load pipeline to
|
1588 |
+
// ensure all shared memory accesses are static. Note that both pipelines
|
1589 |
+
// have even length meaning that the next iteration will always start at
|
1590 |
+
// index 0.
|
1591 |
+
|
1592 |
+
#pragma unroll
|
1593 |
+
for (int pipe = 0; pipe < stages;) {
|
1594 |
+
#pragma unroll
|
1595 |
+
for (int k = 0; k < b_sh_wr_iters; k++) {
|
1596 |
+
fetch_to_registers(k + 1, pipe % stages);
|
1597 |
+
fetch_scales_to_registers(k + 1, pipe);
|
1598 |
+
fetch_zp_to_registers(k + 1, pipe);
|
1599 |
+
if (k == b_sh_wr_iters - 2) {
|
1600 |
+
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
1601 |
+
slice_iters >= stages);
|
1602 |
+
pipe++;
|
1603 |
+
wait_for_stage();
|
1604 |
+
init_same_group(pipe % stages);
|
1605 |
+
}
|
1606 |
+
matmul(k);
|
1607 |
+
}
|
1608 |
+
slice_iters--;
|
1609 |
+
if (slice_iters == 0) {
|
1610 |
+
break;
|
1611 |
+
}
|
1612 |
+
}
|
1613 |
+
|
1614 |
+
a_gl_rd += a_gl_rd_delta_o * stages;
|
1615 |
+
slice_k_start += tb_k * stages;
|
1616 |
+
slice_k_start_shared_fetch += tb_k * stages;
|
1617 |
+
|
1618 |
+
if constexpr (has_act_order) {
|
1619 |
+
int first_group_id = g_idx[slice_k_start];
|
1620 |
+
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
1621 |
+
if (last_g_idx >= prob_k) {
|
1622 |
+
last_g_idx = prob_k - 1;
|
1623 |
+
}
|
1624 |
+
int last_group_id = g_idx[last_g_idx];
|
1625 |
+
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
1626 |
+
fetch_scales_to_shared(false, first_group_id, last_group_id);
|
1627 |
+
__syncthreads();
|
1628 |
+
}
|
1629 |
+
}
|
1630 |
+
|
1631 |
+
// Process results and, if necessary, proceed to the next column slice.
|
1632 |
+
// While this pattern may not be the most readable, other ways of writing
|
1633 |
+
// the loop seemed to noticeably worse performance after compilation.
|
1634 |
+
if (slice_iters == 0) {
|
1635 |
+
cp_async_wait<0>();
|
1636 |
+
bool last = slice_idx == slice_count - 1;
|
1637 |
+
// For per-column scales, we only fetch them here in the final step before
|
1638 |
+
// write-out
|
1639 |
+
if constexpr (!has_act_order && group_blocks == -1) {
|
1640 |
+
if constexpr (w_type.size_bits() == 8) {
|
1641 |
+
if (s_sh_wr_pred) {
|
1642 |
+
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
1643 |
+
}
|
1644 |
+
cp_async_fence();
|
1645 |
+
} else {
|
1646 |
+
if (last) {
|
1647 |
+
if (s_sh_wr_pred) {
|
1648 |
+
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
1649 |
+
}
|
1650 |
+
cp_async_fence();
|
1651 |
+
}
|
1652 |
+
}
|
1653 |
+
}
|
1654 |
+
|
1655 |
+
thread_block_reduce();
|
1656 |
+
if constexpr (!has_act_order && group_blocks == -1) {
|
1657 |
+
if constexpr (w_type.size_bits() == 8) {
|
1658 |
+
cp_async_wait<0>();
|
1659 |
+
__syncthreads();
|
1660 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
1661 |
+
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
1662 |
+
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
1663 |
+
}
|
1664 |
+
|
1665 |
+
} else {
|
1666 |
+
if (last) {
|
1667 |
+
cp_async_wait<0>();
|
1668 |
+
__syncthreads();
|
1669 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
1670 |
+
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
1671 |
+
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
1672 |
+
}
|
1673 |
+
}
|
1674 |
+
}
|
1675 |
+
}
|
1676 |
+
|
1677 |
+
// For 8-bit channelwise, we apply the scale before the global reduction
|
1678 |
+
// that converts the fp32 results to fp16 (so that we avoid possible
|
1679 |
+
// overflow in fp16)
|
1680 |
+
if constexpr (!has_act_order && group_blocks == -1 &&
|
1681 |
+
w_type.size_bits() == 8) {
|
1682 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
1683 |
+
#pragma unroll
|
1684 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
1685 |
+
#pragma unroll
|
1686 |
+
for (int j = 0; j < 4; j++) {
|
1687 |
+
scale_float<scalar_t>(
|
1688 |
+
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
|
1689 |
+
frag_s[j / 2][2 * (j % 2) + 0]);
|
1690 |
+
scale_float<scalar_t>(
|
1691 |
+
reinterpret_cast<float*>(&frag_c[i][j][0][2]),
|
1692 |
+
frag_s[j / 2][2 * (j % 2) + 0]);
|
1693 |
+
|
1694 |
+
scale_float<scalar_t>(
|
1695 |
+
reinterpret_cast<float*>(&frag_c[i][j][1][0]),
|
1696 |
+
frag_s[j / 2][2 * (j % 2) + 1]);
|
1697 |
+
scale_float<scalar_t>(
|
1698 |
+
reinterpret_cast<float*>(&frag_c[i][j][1][2]),
|
1699 |
+
frag_s[j / 2][2 * (j % 2) + 1]);
|
1700 |
+
}
|
1701 |
+
}
|
1702 |
+
}
|
1703 |
+
}
|
1704 |
+
|
1705 |
+
if (slice_count > 1) { // only globally reduce if there is more than one
|
1706 |
+
// block in a slice
|
1707 |
+
barrier_acquire(&locks[slice_col], slice_idx);
|
1708 |
+
if (use_fp32_reduce) {
|
1709 |
+
global_reduce_fp32(slice_idx == 0, last);
|
1710 |
+
} else {
|
1711 |
+
global_reduce_fp16(slice_idx == 0, last);
|
1712 |
+
}
|
1713 |
+
barrier_release(&locks[slice_col], last);
|
1714 |
+
}
|
1715 |
+
if (last) // only the last block in a slice actually writes the result
|
1716 |
+
write_result();
|
1717 |
+
slice_row = 0;
|
1718 |
+
slice_col_par++;
|
1719 |
+
slice_col++;
|
1720 |
+
init_slice();
|
1721 |
+
if (slice_iters) {
|
1722 |
+
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
1723 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
1724 |
+
#pragma unroll
|
1725 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
1726 |
+
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
1727 |
+
if (slice_col == 0) {
|
1728 |
+
#pragma unroll
|
1729 |
+
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
1730 |
+
}
|
1731 |
+
|
1732 |
+
// Update slice k/n for scales loading
|
1733 |
+
if constexpr (has_act_order) {
|
1734 |
+
slice_k_start = tb_k * slice_row;
|
1735 |
+
slice_k_finish = slice_k_start + tb_k * slice_iters;
|
1736 |
+
slice_k_start_shared_fetch = slice_k_start;
|
1737 |
+
slice_n_offset = act_s_col_tb_stride * slice_col;
|
1738 |
+
|
1739 |
+
} else {
|
1740 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
1741 |
+
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
1742 |
+
}
|
1743 |
+
|
1744 |
+
start_pipes();
|
1745 |
+
}
|
1746 |
+
}
|
1747 |
+
}
|
1748 |
+
}
|
1749 |
+
|
1750 |
+
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
1751 |
+
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
|
1752 |
+
IS_ZP_FLOAT) \
|
1753 |
+
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
1754 |
+
thread_n_blocks == THREAD_N_BLOCKS && \
|
1755 |
+
thread_k_blocks == THREAD_K_BLOCKS && \
|
1756 |
+
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
1757 |
+
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
1758 |
+
is_zp_float == IS_ZP_FLOAT) { \
|
1759 |
+
if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
|
1760 |
+
cudaFuncSetAttribute( \
|
1761 |
+
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
1762 |
+
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
|
1763 |
+
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
|
1764 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
1765 |
+
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
1766 |
+
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
1767 |
+
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
1768 |
+
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
1769 |
+
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
1770 |
+
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
1771 |
+
} \
|
1772 |
+
}
|
1773 |
+
|
1774 |
+
typedef struct {
|
1775 |
+
int thread_k;
|
1776 |
+
int thread_n;
|
1777 |
+
int num_threads;
|
1778 |
+
} thread_config_t;
|
1779 |
+
|
1780 |
+
typedef struct {
|
1781 |
+
int max_m_blocks;
|
1782 |
+
thread_config_t tb_cfg;
|
1783 |
+
} exec_config_t;
|
1784 |
+
|
1785 |
+
thread_config_t small_batch_thread_configs[] = {
|
1786 |
+
// Ordered by priority
|
1787 |
+
|
1788 |
+
// thread_k, thread_n, num_threads
|
1789 |
+
{128, 128, 256},
|
1790 |
+
{64, 128, 128},
|
1791 |
+
{128, 64, 128},
|
1792 |
+
};
|
1793 |
+
|
1794 |
+
thread_config_t large_batch_thread_configs[] = {
|
1795 |
+
// Ordered by priority
|
1796 |
+
|
1797 |
+
// thread_k, thread_n, num_threads
|
1798 |
+
{64, 256, 256},
|
1799 |
+
{64, 128, 128},
|
1800 |
+
{128, 64, 128},
|
1801 |
+
|
1802 |
+
};
|
1803 |
+
|
1804 |
+
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
1805 |
+
int prob_n, int prob_k, int num_bits, int group_size,
|
1806 |
+
bool has_act_order, bool is_k_full) {
|
1807 |
+
bool cache_scales_chunk = has_act_order && !is_k_full;
|
1808 |
+
|
1809 |
+
int tb_n = th_config.thread_n;
|
1810 |
+
int tb_k = th_config.thread_k;
|
1811 |
+
|
1812 |
+
// Get max scale groups per thread-block
|
1813 |
+
int tb_groups;
|
1814 |
+
if (group_size == -1) {
|
1815 |
+
tb_groups = 1;
|
1816 |
+
} else if (group_size == 0) {
|
1817 |
+
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
1818 |
+
} else {
|
1819 |
+
tb_groups = div_ceil(tb_k, group_size);
|
1820 |
+
}
|
1821 |
+
|
1822 |
+
if (cache_scales_chunk) {
|
1823 |
+
int load_groups =
|
1824 |
+
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
1825 |
+
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
1826 |
+
return load_groups * tb_n * 2;
|
1827 |
+
|
1828 |
+
} else {
|
1829 |
+
int tb_scales = tb_groups * tb_n * 2;
|
1830 |
+
|
1831 |
+
return tb_scales * pipe_stages;
|
1832 |
+
}
|
1833 |
+
}
|
1834 |
+
|
1835 |
+
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
1836 |
+
int prob_m, int prob_n, int prob_k, int num_bits,
|
1837 |
+
int scales_cache_size, int max_shared_mem) {
|
1838 |
+
int pack_factor = 32 / num_bits;
|
1839 |
+
|
1840 |
+
// Get B size
|
1841 |
+
int tb_k = th_config.thread_k;
|
1842 |
+
int tb_n = th_config.thread_n;
|
1843 |
+
|
1844 |
+
int b_size = (tb_k * tb_n / pack_factor) * 4;
|
1845 |
+
|
1846 |
+
// Get A size
|
1847 |
+
int m_blocks = div_ceil(prob_m, 16);
|
1848 |
+
int tb_max_m = 16;
|
1849 |
+
|
1850 |
+
while (true) {
|
1851 |
+
if (m_blocks >= max_m_blocks) {
|
1852 |
+
tb_max_m *= max_m_blocks;
|
1853 |
+
break;
|
1854 |
+
}
|
1855 |
+
|
1856 |
+
max_m_blocks--;
|
1857 |
+
if (max_m_blocks == 0) {
|
1858 |
+
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
1859 |
+
}
|
1860 |
+
}
|
1861 |
+
|
1862 |
+
int a_size = (tb_max_m * tb_k) * 2;
|
1863 |
+
|
1864 |
+
float pipe_size = (a_size + b_size) * pipe_stages;
|
1865 |
+
|
1866 |
+
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
1867 |
+
|
1868 |
+
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
1869 |
+
}
|
1870 |
+
|
1871 |
+
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
1872 |
+
int prob_m, int prob_n, int prob_k, int num_bits,
|
1873 |
+
int group_size, bool has_act_order, bool is_k_full,
|
1874 |
+
int max_shared_mem) {
|
1875 |
+
// Sanity
|
1876 |
+
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
1877 |
+
th_config.num_threads == -1) {
|
1878 |
+
return false;
|
1879 |
+
}
|
1880 |
+
|
1881 |
+
// Verify K/N are divisible by thread K/N
|
1882 |
+
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
1883 |
+
return false;
|
1884 |
+
}
|
1885 |
+
|
1886 |
+
// Verify min for thread K/N
|
1887 |
+
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
1888 |
+
return false;
|
1889 |
+
}
|
1890 |
+
|
1891 |
+
// num_threads must be at least 128 (= 4 warps)
|
1892 |
+
if (th_config.num_threads < 128) {
|
1893 |
+
return false;
|
1894 |
+
}
|
1895 |
+
|
1896 |
+
// Determine cache for scales
|
1897 |
+
int scales_cache_size =
|
1898 |
+
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
1899 |
+
group_size, has_act_order, is_k_full);
|
1900 |
+
|
1901 |
+
// Check that pipeline fits into cache
|
1902 |
+
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
1903 |
+
num_bits, scales_cache_size, max_shared_mem)) {
|
1904 |
+
return false;
|
1905 |
+
}
|
1906 |
+
|
1907 |
+
return true;
|
1908 |
+
}
|
1909 |
+
|
1910 |
+
int determine_reduce_max_m(int prob_m, int max_par) {
|
1911 |
+
constexpr int tile_m_size = 16;
|
1912 |
+
|
1913 |
+
if (prob_m <= tile_m_size) {
|
1914 |
+
return tile_m_size;
|
1915 |
+
|
1916 |
+
} else if (prob_m <= tile_m_size * 2) {
|
1917 |
+
return tile_m_size * 2;
|
1918 |
+
|
1919 |
+
} else if (prob_m <= tile_m_size * 3) {
|
1920 |
+
return tile_m_size * 3;
|
1921 |
+
|
1922 |
+
} else if (prob_m <= tile_m_size * 4) {
|
1923 |
+
return tile_m_size * 4;
|
1924 |
+
|
1925 |
+
} else {
|
1926 |
+
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
|
1927 |
+
return tile_m_size * 4 * cur_par;
|
1928 |
+
}
|
1929 |
+
}
|
1930 |
+
|
1931 |
+
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
1932 |
+
int num_bits, int group_size,
|
1933 |
+
bool has_act_order, bool is_k_full,
|
1934 |
+
int max_shared_mem) {
|
1935 |
+
int max_m_blocks = 4;
|
1936 |
+
while (max_m_blocks > 0) {
|
1937 |
+
if (prob_m <= 16) {
|
1938 |
+
for (auto th_config : small_batch_thread_configs) {
|
1939 |
+
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
1940 |
+
num_bits, group_size, has_act_order, is_k_full,
|
1941 |
+
max_shared_mem)) {
|
1942 |
+
return exec_config_t{max_m_blocks, th_config};
|
1943 |
+
}
|
1944 |
+
}
|
1945 |
+
} else {
|
1946 |
+
for (auto th_config : large_batch_thread_configs) {
|
1947 |
+
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
1948 |
+
num_bits, group_size, has_act_order, is_k_full,
|
1949 |
+
max_shared_mem)) {
|
1950 |
+
return exec_config_t{max_m_blocks, th_config};
|
1951 |
+
}
|
1952 |
+
}
|
1953 |
+
}
|
1954 |
+
|
1955 |
+
max_m_blocks--; // Process less M blocks per invocation to reduce cache
|
1956 |
+
// usage
|
1957 |
+
}
|
1958 |
+
|
1959 |
+
return exec_config_t{0, {-1, -1, -1}};
|
1960 |
+
}
|
1961 |
+
|
1962 |
+
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
1963 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
1964 |
+
false) \
|
1965 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
1966 |
+
false) \
|
1967 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
1968 |
+
false) \
|
1969 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
1970 |
+
false) \
|
1971 |
+
\
|
1972 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
1973 |
+
false) \
|
1974 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
1975 |
+
false) \
|
1976 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
1977 |
+
false) \
|
1978 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
1979 |
+
false) \
|
1980 |
+
\
|
1981 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
1982 |
+
false) \
|
1983 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
1984 |
+
false) \
|
1985 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
1986 |
+
false) \
|
1987 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
1988 |
+
false) \
|
1989 |
+
\
|
1990 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
1991 |
+
false) \
|
1992 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
1993 |
+
false) \
|
1994 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
1995 |
+
false) \
|
1996 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
1997 |
+
false) \
|
1998 |
+
\
|
1999 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
2000 |
+
false) \
|
2001 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
2002 |
+
false) \
|
2003 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
2004 |
+
false) \
|
2005 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
2006 |
+
false)
|
2007 |
+
|
2008 |
+
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
2009 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
2010 |
+
false) \
|
2011 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
2012 |
+
false) \
|
2013 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
2014 |
+
false) \
|
2015 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
2016 |
+
false) \
|
2017 |
+
\
|
2018 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
2019 |
+
false) \
|
2020 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
2021 |
+
false) \
|
2022 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
2023 |
+
false) \
|
2024 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
2025 |
+
false) \
|
2026 |
+
\
|
2027 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
2028 |
+
false) \
|
2029 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
2030 |
+
false) \
|
2031 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
2032 |
+
false) \
|
2033 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
2034 |
+
false) \
|
2035 |
+
\
|
2036 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
2037 |
+
false) \
|
2038 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
2039 |
+
false) \
|
2040 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
2041 |
+
false) \
|
2042 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
|
2043 |
+
|
2044 |
+
// We currently have 4-bit models only with group_blocks == 4
|
2045 |
+
#define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
2046 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
2047 |
+
true) \
|
2048 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
2049 |
+
true) \
|
2050 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
2051 |
+
true) \
|
2052 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
|
2053 |
+
|
2054 |
+
template <typename scalar_t>
|
2055 |
+
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
2056 |
+
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
2057 |
+
int prob_n, int prob_k, void* workspace,
|
2058 |
+
vllm::ScalarType const& q_type, bool has_act_order,
|
2059 |
+
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
2060 |
+
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
2061 |
+
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
|
2062 |
+
if (has_zp) {
|
2063 |
+
TORCH_CHECK(
|
2064 |
+
q_type == vllm::kU4 || q_type == vllm::kU8,
|
2065 |
+
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
2066 |
+
} else {
|
2067 |
+
TORCH_CHECK(
|
2068 |
+
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
2069 |
+
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
2070 |
+
q_type.str());
|
2071 |
+
}
|
2072 |
+
|
2073 |
+
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
2074 |
+
", ", prob_n, ", ", prob_k, "]");
|
2075 |
+
|
2076 |
+
// TODO: remove alias when we start supporting other 8bit types
|
2077 |
+
int num_bits = q_type.size_bits();
|
2078 |
+
int tot_m = prob_m;
|
2079 |
+
int tot_m_blocks = div_ceil(tot_m, 16);
|
2080 |
+
int pad = 16 * tot_m_blocks - tot_m;
|
2081 |
+
|
2082 |
+
if (sms == -1) {
|
2083 |
+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
2084 |
+
}
|
2085 |
+
|
2086 |
+
int max_shared_mem = 0;
|
2087 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
2088 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
2089 |
+
TORCH_CHECK(max_shared_mem > 0);
|
2090 |
+
|
2091 |
+
// Set thread config
|
2092 |
+
exec_config_t exec_cfg;
|
2093 |
+
if (thread_k != -1 && thread_n != -1) {
|
2094 |
+
// User-defined config
|
2095 |
+
exec_cfg =
|
2096 |
+
exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
|
2097 |
+
} else {
|
2098 |
+
// Auto config
|
2099 |
+
exec_cfg =
|
2100 |
+
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
|
2101 |
+
has_act_order, is_k_full, max_shared_mem);
|
2102 |
+
}
|
2103 |
+
|
2104 |
+
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
|
2105 |
+
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
|
2106 |
+
prob_m, prob_n, prob_k, num_bits, group_size,
|
2107 |
+
has_act_order, is_k_full, max_shared_mem),
|
2108 |
+
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
|
2109 |
+
", thread_k = ", exec_cfg.tb_cfg.thread_k,
|
2110 |
+
", thread_n = ", exec_cfg.tb_cfg.thread_n,
|
2111 |
+
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
|
2112 |
+
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
2113 |
+
", group_size = ", group_size,
|
2114 |
+
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
2115 |
+
", max_shared_mem = ", max_shared_mem);
|
2116 |
+
|
2117 |
+
int num_threads = exec_cfg.tb_cfg.num_threads;
|
2118 |
+
thread_k = exec_cfg.tb_cfg.thread_k;
|
2119 |
+
thread_n = exec_cfg.tb_cfg.thread_n;
|
2120 |
+
|
2121 |
+
int thread_k_blocks = thread_k / 16;
|
2122 |
+
int thread_n_blocks = thread_n / 16;
|
2123 |
+
|
2124 |
+
int blocks = sms;
|
2125 |
+
|
2126 |
+
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
2127 |
+
" is not divisible by thread_n = ", thread_n);
|
2128 |
+
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
2129 |
+
" is not divisible by thread_k = ", thread_k);
|
2130 |
+
|
2131 |
+
int group_blocks = 0;
|
2132 |
+
if (has_act_order) {
|
2133 |
+
if (is_k_full) {
|
2134 |
+
TORCH_CHECK(group_size != -1);
|
2135 |
+
group_blocks = group_size / 16;
|
2136 |
+
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
2137 |
+
" is not divisible by group_blocks = ", group_blocks);
|
2138 |
+
} else {
|
2139 |
+
TORCH_CHECK(group_size == 0);
|
2140 |
+
group_blocks = 0;
|
2141 |
+
}
|
2142 |
+
|
2143 |
+
} else {
|
2144 |
+
if (group_size == -1) {
|
2145 |
+
group_blocks = -1;
|
2146 |
+
} else {
|
2147 |
+
group_blocks = group_size / 16;
|
2148 |
+
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
2149 |
+
" is not divisible by group_blocks = ", group_blocks);
|
2150 |
+
}
|
2151 |
+
}
|
2152 |
+
|
2153 |
+
const int4* A_ptr = (const int4*)A;
|
2154 |
+
const int4* B_ptr = (const int4*)B;
|
2155 |
+
int4* C_ptr = (int4*)C;
|
2156 |
+
int4* C_tmp_ptr = (int4*)C_tmp;
|
2157 |
+
const int4* s_ptr = (const int4*)s;
|
2158 |
+
const int4* zp_ptr = (const int4*)zp;
|
2159 |
+
const int* g_idx_ptr = (const int*)g_idx;
|
2160 |
+
const int* perm_ptr = (const int*)perm;
|
2161 |
+
int4* a_tmp_ptr = (int4*)a_tmp;
|
2162 |
+
|
2163 |
+
int* locks = (int*)workspace;
|
2164 |
+
|
2165 |
+
if (has_act_order) {
|
2166 |
+
// Permute A columns
|
2167 |
+
int block_rows = div_ceil(prob_m, blocks);
|
2168 |
+
permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
|
2169 |
+
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
|
2170 |
+
A_ptr = a_tmp_ptr;
|
2171 |
+
}
|
2172 |
+
|
2173 |
+
// If we have a full K, then we can run the non-act-order version of Marlin
|
2174 |
+
// (since the weight rows are reordered by increasing group ids, and by having
|
2175 |
+
// a full K, we have full original groups)
|
2176 |
+
if (is_k_full) {
|
2177 |
+
has_act_order = false;
|
2178 |
+
}
|
2179 |
+
|
2180 |
+
// Main loop
|
2181 |
+
for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
|
2182 |
+
int thread_m_blocks = tot_m_blocks - i;
|
2183 |
+
prob_m = tot_m - 16 * i;
|
2184 |
+
int par = 1;
|
2185 |
+
if (thread_m_blocks > exec_cfg.max_m_blocks) {
|
2186 |
+
// Note that parallel > 1 currently only works for inputs without any
|
2187 |
+
// padding
|
2188 |
+
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
|
2189 |
+
if (par > max_par) par = max_par;
|
2190 |
+
prob_m = (16 * exec_cfg.max_m_blocks) * par;
|
2191 |
+
i += exec_cfg.max_m_blocks * (par - 1);
|
2192 |
+
thread_m_blocks = exec_cfg.max_m_blocks;
|
2193 |
+
}
|
2194 |
+
|
2195 |
+
if (false) {
|
2196 |
+
}
|
2197 |
+
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
|
2198 |
+
GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
|
2199 |
+
GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
|
2200 |
+
GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
|
2201 |
+
GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
|
2202 |
+
GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
|
2203 |
+
GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
|
2204 |
+
GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
|
2205 |
+
|
2206 |
+
AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
|
2207 |
+
AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
|
2208 |
+
AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
|
2209 |
+
AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
|
2210 |
+
AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
|
2211 |
+
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
|
2212 |
+
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
|
2213 |
+
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
|
2214 |
+
|
2215 |
+
HQQ_CALL_IF(vllm::kU4, 16, 4, 256)
|
2216 |
+
HQQ_CALL_IF(vllm::kU4, 8, 8, 256)
|
2217 |
+
HQQ_CALL_IF(vllm::kU4, 8, 4, 128)
|
2218 |
+
HQQ_CALL_IF(vllm::kU4, 4, 8, 128)
|
2219 |
+
else {
|
2220 |
+
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
2221 |
+
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
2222 |
+
", num_groups = ", num_groups, ", group_size = ", group_size,
|
2223 |
+
", thread_m_blocks = ", thread_m_blocks,
|
2224 |
+
", thread_n_blocks = ", thread_n_blocks,
|
2225 |
+
", thread_k_blocks = ", thread_k_blocks,
|
2226 |
+
", num_bits = ", num_bits);
|
2227 |
+
}
|
2228 |
+
|
2229 |
+
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
2230 |
+
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
|
2231 |
+
}
|
2232 |
+
}
|
2233 |
+
|
2234 |
+
} // namespace marlin
|
2235 |
+
|
2236 |
+
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
2237 |
+
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
2238 |
+
torch::Tensor& g_idx, torch::Tensor& perm,
|
2239 |
+
torch::Tensor& workspace,
|
2240 |
+
vllm::ScalarTypeId const& b_q_type_id,
|
2241 |
+
int64_t size_m, int64_t size_n, int64_t size_k,
|
2242 |
+
bool is_k_full, bool has_zp,
|
2243 |
+
bool use_fp32_reduce, bool is_zp_float) {
|
2244 |
+
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
2245 |
+
if (has_zp) {
|
2246 |
+
TORCH_CHECK(
|
2247 |
+
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
2248 |
+
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
2249 |
+
} else {
|
2250 |
+
TORCH_CHECK(
|
2251 |
+
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
2252 |
+
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
2253 |
+
b_q_type.str());
|
2254 |
+
}
|
2255 |
+
|
2256 |
+
if (has_zp && is_zp_float) {
|
2257 |
+
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
2258 |
+
"Computation type must be float16 (half) when using float zero "
|
2259 |
+
"points.");
|
2260 |
+
}
|
2261 |
+
|
2262 |
+
int pack_factor = 32 / b_q_type.size_bits();
|
2263 |
+
|
2264 |
+
// Verify A
|
2265 |
+
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
2266 |
+
", size_m = ", size_m);
|
2267 |
+
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
2268 |
+
", size_k = ", size_k);
|
2269 |
+
|
2270 |
+
// Verify B
|
2271 |
+
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
2272 |
+
" is not divisible by tile_size = ", marlin::tile_size);
|
2273 |
+
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
2274 |
+
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
2275 |
+
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
2276 |
+
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
2277 |
+
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
2278 |
+
" is not divisible by tile_size = ", marlin::tile_size);
|
2279 |
+
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
2280 |
+
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
2281 |
+
", actual_size_n = ", actual_size_n);
|
2282 |
+
|
2283 |
+
// Verify device and strides
|
2284 |
+
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
2285 |
+
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
2286 |
+
|
2287 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
2288 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
2289 |
+
|
2290 |
+
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
2291 |
+
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
2292 |
+
|
2293 |
+
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
2294 |
+
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
2295 |
+
|
2296 |
+
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
2297 |
+
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
2298 |
+
|
2299 |
+
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
2300 |
+
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
2301 |
+
|
2302 |
+
// Alloc buffers
|
2303 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
2304 |
+
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
2305 |
+
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
2306 |
+
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
|
2307 |
+
|
2308 |
+
// Alloc C tmp buffer that is going to be used for the global reduce
|
2309 |
+
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
2310 |
+
int reduce_n = size_n;
|
2311 |
+
auto options_fp32 =
|
2312 |
+
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
2313 |
+
if (!use_fp32_reduce) {
|
2314 |
+
reduce_max_m = 0;
|
2315 |
+
reduce_n = 0;
|
2316 |
+
}
|
2317 |
+
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
2318 |
+
|
2319 |
+
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
2320 |
+
// auto -1)
|
2321 |
+
int thread_k = -1;
|
2322 |
+
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
2323 |
+
// auto -1)
|
2324 |
+
int thread_n = -1;
|
2325 |
+
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
|
2326 |
+
int sms = -1;
|
2327 |
+
|
2328 |
+
// Verify g_idx and perm
|
2329 |
+
TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
|
2330 |
+
(g_idx.size(0) == size_k && perm.size(0) == size_k),
|
2331 |
+
"Unexpected g_idx.size(0) = ", g_idx.size(0),
|
2332 |
+
" and perm.size(0) = ", perm.size(0),
|
2333 |
+
", where size_k = ", size_k);
|
2334 |
+
|
2335 |
+
// Detect groupsize and act_order
|
2336 |
+
int num_groups = -1;
|
2337 |
+
int group_size = -1;
|
2338 |
+
bool has_act_order = g_idx.size(0) != 0;
|
2339 |
+
|
2340 |
+
int rank = b_scales.sizes().size();
|
2341 |
+
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
2342 |
+
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
2343 |
+
" is not size_n = ", size_n);
|
2344 |
+
num_groups = b_scales.size(0);
|
2345 |
+
|
2346 |
+
if (has_act_order) {
|
2347 |
+
if (is_k_full) {
|
2348 |
+
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
2349 |
+
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
2350 |
+
", is not divisible by num_groups = ", num_groups);
|
2351 |
+
group_size = size_k / num_groups;
|
2352 |
+
} else {
|
2353 |
+
group_size = 0;
|
2354 |
+
}
|
2355 |
+
|
2356 |
+
} else {
|
2357 |
+
if (num_groups > 1) {
|
2358 |
+
TORCH_CHECK(
|
2359 |
+
size_k % num_groups == 0, "size_k = ", size_k,
|
2360 |
+
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
|
2361 |
+
group_size = size_k / num_groups;
|
2362 |
+
} else {
|
2363 |
+
group_size = -1;
|
2364 |
+
}
|
2365 |
+
}
|
2366 |
+
|
2367 |
+
// Verify b_zeros
|
2368 |
+
if (has_zp) {
|
2369 |
+
int rank = b_zeros.sizes().size();
|
2370 |
+
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
2371 |
+
if (is_zp_float) {
|
2372 |
+
TORCH_CHECK(b_zeros.size(1) == size_n,
|
2373 |
+
"b_zeros dim 1 = ", b_zeros.size(1),
|
2374 |
+
" is not size_n = ", size_n);
|
2375 |
+
TORCH_CHECK(num_groups == b_zeros.size(0),
|
2376 |
+
"b_zeros dim 0 = ", b_zeros.size(0),
|
2377 |
+
" is not num_groups = ", num_groups);
|
2378 |
+
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
2379 |
+
} else {
|
2380 |
+
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
2381 |
+
"b_zeros dim 0 = ", b_zeros.size(0),
|
2382 |
+
" is not num_groups = ", num_groups);
|
2383 |
+
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
2384 |
+
"b_zeros dim 1 = ", b_zeros.size(1),
|
2385 |
+
" is not size_n / pack_factor = ", size_n / pack_factor);
|
2386 |
+
}
|
2387 |
+
}
|
2388 |
+
|
2389 |
+
// Verify workspace size
|
2390 |
+
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
2391 |
+
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
2392 |
+
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
2393 |
+
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
2394 |
+
"workspace.numel = ", workspace.numel(),
|
2395 |
+
" is below min_workspace_size = ", min_workspace_size);
|
2396 |
+
|
2397 |
+
int dev = a.get_device();
|
2398 |
+
if (a.scalar_type() == at::ScalarType::Half) {
|
2399 |
+
marlin::marlin_mm<half>(
|
2400 |
+
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
2401 |
+
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
2402 |
+
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
2403 |
+
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
2404 |
+
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
2405 |
+
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
2406 |
+
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
2407 |
+
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
2408 |
+
marlin::marlin_mm<nv_bfloat16>(
|
2409 |
+
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
2410 |
+
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
2411 |
+
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
2412 |
+
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
2413 |
+
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
2414 |
+
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
2415 |
+
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
2416 |
+
} else {
|
2417 |
+
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
2418 |
+
}
|
2419 |
+
|
2420 |
+
return c;
|
2421 |
+
}
|
2422 |
+
|
2423 |
+
#endif
|
gptq_marlin/gptq_marlin_repack.cu
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "marlin.cuh"
|
2 |
+
|
3 |
+
namespace marlin {
|
4 |
+
|
5 |
+
template <int const num_threads, int const num_bits, bool const has_perm>
|
6 |
+
__global__ void gptq_marlin_repack_kernel(
|
7 |
+
uint32_t const* __restrict__ b_q_weight_ptr,
|
8 |
+
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
9 |
+
int size_k, int size_n) {
|
10 |
+
constexpr int pack_factor = 32 / num_bits;
|
11 |
+
|
12 |
+
int k_tiles = size_k / tile_k_size;
|
13 |
+
int n_tiles = size_n / tile_n_size;
|
14 |
+
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
15 |
+
|
16 |
+
int start_k_tile = blockIdx.x * block_k_tiles;
|
17 |
+
if (start_k_tile >= k_tiles) {
|
18 |
+
return;
|
19 |
+
}
|
20 |
+
|
21 |
+
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
22 |
+
|
23 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
24 |
+
auto wait_for_stage = [&]() {
|
25 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
26 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
27 |
+
// shared memory load is fully complete (as it may otherwise be
|
28 |
+
// overwritten).
|
29 |
+
cp_async_wait<repack_stages - 2>();
|
30 |
+
__syncthreads();
|
31 |
+
};
|
32 |
+
|
33 |
+
extern __shared__ int4 sh[];
|
34 |
+
|
35 |
+
constexpr int perm_size = tile_k_size / 4;
|
36 |
+
|
37 |
+
int4* sh_perm_ptr = sh;
|
38 |
+
int4* sh_pipe_ptr = sh_perm_ptr;
|
39 |
+
if constexpr (has_perm) {
|
40 |
+
sh_pipe_ptr += perm_size;
|
41 |
+
}
|
42 |
+
|
43 |
+
constexpr int tile_ints = tile_k_size / pack_factor;
|
44 |
+
|
45 |
+
constexpr int stage_n_threads = tile_n_size / 4;
|
46 |
+
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
47 |
+
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
48 |
+
|
49 |
+
auto load_perm_to_shared = [&](int k_tile_id) {
|
50 |
+
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
51 |
+
|
52 |
+
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
53 |
+
|
54 |
+
if (threadIdx.x < perm_size) {
|
55 |
+
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
56 |
+
}
|
57 |
+
__syncthreads();
|
58 |
+
};
|
59 |
+
|
60 |
+
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
61 |
+
if (n_tile_id >= n_tiles) {
|
62 |
+
cp_async_fence();
|
63 |
+
return;
|
64 |
+
}
|
65 |
+
|
66 |
+
int first_n = n_tile_id * tile_n_size;
|
67 |
+
|
68 |
+
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
69 |
+
|
70 |
+
if constexpr (has_perm) {
|
71 |
+
if (threadIdx.x < stage_size) {
|
72 |
+
int k_id = threadIdx.x / stage_n_threads;
|
73 |
+
int n_id = threadIdx.x % stage_n_threads;
|
74 |
+
|
75 |
+
uint32_t const* sh_perm_int_ptr =
|
76 |
+
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
77 |
+
|
78 |
+
int src_k = sh_perm_int_ptr[k_id];
|
79 |
+
int src_k_packed = src_k / pack_factor;
|
80 |
+
|
81 |
+
cp_async4(
|
82 |
+
&sh_ptr[k_id * stage_n_threads + n_id],
|
83 |
+
reinterpret_cast<int4 const*>(&(
|
84 |
+
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
85 |
+
}
|
86 |
+
|
87 |
+
} else {
|
88 |
+
if (threadIdx.x < stage_size) {
|
89 |
+
int k_id = threadIdx.x / stage_n_threads;
|
90 |
+
int n_id = threadIdx.x % stage_n_threads;
|
91 |
+
|
92 |
+
int first_k = k_tile_id * tile_k_size;
|
93 |
+
int first_k_packed = first_k / pack_factor;
|
94 |
+
|
95 |
+
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
96 |
+
reinterpret_cast<int4 const*>(
|
97 |
+
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
98 |
+
first_n + (n_id * 4)])));
|
99 |
+
}
|
100 |
+
}
|
101 |
+
|
102 |
+
cp_async_fence();
|
103 |
+
};
|
104 |
+
|
105 |
+
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
106 |
+
if (n_tile_id >= n_tiles) {
|
107 |
+
return;
|
108 |
+
}
|
109 |
+
|
110 |
+
int warp_id = threadIdx.x / 32;
|
111 |
+
int th_id = threadIdx.x % 32;
|
112 |
+
|
113 |
+
if (warp_id >= 4) {
|
114 |
+
return;
|
115 |
+
}
|
116 |
+
|
117 |
+
int tc_col = th_id / 4;
|
118 |
+
int tc_row = (th_id % 4) * 2;
|
119 |
+
|
120 |
+
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
121 |
+
|
122 |
+
int cur_n = warp_id * 16 + tc_col;
|
123 |
+
|
124 |
+
constexpr int sh_stride = 64;
|
125 |
+
constexpr uint32_t mask = (1 << num_bits) - 1;
|
126 |
+
|
127 |
+
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
128 |
+
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
129 |
+
|
130 |
+
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
131 |
+
|
132 |
+
uint32_t vals[8];
|
133 |
+
|
134 |
+
if constexpr (has_perm) {
|
135 |
+
for (int i = 0; i < 4; i++) {
|
136 |
+
int k_idx = tc_row + tc_offsets[i];
|
137 |
+
|
138 |
+
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
139 |
+
uint32_t src_k_pos = src_k % pack_factor;
|
140 |
+
|
141 |
+
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
142 |
+
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
143 |
+
|
144 |
+
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
145 |
+
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
146 |
+
|
147 |
+
vals[i] = b1_cur_val;
|
148 |
+
vals[4 + i] = b2_cur_val;
|
149 |
+
}
|
150 |
+
|
151 |
+
} else {
|
152 |
+
uint32_t b1_vals[tile_ints];
|
153 |
+
uint32_t b2_vals[tile_ints];
|
154 |
+
|
155 |
+
#pragma unroll
|
156 |
+
for (int i = 0; i < tile_ints; i++) {
|
157 |
+
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
158 |
+
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
159 |
+
}
|
160 |
+
|
161 |
+
#pragma unroll
|
162 |
+
for (int i = 0; i < 4; i++) {
|
163 |
+
int cur_elem = tc_row + tc_offsets[i];
|
164 |
+
int cur_int = cur_elem / pack_factor;
|
165 |
+
int cur_pos = cur_elem % pack_factor;
|
166 |
+
|
167 |
+
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
168 |
+
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
169 |
+
}
|
170 |
+
}
|
171 |
+
|
172 |
+
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
173 |
+
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
174 |
+
|
175 |
+
// Result of:
|
176 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
177 |
+
if constexpr (num_bits == 4) {
|
178 |
+
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
179 |
+
|
180 |
+
uint32_t res = 0;
|
181 |
+
#pragma unroll
|
182 |
+
for (int i = 0; i < 8; i++) {
|
183 |
+
res |= vals[pack_idx[i]] << (i * 4);
|
184 |
+
}
|
185 |
+
|
186 |
+
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
187 |
+
|
188 |
+
} else {
|
189 |
+
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
190 |
+
|
191 |
+
uint32_t res1 = 0;
|
192 |
+
uint32_t res2 = 0;
|
193 |
+
#pragma unroll
|
194 |
+
for (int i = 0; i < 4; i++) {
|
195 |
+
res1 |= vals[pack_idx[i]] << (i * 8);
|
196 |
+
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
197 |
+
}
|
198 |
+
|
199 |
+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
200 |
+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
201 |
+
}
|
202 |
+
};
|
203 |
+
|
204 |
+
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
205 |
+
#pragma unroll
|
206 |
+
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
207 |
+
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
208 |
+
}
|
209 |
+
|
210 |
+
wait_for_stage();
|
211 |
+
};
|
212 |
+
#pragma unroll
|
213 |
+
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
214 |
+
int n_tile_id = 0;
|
215 |
+
|
216 |
+
if constexpr (has_perm) {
|
217 |
+
load_perm_to_shared(k_tile_id);
|
218 |
+
}
|
219 |
+
|
220 |
+
start_pipes(k_tile_id, n_tile_id);
|
221 |
+
|
222 |
+
while (n_tile_id < n_tiles) {
|
223 |
+
#pragma unroll
|
224 |
+
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
225 |
+
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
226 |
+
n_tile_id + pipe + repack_stages - 1);
|
227 |
+
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
228 |
+
wait_for_stage();
|
229 |
+
}
|
230 |
+
n_tile_id += repack_stages;
|
231 |
+
}
|
232 |
+
}
|
233 |
+
}
|
234 |
+
|
235 |
+
} // namespace marlin
|
236 |
+
|
237 |
+
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
238 |
+
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
239 |
+
cudaFuncSetAttribute( \
|
240 |
+
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
241 |
+
HAS_PERM>, \
|
242 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
243 |
+
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
244 |
+
HAS_PERM> \
|
245 |
+
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
246 |
+
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
247 |
+
}
|
248 |
+
|
249 |
+
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
250 |
+
int64_t size_k, int64_t size_n,
|
251 |
+
int64_t num_bits) {
|
252 |
+
// Verify compatibility with marlin tile of 16x64
|
253 |
+
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
254 |
+
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
255 |
+
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
256 |
+
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
257 |
+
|
258 |
+
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
259 |
+
"num_bits must be 4 or 8. Got = ", num_bits);
|
260 |
+
int const pack_factor = 32 / num_bits;
|
261 |
+
|
262 |
+
// Verify B
|
263 |
+
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
|
264 |
+
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
265 |
+
", size_k = ", size_k, ", pack_factor = ", pack_factor);
|
266 |
+
TORCH_CHECK(b_q_weight.size(1) == size_n,
|
267 |
+
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
268 |
+
" is not size_n = ", size_n);
|
269 |
+
|
270 |
+
// Verify device and strides
|
271 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
272 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
273 |
+
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
274 |
+
|
275 |
+
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
276 |
+
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
277 |
+
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
278 |
+
|
279 |
+
// Alloc buffers
|
280 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
281 |
+
auto options = torch::TensorOptions()
|
282 |
+
.dtype(b_q_weight.dtype())
|
283 |
+
.device(b_q_weight.device());
|
284 |
+
torch::Tensor out = torch::empty(
|
285 |
+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
286 |
+
options);
|
287 |
+
|
288 |
+
// Detect if there is act_order
|
289 |
+
bool has_perm = perm.size(0) != 0;
|
290 |
+
|
291 |
+
// Get ptrs
|
292 |
+
uint32_t const* b_q_weight_ptr =
|
293 |
+
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
294 |
+
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
295 |
+
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
296 |
+
|
297 |
+
// Get dev info
|
298 |
+
int dev = b_q_weight.get_device();
|
299 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
300 |
+
int blocks;
|
301 |
+
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
302 |
+
|
303 |
+
int max_shared_mem = 0;
|
304 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
305 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
306 |
+
TORCH_CHECK(max_shared_mem > 0);
|
307 |
+
|
308 |
+
if (false) {
|
309 |
+
}
|
310 |
+
CALL_IF(4, false)
|
311 |
+
CALL_IF(4, true)
|
312 |
+
CALL_IF(8, false)
|
313 |
+
CALL_IF(8, true)
|
314 |
+
else {
|
315 |
+
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
316 |
+
", has_perm = ", has_perm);
|
317 |
+
}
|
318 |
+
|
319 |
+
return out;
|
320 |
+
}
|
321 |
+
|
322 |
+
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
323 |
+
torch::Tensor& perm, c10::SymInt size_k,
|
324 |
+
c10::SymInt size_n, int64_t num_bits) {
|
325 |
+
int const pack_factor = 32 / num_bits;
|
326 |
+
auto options = torch::TensorOptions()
|
327 |
+
.dtype(b_q_weight.dtype())
|
328 |
+
.device(b_q_weight.device());
|
329 |
+
return torch::empty_symint(
|
330 |
+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
331 |
+
options);
|
332 |
+
}
|
333 |
+
|