danieldk HF staff commited on
Commit
c31b5ce
·
1 Parent(s): c5018b2

Add GPTQ-Marlin

Browse files
build.toml CHANGED
@@ -5,6 +5,7 @@ version = "0.0.1"
5
  name = "quantization"
6
  src = [
7
  "core/registration.h",
 
8
  "ext-torch/torch_binding.cpp",
9
  "ext-torch/torch_binding.h"
10
  ]
@@ -69,3 +70,16 @@ src = [
69
  ]
70
  include = [ "." ]
71
  depends = [ "torch" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  name = "quantization"
6
  src = [
7
  "core/registration.h",
8
+ "core/scalar_type.hpp",
9
  "ext-torch/torch_binding.cpp",
10
  "ext-torch/torch_binding.h"
11
  ]
 
70
  ]
71
  include = [ "." ]
72
  depends = [ "torch" ]
73
+
74
+ [kernel.gptq_marlin]
75
+ capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
76
+ src = [
77
+ "core/scalar_type.hpp",
78
+ "gptq_marlin/awq_marlin_repack.cu",
79
+ "gptq_marlin/gptq_marlin.cu",
80
+ "gptq_marlin/gptq_marlin_repack.cu",
81
+ "gptq_marlin/marlin.cuh",
82
+ "gptq_marlin/marlin_dtypes.cuh"
83
+ ]
84
+ include = [ "." ]
85
+ depends = [ "torch" ]
ext-torch/torch_binding.cpp CHANGED
@@ -66,6 +66,32 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
66
  "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
67
  "SymInt size_k) -> Tensor");
68
  ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  }
70
 
71
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
66
  "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
67
  "SymInt size_k) -> Tensor");
68
  ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
69
+
70
+ // awq_marlin repack from AWQ.
71
+ ops.def(
72
+ "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
73
+ "SymInt size_n, int num_bits) -> Tensor");
74
+ ops.impl("awq_marlin_repack", &awq_marlin_repack);
75
+
76
+ // gptq_marlin Optimized Quantized GEMM for GPTQ.
77
+ ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
78
+ ops.def(
79
+ "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
80
+ "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
81
+ "int b_q_type, "
82
+ "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
83
+ "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
84
+
85
+ // gptq_marlin repack from GPTQ.
86
+ ops.def(
87
+ "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
88
+ "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
89
+ ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
90
+ }
91
+
92
+ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
93
+ ops.impl("awq_marlin_repack", &awq_marlin_repack_meta);
94
+ ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
95
  }
96
 
97
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
ext-torch/torch_binding.h CHANGED
@@ -2,6 +2,8 @@
2
 
3
  #include <torch/torch.h>
4
 
 
 
5
  bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
6
 
7
  void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
@@ -46,3 +48,29 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
46
  torch::Tensor& b_scales, torch::Tensor& workspace,
47
  int64_t num_bits, int64_t size_m, int64_t size_n,
48
  int64_t size_k);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #include <torch/torch.h>
4
 
5
+ #include <core/scalar_type.hpp>
6
+
7
  bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
8
 
9
  void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
 
48
  torch::Tensor& b_scales, torch::Tensor& workspace,
49
  int64_t num_bits, int64_t size_m, int64_t size_n,
50
  int64_t size_k);
51
+
52
+ // GPTQ-Marlin
53
+
54
+ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
55
+ int64_t size_n, int64_t num_bits);
56
+
57
+ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
58
+ c10::SymInt size_k, c10::SymInt size_n,
59
+ int64_t num_bits);
60
+
61
+ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
62
+ torch::Tensor& b_scales, torch::Tensor& b_zeros,
63
+ torch::Tensor& g_idx, torch::Tensor& perm,
64
+ torch::Tensor& workspace,
65
+ vllm::ScalarTypeId const& b_q_type_id,
66
+ int64_t size_m, int64_t size_n, int64_t size_k,
67
+ bool is_k_full, bool has_zp,
68
+ bool use_fp32_reduce, bool is_zp_float);
69
+
70
+ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
71
+ int64_t size_k, int64_t size_n,
72
+ int64_t num_bits);
73
+
74
+ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
75
+ torch::Tensor& perm, c10::SymInt size_k,
76
+ c10::SymInt size_n, int64_t num_bits);
gptq_marlin/awq_marlin_repack.cu ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "marlin.cuh"
2
+
3
+ namespace marlin {
4
+
5
+ template <int const num_threads, int const num_bits>
6
+ __global__ void awq_marlin_repack_kernel(
7
+ uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
8
+ int size_k, int size_n) {
9
+ constexpr int pack_factor = 32 / num_bits;
10
+
11
+ int k_tiles = size_k / tile_k_size;
12
+ int n_tiles = size_n / tile_n_size;
13
+ int block_k_tiles = div_ceil(k_tiles, gridDim.x);
14
+
15
+ int start_k_tile = blockIdx.x * block_k_tiles;
16
+ if (start_k_tile >= k_tiles) {
17
+ return;
18
+ }
19
+
20
+ int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
21
+
22
+ // Wait until the next thread tile has been loaded to shared memory.
23
+ auto wait_for_stage = [&]() {
24
+ // We only have `stages - 2` active fetches since we are double buffering
25
+ // and can only issue the next fetch when it is guaranteed that the previous
26
+ // shared memory load is fully complete (as it may otherwise be
27
+ // overwritten).
28
+ cp_async_wait<repack_stages - 2>();
29
+ __syncthreads();
30
+ };
31
+
32
+ extern __shared__ int4 sh[];
33
+
34
+ constexpr int tile_n_ints = tile_n_size / pack_factor;
35
+
36
+ constexpr int stage_n_threads = tile_n_ints / 4;
37
+ constexpr int stage_k_threads = tile_k_size;
38
+ constexpr int stage_size = stage_k_threads * stage_n_threads;
39
+
40
+ auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
41
+ if (n_tile_id >= n_tiles) {
42
+ cp_async_fence();
43
+ return;
44
+ }
45
+
46
+ int first_n = n_tile_id * tile_n_size;
47
+ int first_n_packed = first_n / pack_factor;
48
+
49
+ int4* sh_ptr = sh + stage_size * pipe;
50
+
51
+ if (threadIdx.x < stage_size) {
52
+ int k_id = threadIdx.x / stage_n_threads;
53
+ int n_id = threadIdx.x % stage_n_threads;
54
+
55
+ int first_k = k_tile_id * tile_k_size;
56
+
57
+ cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
58
+ reinterpret_cast<int4 const*>(
59
+ &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
60
+ first_n_packed + (n_id * 4)])));
61
+ }
62
+
63
+ cp_async_fence();
64
+ };
65
+
66
+ auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
67
+ if (n_tile_id >= n_tiles) {
68
+ return;
69
+ }
70
+
71
+ int warp_id = threadIdx.x / 32;
72
+ int th_id = threadIdx.x % 32;
73
+
74
+ if (warp_id >= 4) {
75
+ return;
76
+ }
77
+
78
+ int tc_col = th_id / 4;
79
+ int tc_row = (th_id % 4) * 2;
80
+
81
+ constexpr int tc_offsets[4] = {0, 1, 8, 9};
82
+
83
+ int cur_n = warp_id * 16 + tc_col;
84
+ int cur_n_packed = cur_n / pack_factor;
85
+ int cur_n_pos = cur_n % pack_factor;
86
+
87
+ constexpr int sh_stride = tile_n_ints;
88
+ constexpr uint32_t mask = (1 << num_bits) - 1;
89
+
90
+ int4* sh_stage_ptr = sh + stage_size * pipe;
91
+ uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
92
+
93
+ // Undo interleaving
94
+ int cur_n_pos_unpacked;
95
+ if constexpr (num_bits == 4) {
96
+ constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
97
+ cur_n_pos_unpacked = undo_pack[cur_n_pos];
98
+ } else {
99
+ constexpr int undo_pack[4] = {0, 2, 1, 3};
100
+ cur_n_pos_unpacked = undo_pack[cur_n_pos];
101
+ }
102
+
103
+ uint32_t vals[8];
104
+ #pragma unroll
105
+ for (int i = 0; i < 4; i++) {
106
+ int cur_elem = tc_row + tc_offsets[i];
107
+
108
+ int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
109
+ int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
110
+ sh_stride * cur_elem];
111
+
112
+ vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
113
+ vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
114
+ }
115
+
116
+ constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
117
+ int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
118
+
119
+ // Result of:
120
+ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
121
+ if constexpr (num_bits == 4) {
122
+ constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
123
+
124
+ uint32_t res = 0;
125
+ #pragma unroll
126
+ for (int i = 0; i < 8; i++) {
127
+ res |= vals[pack_idx[i]] << (i * 4);
128
+ }
129
+
130
+ out_ptr[out_offset + th_id * 4 + warp_id] = res;
131
+
132
+ } else {
133
+ constexpr int pack_idx[4] = {0, 2, 1, 3};
134
+
135
+ uint32_t res1 = 0;
136
+ uint32_t res2 = 0;
137
+ #pragma unroll
138
+ for (int i = 0; i < 4; i++) {
139
+ res1 |= vals[pack_idx[i]] << (i * 8);
140
+ res2 |= vals[4 + pack_idx[i]] << (i * 8);
141
+ }
142
+
143
+ out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
144
+ out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
145
+ }
146
+ };
147
+
148
+ auto start_pipes = [&](int k_tile_id, int n_tile_id) {
149
+ #pragma unroll
150
+ for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
151
+ fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
152
+ }
153
+
154
+ wait_for_stage();
155
+ };
156
+ #pragma unroll
157
+ for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
158
+ int n_tile_id = 0;
159
+
160
+ start_pipes(k_tile_id, n_tile_id);
161
+
162
+ while (n_tile_id < n_tiles) {
163
+ #pragma unroll
164
+ for (int pipe = 0; pipe < repack_stages; pipe++) {
165
+ fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
166
+ n_tile_id + pipe + repack_stages - 1);
167
+ repack_tile(pipe, k_tile_id, n_tile_id + pipe);
168
+ wait_for_stage();
169
+ }
170
+ n_tile_id += repack_stages;
171
+ }
172
+ }
173
+ }
174
+
175
+ } // namespace marlin
176
+
177
+ #define CALL_IF(NUM_BITS) \
178
+ else if (num_bits == NUM_BITS) { \
179
+ cudaFuncSetAttribute( \
180
+ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
181
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
182
+ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
183
+ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
184
+ b_q_weight_ptr, out_ptr, size_k, size_n); \
185
+ }
186
+
187
+ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
188
+ int64_t size_n, int64_t num_bits) {
189
+ // Verify compatibility with marlin tile of 16x64
190
+ TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
191
+ " is not divisible by tile_k_size = ", marlin::tile_k_size);
192
+ TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
193
+ " is not divisible by tile_n_size = ", marlin::tile_n_size);
194
+
195
+ TORCH_CHECK(num_bits == 4 || num_bits == 8,
196
+ "num_bits must be 4 or 8. Got = ", num_bits);
197
+ int const pack_factor = 32 / num_bits;
198
+
199
+ // Verify B
200
+ TORCH_CHECK(b_q_weight.size(0) == size_k,
201
+ "b_q_weight.size(0) = ", b_q_weight.size(0),
202
+ " is not size_k = ", size_k);
203
+ TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
204
+ "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
205
+ ", size_n = ", size_n, ", pack_factor = ", pack_factor);
206
+
207
+ // Verify device and strides
208
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
209
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
210
+ TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
211
+
212
+ // Alloc buffers
213
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
214
+ auto options = torch::TensorOptions()
215
+ .dtype(b_q_weight.dtype())
216
+ .device(b_q_weight.device());
217
+ torch::Tensor out = torch::empty(
218
+ {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
219
+ options);
220
+
221
+ // Get ptrs
222
+ uint32_t const* b_q_weight_ptr =
223
+ reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
224
+ uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
225
+
226
+ // Get dev info
227
+ int dev = b_q_weight.get_device();
228
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
229
+ int blocks;
230
+ cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
231
+
232
+ int max_shared_mem = 0;
233
+ cudaDeviceGetAttribute(&max_shared_mem,
234
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
235
+ TORCH_CHECK(max_shared_mem > 0);
236
+
237
+ if (false) {
238
+ }
239
+ CALL_IF(4)
240
+ CALL_IF(8)
241
+ else {
242
+ TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
243
+ }
244
+
245
+ return out;
246
+ }
247
+
248
+ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
249
+ c10::SymInt size_k, c10::SymInt size_n,
250
+ int64_t num_bits) {
251
+ int const pack_factor = 32 / num_bits;
252
+ auto options = torch::TensorOptions()
253
+ .dtype(b_q_weight.dtype())
254
+ .device(b_q_weight.device());
255
+ return torch::empty_symint(
256
+ {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
257
+ options);
258
+ }
gptq_marlin/gptq_marlin.cu ADDED
@@ -0,0 +1,2423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Modified by Neural Magic
3
+ * Copyright (C) Marlin.2024 Elias Frantar
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ /*
19
+ * Adapted from https://github.com/IST-DASLab/marlin
20
+ */
21
+
22
+ #include "marlin.cuh"
23
+ #include "marlin_dtypes.cuh"
24
+ #include "core/scalar_type.hpp"
25
+
26
+ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
27
+ static_assert(std::is_same<scalar_t, half>::value || \
28
+ std::is_same<scalar_t, nv_bfloat16>::value, \
29
+ "only float16 and bfloat16 is supported");
30
+
31
+ template <typename T>
32
+ inline std::string str(T x) {
33
+ return std::to_string(x);
34
+ }
35
+
36
+ namespace marlin {
37
+
38
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
39
+
40
+ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
41
+ int const* __restrict__ perm_int_ptr,
42
+ int4* __restrict__ out_int4_ptr, int size_m,
43
+ int size_k, int block_rows) {}
44
+
45
+ template <typename scalar_t, // compute dtype, half or nv_float16
46
+ const vllm::ScalarTypeId w_type_id, // weight ScalarType id
47
+ const int threads, // number of threads in a threadblock
48
+ const int thread_m_blocks, // number of 16x16 blocks in the m
49
+ // dimension (batchsize) of the
50
+ // threadblock
51
+ const int thread_n_blocks, // same for n dimension (output)
52
+ const int thread_k_blocks, // same for k dimension (reduction)
53
+ const int stages, // number of stages for the async global->shared
54
+ // fetch pipeline
55
+ const bool has_act_order, // whether act_order is enabled
56
+ const int group_blocks = -1, // number of consecutive 16x16 blocks
57
+ // with a separate quantization scale
58
+ const bool is_zp_float // is zero point of float16 type?
59
+ >
60
+ __global__ void Marlin(
61
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
62
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
63
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
64
+ int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
65
+ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
66
+ // (k/groupsize)xn
67
+ const int* __restrict__ g_idx, // int32 group indices of shape k
68
+ int num_groups, // number of scale groups per output channel
69
+ int prob_m, // batch dimension m
70
+ int prob_n, // output dimension n
71
+ int prob_k, // reduction dimension k
72
+ int* locks, // extra global storage for barrier synchronization
73
+ bool use_fp32_reduce // whether to use fp32 global reduce
74
+ ) {}
75
+
76
+ } // namespace marlin
77
+
78
+ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
79
+ torch::Tensor& b_scales, torch::Tensor& b_zeros,
80
+ torch::Tensor& g_idx, torch::Tensor& perm,
81
+ torch::Tensor& workspace,
82
+ vllm::ScalarTypeId const b_q_type_id,
83
+ int64_t size_m, int64_t size_n, int64_t size_k,
84
+ bool is_k_full, bool has_zp, bool is_zp_float) {
85
+ TORCH_CHECK_NOT_IMPLEMENTED(false,
86
+ "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
87
+ return torch::empty({1, 1});
88
+ }
89
+
90
+ #else
91
+
92
+ // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
93
+ // output/accumulation.
94
+ template <typename scalar_t>
95
+ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
96
+ const typename ScalarType<scalar_t>::FragB& frag_b,
97
+ typename ScalarType<scalar_t>::FragC& frag_c) {
98
+ const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
99
+ const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
100
+ float* c = reinterpret_cast<float*>(&frag_c);
101
+ if constexpr (std::is_same<scalar_t, half>::value) {
102
+ asm volatile(
103
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
104
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
105
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
106
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
107
+ "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
108
+ } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
109
+ asm volatile(
110
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
111
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
112
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
113
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
114
+ "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
115
+ } else {
116
+ STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
117
+ }
118
+ }
119
+
120
+ // Instruction for loading a full 16x16 matrix fragment of operand A from shared
121
+ // memory, directly in tensor core layout.
122
+ template <typename scalar_t>
123
+ __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
124
+ const void* smem_ptr) {
125
+ uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
126
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
127
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
128
+ : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
129
+ : "r"(smem));
130
+ }
131
+
132
+ // Lookup-table based 3-input logical operation; explicitly used for
133
+ // dequantization as the compiler does not seem to automatically recognize it in
134
+ // all cases.
135
+ template <int lut>
136
+ __device__ inline int lop3(int a, int b, int c) {
137
+ int res;
138
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
139
+ : "=r"(res)
140
+ : "r"(a), "r"(b), "r"(c), "n"(lut));
141
+ return res;
142
+ }
143
+
144
+ // Constructs destination register by taking bytes from 2 sources (based on
145
+ // mask)
146
+ template <int start_byte, int mask>
147
+ __device__ inline uint32_t prmt(uint32_t a) {
148
+ uint32_t res;
149
+ asm volatile("prmt.b32 %0, %1, %2, %3;\n"
150
+ : "=r"(res)
151
+ : "r"(a), "n"(start_byte), "n"(mask));
152
+ return res;
153
+ }
154
+
155
+ template <typename scalar_t, vllm::ScalarTypeId w_type_id>
156
+ __device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
157
+
158
+ //
159
+ // Efficiently dequantize 4bit values packed in an int32 value into a full
160
+ // B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
161
+ // with some small changes:
162
+ // - FP16:
163
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
164
+ // - BF16:
165
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
166
+ //
167
+ template <>
168
+ __device__ inline typename ScalarType<half>::FragB
169
+ dequant<half, vllm::kU4B8.id()>(int q) {
170
+ const int LO = 0x000f000f;
171
+ const int HI = 0x00f000f0;
172
+ const int EX = 0x64006400;
173
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
174
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
175
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
176
+ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
177
+ // directly into `SUB` and `ADD`.
178
+ const int SUB = 0x64086408;
179
+ const int MUL = 0x2c002c00;
180
+ const int ADD = 0xd480d480;
181
+ typename ScalarType<half>::FragB frag_b;
182
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
183
+ *reinterpret_cast<const half2*>(&SUB));
184
+ frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
185
+ *reinterpret_cast<const half2*>(&MUL),
186
+ *reinterpret_cast<const half2*>(&ADD));
187
+ return frag_b;
188
+ }
189
+
190
+ template <>
191
+ __device__ inline typename ScalarType<nv_bfloat16>::FragB
192
+ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
193
+ static constexpr uint32_t MASK = 0x000f000f;
194
+ static constexpr uint32_t EX = 0x43004300;
195
+
196
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
197
+
198
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
199
+ q >>= 4;
200
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
201
+
202
+ typename ScalarType<nv_bfloat16>::FragB frag_b;
203
+ static constexpr uint32_t MUL = 0x3F803F80;
204
+ static constexpr uint32_t ADD = 0xC308C308;
205
+
206
+ frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
207
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
208
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
209
+ frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
210
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
211
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
212
+ return frag_b;
213
+ }
214
+
215
+ template <>
216
+ __device__ inline typename ScalarType<half>::FragB
217
+ dequant<half, vllm::kU4.id()>(int q) {
218
+ const int LO = 0x000f000f;
219
+ const int HI = 0x00f000f0;
220
+ const int EX = 0x64006400;
221
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
222
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
223
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
224
+
225
+ const int SUB = 0x64006400;
226
+ const int MUL = 0x2c002c00;
227
+ const int ADD = 0xd400d400;
228
+ typename ScalarType<half>::FragB frag_b;
229
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
230
+ *reinterpret_cast<const half2*>(&SUB));
231
+ frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
232
+ *reinterpret_cast<const half2*>(&MUL),
233
+ *reinterpret_cast<const half2*>(&ADD));
234
+ return frag_b;
235
+ }
236
+
237
+ template <>
238
+ __device__ inline typename ScalarType<nv_bfloat16>::FragB
239
+ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
240
+ static constexpr uint32_t MASK = 0x000f000f;
241
+ static constexpr uint32_t EX = 0x43004300;
242
+
243
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
244
+
245
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
246
+ q >>= 4;
247
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
248
+
249
+ typename ScalarType<nv_bfloat16>::FragB frag_b;
250
+ static constexpr uint32_t MUL = 0x3F803F80;
251
+ static constexpr uint32_t ADD = 0xC300C300;
252
+
253
+ frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
254
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
255
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
256
+ frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
257
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
258
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
259
+ return frag_b;
260
+ }
261
+
262
+ //
263
+ // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
264
+ // bf16 Reference:
265
+ // - FP16:
266
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
267
+ // - BF16:
268
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
269
+ //
270
+ template <>
271
+ __device__ inline typename ScalarType<half>::FragB
272
+ dequant<half, vllm::kU8B128.id()>(int q) {
273
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
274
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
275
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
276
+
277
+ uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
278
+ uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
279
+
280
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
281
+
282
+ typename ScalarType<half>::FragB frag_b;
283
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
284
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
285
+ frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
286
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
287
+ return frag_b;
288
+ }
289
+
290
+ template <>
291
+ __device__ inline typename ScalarType<nv_bfloat16>::FragB
292
+ dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) {
293
+ typename ScalarType<nv_bfloat16>::FragB frag_b;
294
+
295
+ float fp32_intermediates[4];
296
+ uint32_t* fp32_intermediates_casted =
297
+ reinterpret_cast<uint32_t*>(fp32_intermediates);
298
+
299
+ static constexpr uint32_t fp32_base = 0x4B000000;
300
+ fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
301
+ fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
302
+ fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
303
+ fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
304
+
305
+ fp32_intermediates[0] -= 8388736.f;
306
+ fp32_intermediates[1] -= 8388736.f;
307
+ fp32_intermediates[2] -= 8388736.f;
308
+ fp32_intermediates[3] -= 8388736.f;
309
+
310
+ uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
311
+ bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
312
+ fp32_intermediates_casted[1], 0x7632);
313
+ bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
314
+ fp32_intermediates_casted[3], 0x7632);
315
+
316
+ return frag_b;
317
+ }
318
+
319
+ template <>
320
+ __device__ inline typename ScalarType<half>::FragB
321
+ dequant<half, vllm::kU8.id()>(int q) {
322
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
323
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
324
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
325
+
326
+ uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
327
+ uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
328
+
329
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
330
+
331
+ typename ScalarType<half>::FragB frag_b;
332
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
333
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
334
+ frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
335
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
336
+ return frag_b;
337
+ }
338
+
339
+ template <>
340
+ __device__ inline typename ScalarType<nv_bfloat16>::FragB
341
+ dequant<nv_bfloat16, vllm::kU8.id()>(int q) {
342
+ typename ScalarType<nv_bfloat16>::FragB frag_b;
343
+
344
+ float fp32_intermediates[4];
345
+ uint32_t* fp32_intermediates_casted =
346
+ reinterpret_cast<uint32_t*>(fp32_intermediates);
347
+
348
+ static constexpr uint32_t fp32_base = 0x4B000000;
349
+ fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
350
+ fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
351
+ fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
352
+ fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
353
+
354
+ fp32_intermediates[0] -= 8388608.f;
355
+ fp32_intermediates[1] -= 8388608.f;
356
+ fp32_intermediates[2] -= 8388608.f;
357
+ fp32_intermediates[3] -= 8388608.f;
358
+
359
+ uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
360
+ bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
361
+ fp32_intermediates_casted[1], 0x7632);
362
+ bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
363
+ fp32_intermediates_casted[3], 0x7632);
364
+
365
+ return frag_b;
366
+ }
367
+
368
+ // Multiply dequantized values by the corresponding quantization scale; used
369
+ // only for grouped quantization.
370
+ template <typename scalar_t>
371
+ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
372
+ typename ScalarType<scalar_t>::FragS& frag_s,
373
+ int i) {
374
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
375
+ scalar_t2 s =
376
+ ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
377
+ frag_b[0] = __hmul2(frag_b[0], s);
378
+ frag_b[1] = __hmul2(frag_b[1], s);
379
+ }
380
+
381
+ template <typename scalar_t>
382
+ __device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
383
+ typename ScalarType<scalar_t>::scalar_t2& frag_zp,
384
+ int i) {
385
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
386
+ scalar_t2 zp =
387
+ ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
388
+ frag_b[0] = __hsub2(frag_b[0], zp);
389
+ frag_b[1] = __hsub2(frag_b[1], zp);
390
+ }
391
+
392
+ // Same as above, but for act_order (each K is multiplied individually)
393
+ template <typename scalar_t>
394
+ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
395
+ typename ScalarType<scalar_t>::FragS& frag_s_1,
396
+ typename ScalarType<scalar_t>::FragS& frag_s_2,
397
+ typename ScalarType<scalar_t>::FragS& frag_s_3,
398
+ typename ScalarType<scalar_t>::FragS& frag_s_4,
399
+ int i) {
400
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
401
+ scalar_t2 s_val_1_2;
402
+ s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
403
+ s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
404
+
405
+ scalar_t2 s_val_3_4;
406
+ s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];
407
+ s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];
408
+
409
+ frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
410
+ frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
411
+ }
412
+
413
+ // Given 2 floats multiply by 2 scales (halves)
414
+ template <typename scalar_t>
415
+ __device__ inline void scale_float(float* c,
416
+ typename ScalarType<scalar_t>::FragS& s) {
417
+ scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
418
+ c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
419
+ c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
420
+ }
421
+
422
+ // Wait until barrier reaches `count`, then lock for current threadblock.
423
+ __device__ inline void barrier_acquire(int* lock, int count) {
424
+ if (threadIdx.x == 0) {
425
+ int state = -1;
426
+ do
427
+ // Guarantee that subsequent writes by this threadblock will be visible
428
+ // globally.
429
+ asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
430
+ : "=r"(state)
431
+ : "l"(lock));
432
+ while (state != count);
433
+ }
434
+ __syncthreads();
435
+ }
436
+
437
+ // Release barrier and increment visitation count.
438
+ __device__ inline void barrier_release(int* lock, bool reset = false) {
439
+ __syncthreads();
440
+ if (threadIdx.x == 0) {
441
+ if (reset) {
442
+ lock[0] = 0;
443
+ return;
444
+ }
445
+ int val = 1;
446
+ // Make sure that all writes since acquiring this barrier are visible
447
+ // globally, while releasing the barrier.
448
+ asm volatile("fence.acq_rel.gpu;\n");
449
+ asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
450
+ :
451
+ : "l"(lock), "r"(val));
452
+ }
453
+ }
454
+
455
+ // For a given "a" of size [M,K] performs a permutation of the K columns based
456
+ // on the given "perm" indices.
457
+ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
458
+ int const* __restrict__ perm_int_ptr,
459
+ int4* __restrict__ out_int4_ptr, int size_m,
460
+ int size_k, int block_rows) {
461
+ int start_row = block_rows * blockIdx.x;
462
+ int finish_row = start_row + block_rows;
463
+ if (finish_row > size_m) {
464
+ finish_row = size_m;
465
+ }
466
+ int cur_block_rows = finish_row - start_row;
467
+
468
+ int row_stride = size_k * sizeof(half) / 16;
469
+
470
+ auto permute_row = [&](int row) {
471
+ int iters = size_k / default_threads;
472
+ int rest = size_k % default_threads;
473
+
474
+ int offset = row * row_stride;
475
+
476
+ half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
477
+ half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
478
+
479
+ int base_k = 0;
480
+
481
+ for (int i = 0; i < iters; i++) {
482
+ int cur_k = base_k + threadIdx.x;
483
+ int src_pos = perm_int_ptr[cur_k];
484
+
485
+ out_half[cur_k] = a_row_half[src_pos];
486
+
487
+ base_k += default_threads;
488
+ }
489
+
490
+ if (rest) {
491
+ if (threadIdx.x < rest) {
492
+ int cur_k = base_k + threadIdx.x;
493
+ int src_pos = perm_int_ptr[cur_k];
494
+
495
+ out_half[cur_k] = a_row_half[src_pos];
496
+ }
497
+ }
498
+ };
499
+
500
+ for (int i = 0; i < cur_block_rows; i++) {
501
+ int cur_row = start_row + i;
502
+ if (cur_row < size_m) {
503
+ permute_row(cur_row);
504
+ }
505
+ }
506
+ }
507
+
508
+ template <typename scalar_t, // compute dtype, half or nv_float16
509
+ const vllm::ScalarTypeId w_type_id, // weight ScalarType id
510
+ const int threads, // number of threads in a threadblock
511
+ const int thread_m_blocks, // number of 16x16 blocks in the m
512
+ // dimension (batchsize) of the
513
+ // threadblock
514
+ const int thread_n_blocks, // same for n dimension (output)
515
+ const int thread_k_blocks, // same for k dimension (reduction)
516
+ const int stages, // number of stages for the async global->shared
517
+ // fetch pipeline
518
+ const bool has_act_order, // whether act_order is enabled
519
+ const bool has_zp, // whether zero-points are enabled
520
+ const int group_blocks = -1, // number of consecutive 16x16 blocks
521
+ // with a separate quantization scale
522
+ const bool is_zp_float // is zero point of float16 type?
523
+ >
524
+ __global__ void Marlin(
525
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
526
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
527
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
528
+ int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
529
+ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
530
+ // (k/groupsize)xn
531
+ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
532
+ // (k/groupsize)x(n/pack_factor)
533
+ const int* __restrict__ g_idx, // int32 group indices of shape k
534
+ int num_groups, // number of scale groups per output channel
535
+ int prob_m, // batch dimension m
536
+ int prob_n, // output dimension n
537
+ int prob_k, // reduction dimension k
538
+ int* locks, // extra global storage for barrier synchronization
539
+ bool use_fp32_reduce // whether to use fp32 global reduce
540
+ ) {
541
+ // Each threadblock processes one "stripe" of the B matrix with (roughly) the
542
+ // same size, which might involve multiple column "slices" (of width 16 *
543
+ // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
544
+ // example:
545
+ // 0 1 3
546
+ // 0 2 3
547
+ // 1 2 4
548
+ // While this kind of partitioning makes things somewhat more complicated, it
549
+ // ensures good utilization of all SMs for many kinds of shape and GPU
550
+ // configurations, while requiring as few slow global cross-threadblock
551
+ // reductions as possible.
552
+ using Dtype = ScalarType<scalar_t>;
553
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
554
+ using FragA = typename ScalarType<scalar_t>::FragA;
555
+ using FragB = typename ScalarType<scalar_t>::FragB;
556
+ using FragC = typename ScalarType<scalar_t>::FragC;
557
+ using FragS = typename ScalarType<scalar_t>::FragS;
558
+ using FragZP = typename ScalarType<scalar_t>::FragZP;
559
+
560
+ static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
561
+
562
+ constexpr int pack_factor = 32 / w_type.size_bits();
563
+
564
+ // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
565
+ // better partitioning with less reductions
566
+ int parallel = 1;
567
+ if (prob_m > 16 * thread_m_blocks) {
568
+ parallel = prob_m / (16 * thread_m_blocks);
569
+ prob_m = 16 * thread_m_blocks;
570
+ }
571
+
572
+ int k_tiles = prob_k / 16 / thread_k_blocks;
573
+ int n_tiles = prob_n / 16 / thread_n_blocks;
574
+ int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
575
+
576
+ if constexpr (!has_act_order && group_blocks != -1) {
577
+ if (group_blocks >= thread_k_blocks) {
578
+ // Ensure that the number of tiles in each stripe is a multiple of the
579
+ // groupsize; this avoids an annoying special case where a stripe starts
580
+ // in the middle of group.
581
+ iters = (group_blocks / thread_k_blocks) *
582
+ div_ceil(iters, (group_blocks / thread_k_blocks));
583
+ }
584
+ }
585
+
586
+ int slice_row = (iters * blockIdx.x) % k_tiles;
587
+ int slice_col_par = (iters * blockIdx.x) / k_tiles;
588
+ int slice_col = slice_col_par;
589
+ int slice_iters; // number of threadblock tiles in the current slice
590
+ int slice_count =
591
+ 0; // total number of active threadblocks in the current slice
592
+ int slice_idx; // index of threadblock in current slice; numbered bottom to
593
+ // top
594
+
595
+ int par_id = 0;
596
+
597
+ // We can easily implement parallel problem execution by just remapping
598
+ // indices and advancing global pointers
599
+ if (slice_col_par >= n_tiles) {
600
+ A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
601
+ C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
602
+ locks += (slice_col_par / n_tiles) * n_tiles;
603
+ slice_col = slice_col_par % n_tiles;
604
+ par_id = slice_col_par / n_tiles;
605
+ }
606
+
607
+ // Compute all information about the current slice which is required for
608
+ // synchronization.
609
+ auto init_slice = [&]() {
610
+ slice_iters =
611
+ iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
612
+ if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
613
+ if (slice_iters == 0) return;
614
+ if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
615
+ slice_count = 1;
616
+ slice_idx = 0;
617
+ int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
618
+ if (col_first <= k_tiles * (slice_col_par + 1)) {
619
+ int col_off = col_first - k_tiles * slice_col_par;
620
+ slice_count = div_ceil(k_tiles - col_off, iters);
621
+ if (col_off > 0) slice_count++;
622
+ int delta_first = iters * blockIdx.x - col_first;
623
+ if (delta_first < 0 || (col_off == 0 && delta_first == 0))
624
+ slice_idx = slice_count - 1;
625
+ else {
626
+ slice_idx = slice_count - 1 - delta_first / iters;
627
+ if (col_off > 0) slice_idx--;
628
+ }
629
+ }
630
+ if (slice_col == n_tiles) {
631
+ A += 16 * thread_m_blocks * prob_k / 8;
632
+ C += 16 * thread_m_blocks * prob_n / 8;
633
+ locks += n_tiles;
634
+ slice_col = 0;
635
+ par_id++;
636
+ }
637
+ };
638
+ init_slice();
639
+
640
+ // A sizes/strides
641
+
642
+ // stride of the A matrix in global memory
643
+ int a_gl_stride = prob_k / 8;
644
+ // stride of an A matrix tile in shared memory
645
+ constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
646
+ // delta between subsequent A tiles in global memory
647
+ constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
648
+ // between subsequent accesses within a tile
649
+ int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
650
+ // between shared memory writes
651
+ constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
652
+ // between shared memory tile reads
653
+ constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
654
+ // within a shared memory tile
655
+ constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
656
+ // overall size of a tile
657
+ constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
658
+ // number of shared write iterations for a tile
659
+ constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
660
+
661
+ // B sizes/strides
662
+ int b_gl_stride = 16 * prob_n / (pack_factor * 4);
663
+ constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
664
+ constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
665
+ constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
666
+
667
+ int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
668
+ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
669
+ constexpr int b_sh_wr_delta = threads * b_thread_vecs;
670
+ constexpr int b_sh_rd_delta = threads * b_thread_vecs;
671
+ constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
672
+ constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
673
+
674
+ // Scale sizes/strides without act_order
675
+ int s_gl_stride = prob_n / 8;
676
+ constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
677
+ constexpr int s_tb_groups =
678
+ !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
679
+ ? thread_k_blocks / group_blocks
680
+ : 1;
681
+ constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
682
+ int s_gl_rd_delta = s_gl_stride;
683
+
684
+ // Scale size/strides with act_order
685
+ constexpr int tb_k = 16 * thread_k_blocks;
686
+ constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
687
+ // constexpr int act_s_row_stride = 1;
688
+ // int act_s_col_stride = act_s_row_stride * num_groups;
689
+ int act_s_col_stride = 1;
690
+ int act_s_col_warp_stride = act_s_col_stride * 8;
691
+ int tb_n_warps = thread_n_blocks / 4;
692
+ int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
693
+
694
+ // Zero-points sizes/strides
695
+ int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
696
+ constexpr int zp_sh_stride = is_zp_float
697
+ ? 16 * thread_n_blocks / 8
698
+ : ((16 * thread_n_blocks) / pack_factor) / 4;
699
+ constexpr int zp_tb_groups = s_tb_groups;
700
+ constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
701
+ int zp_gl_rd_delta = zp_gl_stride;
702
+
703
+ // Global A read index of current thread.
704
+ int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
705
+ (threadIdx.x % a_gl_rd_delta_o);
706
+ a_gl_rd += a_gl_rd_delta_o * slice_row;
707
+ // Shared write index of current thread.
708
+ int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
709
+ (threadIdx.x % a_gl_rd_delta_o);
710
+ // Shared read index.
711
+ int a_sh_rd =
712
+ a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
713
+ a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
714
+
715
+ int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
716
+ (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
717
+ b_gl_rd += b_sh_stride * slice_col;
718
+ b_gl_rd += b_gl_rd_delta_o * slice_row;
719
+ int b_sh_wr = threadIdx.x * b_thread_vecs;
720
+ int b_sh_rd = threadIdx.x * b_thread_vecs;
721
+
722
+ // For act_order
723
+ constexpr int k_iter_size = tb_k / b_sh_wr_iters;
724
+ int slice_k_start = tb_k * slice_row;
725
+ int slice_k_finish = slice_k_start + tb_k * slice_iters;
726
+ int slice_k_start_shared_fetch = slice_k_start;
727
+ int slice_n_offset = act_s_col_tb_stride * slice_col;
728
+
729
+ // No act_order
730
+ int s_gl_rd;
731
+ if constexpr (!has_act_order) {
732
+ if constexpr (group_blocks == -1) {
733
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
734
+ } else {
735
+ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
736
+ s_sh_stride * slice_col + threadIdx.x;
737
+ }
738
+ }
739
+ int s_sh_wr = threadIdx.x;
740
+ bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
741
+
742
+ // Zero-points
743
+ int zp_gl_rd;
744
+ if constexpr (has_zp) {
745
+ if constexpr (group_blocks == -1) {
746
+ zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
747
+ } else {
748
+ zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
749
+ zp_sh_stride * slice_col + threadIdx.x;
750
+ }
751
+ }
752
+ int zp_sh_wr = threadIdx.x;
753
+ bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
754
+
755
+ // We use a different scale layout for grouped and column-wise quantization as
756
+ // we scale a `half2` tile in column-major layout in the former and in
757
+ // row-major in the latter case.
758
+ int s_sh_rd;
759
+ if constexpr (group_blocks != -1)
760
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
761
+ (threadIdx.x % 32) / 4;
762
+ else
763
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
764
+ (threadIdx.x % 32) % 4;
765
+
766
+ // Zero-points have the same read layout as the scales
767
+ // (without column-wise case)
768
+ constexpr int num_col_threads = 8;
769
+ constexpr int num_row_threads = 4;
770
+ constexpr int num_ints_per_thread = 8 / pack_factor;
771
+ int zp_sh_rd;
772
+ if constexpr (has_zp) {
773
+ if constexpr (is_zp_float) {
774
+ if constexpr (group_blocks != -1) {
775
+ zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
776
+ (threadIdx.x % 32) / 4;
777
+ }
778
+ } else {
779
+ zp_sh_rd = num_ints_per_thread * num_col_threads *
780
+ ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
781
+ num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
782
+ }
783
+ }
784
+
785
+ // Precompute which thread should not read memory in which iterations; this is
786
+ // needed if there are more threads than required for a certain tilesize or
787
+ // when the batchsize is not a multiple of 16.
788
+ bool a_sh_wr_pred[a_sh_wr_iters];
789
+ #pragma unroll
790
+ for (int i = 0; i < a_sh_wr_iters; i++)
791
+ a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
792
+
793
+ // To ensure that writing and reading A tiles to/from shared memory, the
794
+ // latter in fragment format, is fully bank conflict free, we need to use a
795
+ // rather fancy XOR-based layout. The key here is that neither reads nor
796
+ // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
797
+ // same shared memory banks. Further, it seems (based on NSight-Compute) that
798
+ // each warp must also write a consecutive memory segment?
799
+ auto transform_a = [&](int i) {
800
+ int row = i / a_gl_rd_delta_o;
801
+ return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
802
+ };
803
+ // Since the computation of this remapping is non-trivial and, due to our main
804
+ // loop unrolls, all shared memory accesses are static, we simply precompute
805
+ // both transformed reads and writes.
806
+ int a_sh_wr_trans[a_sh_wr_iters];
807
+ #pragma unroll
808
+ for (int i = 0; i < a_sh_wr_iters; i++)
809
+ a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
810
+ int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
811
+ #pragma unroll
812
+ for (int i = 0; i < b_sh_wr_iters; i++) {
813
+ #pragma unroll
814
+ for (int j = 0; j < thread_m_blocks; j++)
815
+ a_sh_rd_trans[i][j] =
816
+ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
817
+ }
818
+
819
+ // Since B-accesses have non-constant stride they have to be computed at
820
+ // runtime; we break dependencies between subsequent accesses with a tile by
821
+ // maintining multiple pointers (we have enough registers), a tiny
822
+ // optimization.
823
+ const int4* B_ptr[b_sh_wr_iters];
824
+ #pragma unroll
825
+ for (int i = 0; i < b_sh_wr_iters; i++)
826
+ B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
827
+
828
+ extern __shared__ int4 sh[];
829
+ // Shared memory storage for global fetch pipelines.
830
+ int4* sh_a = sh;
831
+ int4* sh_b = sh_a + (stages * a_sh_stage);
832
+ int4* sh_g_idx = sh_b + (stages * b_sh_stage);
833
+ int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
834
+ int4* sh_s = sh_zp + (stages * zp_sh_stage);
835
+
836
+ // Register storage for double buffer of shared memory reads.
837
+ FragA frag_a[2][thread_m_blocks];
838
+ I4 frag_b_quant[2][b_thread_vecs];
839
+ FragC frag_c[thread_m_blocks][4][2];
840
+ FragS frag_s[2][4]; // No act-order
841
+ FragS act_frag_s[2][4][4]; // For act-order
842
+ int frag_qzp[2][num_ints_per_thread]; // Zero-points
843
+ FragZP frag_zp; // Zero-points in fp16
844
+ FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
845
+
846
+ // Zero accumulators.
847
+ auto zero_accums = [&]() {
848
+ #pragma unroll
849
+ for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
850
+ reinterpret_cast<float*>(frag_c)[i] = 0;
851
+ };
852
+
853
+ int sh_first_group_id = -1;
854
+ int sh_num_groups = -1;
855
+ constexpr int sh_max_num_groups = 32;
856
+
857
+ auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
858
+ int last_group_id) {
859
+ sh_first_group_id = first_group_id;
860
+ sh_num_groups = last_group_id - first_group_id + 1;
861
+
862
+ if (sh_num_groups < sh_max_num_groups) {
863
+ sh_num_groups = sh_max_num_groups;
864
+ }
865
+
866
+ if (sh_first_group_id + sh_num_groups > num_groups) {
867
+ sh_num_groups = num_groups - sh_first_group_id;
868
+ }
869
+
870
+ int row_offset = first_group_id * s_gl_stride;
871
+
872
+ if (is_async) {
873
+ for (int i = 0; i < sh_num_groups; i++) {
874
+ if (threadIdx.x < s_sh_stride) {
875
+ cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
876
+ &scales_ptr[row_offset + (i * s_gl_stride) +
877
+ slice_n_offset + threadIdx.x]);
878
+ }
879
+ }
880
+ } else {
881
+ for (int i = 0; i < sh_num_groups; i++) {
882
+ if (threadIdx.x < s_sh_stride) {
883
+ sh_s[(i * s_sh_stride) + threadIdx.x] =
884
+ scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
885
+ threadIdx.x];
886
+ }
887
+ }
888
+ }
889
+ };
890
+ // Asynchronously fetch the next A, B and s tile from global to the next
891
+ // shared memory pipeline location.
892
+ auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
893
+ if (pred) {
894
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
895
+ #pragma unroll
896
+ for (int i = 0; i < a_sh_wr_iters; i++) {
897
+ cp_async4_pred(
898
+ &sh_a_stage[a_sh_wr_trans[i]],
899
+ &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
900
+ a_sh_wr_pred[i]);
901
+ }
902
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
903
+ #pragma unroll
904
+ for (int i = 0; i < b_sh_wr_iters; i++) {
905
+ #pragma unroll
906
+ for (int j = 0; j < b_thread_vecs; j++) {
907
+ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
908
+ }
909
+
910
+ B_ptr[i] += b_gl_rd_delta_o;
911
+ }
912
+
913
+ if constexpr (has_act_order) {
914
+ // Fetch g_idx thread-block portion
915
+ int full_pipe = a_off;
916
+ int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
917
+ if (cur_k < prob_k && cur_k < slice_k_finish) {
918
+ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
919
+
920
+ int4 const* cur_g_idx_stage_ptr =
921
+ reinterpret_cast<int4 const*>(&g_idx[cur_k]);
922
+
923
+ if (threadIdx.x < g_idx_stage) {
924
+ cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
925
+ &cur_g_idx_stage_ptr[threadIdx.x]);
926
+ }
927
+ }
928
+ } else {
929
+ if constexpr (group_blocks != -1) {
930
+ int4* sh_s_stage = sh_s + s_sh_stage * pipe;
931
+
932
+ if constexpr (group_blocks >= thread_k_blocks) {
933
+ // Only fetch scales if this tile starts a new group
934
+ if (pipe % (group_blocks / thread_k_blocks) == 0) {
935
+ if (s_sh_wr_pred) {
936
+ cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
937
+ }
938
+ s_gl_rd += s_gl_rd_delta;
939
+ }
940
+ } else {
941
+ for (int i = 0; i < s_tb_groups; i++) {
942
+ if (s_sh_wr_pred) {
943
+ cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
944
+ &scales_ptr[s_gl_rd]);
945
+ }
946
+ s_gl_rd += s_gl_rd_delta;
947
+ }
948
+ }
949
+ }
950
+
951
+ if constexpr (has_zp && group_blocks != -1) {
952
+ int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
953
+
954
+ if constexpr (group_blocks >= thread_k_blocks) {
955
+ // Only fetch zero-points if this tile starts a new group
956
+ if (pipe % (group_blocks / thread_k_blocks) == 0) {
957
+ if (zp_sh_wr_pred) {
958
+ cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
959
+ }
960
+ zp_gl_rd += zp_gl_rd_delta;
961
+ }
962
+ } else {
963
+ for (int i = 0; i < zp_tb_groups; i++) {
964
+ if (zp_sh_wr_pred) {
965
+ cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
966
+ &zp_ptr[zp_gl_rd]);
967
+ }
968
+ zp_gl_rd += zp_gl_rd_delta;
969
+ }
970
+ }
971
+ }
972
+ }
973
+ }
974
+ // Insert a fence even when we are winding down the pipeline to ensure that
975
+ // waiting is also correct at this point.
976
+ cp_async_fence();
977
+ };
978
+
979
+ auto fetch_zp_to_shared = [&]() {
980
+ if (zp_sh_wr_pred) {
981
+ cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
982
+ }
983
+ };
984
+
985
+ // Wait until the next thread tile has been loaded to shared memory.
986
+ auto wait_for_stage = [&]() {
987
+ // We only have `stages - 2` active fetches since we are double buffering
988
+ // and can only issue the next fetch when it is guaranteed that the previous
989
+ // shared memory load is fully complete (as it may otherwise be
990
+ // overwritten).
991
+ cp_async_wait<stages - 2>();
992
+ __syncthreads();
993
+ };
994
+
995
+ // Load the next sub-tile from the current location in the shared memory pipe
996
+ // into the current register buffer.
997
+ auto fetch_to_registers = [&](int k, int pipe) {
998
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
999
+ #pragma unroll
1000
+ for (int i = 0; i < thread_m_blocks; i++)
1001
+ ldsm4<scalar_t>(frag_a[k % 2][i],
1002
+ &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
1003
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
1004
+
1005
+ #pragma unroll
1006
+ for (int i = 0; i < b_thread_vecs; i++) {
1007
+ frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
1008
+ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
1009
+ }
1010
+ };
1011
+
1012
+ bool is_same_group[stages];
1013
+ int same_group_id[stages];
1014
+
1015
+ auto init_same_group = [&](int pipe) {
1016
+ if constexpr (!has_act_order) {
1017
+ is_same_group[pipe] = false;
1018
+ same_group_id[pipe] = 0;
1019
+ return;
1020
+ }
1021
+
1022
+ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
1023
+ int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
1024
+
1025
+ int group_id_1 = sh_g_idx_int_ptr[0];
1026
+ int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
1027
+
1028
+ is_same_group[pipe] = group_id_1 == group_id_2;
1029
+ same_group_id[pipe] = group_id_1;
1030
+ };
1031
+
1032
+ auto fetch_scales_to_registers = [&](int k, int full_pipe) {
1033
+ int pipe = full_pipe % stages;
1034
+
1035
+ if constexpr (!has_act_order) {
1036
+ // No act-order case
1037
+ if constexpr (group_blocks != -1) {
1038
+ if constexpr (group_blocks >= thread_k_blocks) {
1039
+ int4* sh_s_stage =
1040
+ sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
1041
+ (pipe / (group_blocks / thread_k_blocks)));
1042
+ reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
1043
+ } else {
1044
+ int warp_id = threadIdx.x / 32;
1045
+ int n_warps = thread_n_blocks / 4;
1046
+
1047
+ int warp_row = warp_id / n_warps;
1048
+
1049
+ int cur_k = warp_row * 16;
1050
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
1051
+
1052
+ int k_blocks = cur_k / 16;
1053
+ int cur_group_id = k_blocks / group_blocks;
1054
+
1055
+ int4* sh_s_stage = sh_s + s_sh_stage * pipe;
1056
+
1057
+ reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
1058
+ sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
1059
+ }
1060
+ }
1061
+
1062
+ return;
1063
+ }
1064
+
1065
+ // Act-order case
1066
+
1067
+ // Determine K of the "current" thread-block
1068
+ int cur_k = slice_k_start + tb_k * full_pipe;
1069
+ if (cur_k >= prob_k || cur_k >= slice_k_finish) {
1070
+ return;
1071
+ }
1072
+
1073
+ // Reset (to current thread-block) since we read g_idx portion from the
1074
+ // shared memory
1075
+ cur_k = 0;
1076
+
1077
+ // Progress to current iteration
1078
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
1079
+
1080
+ // Determine "position" inside the thread-block (based on warp and
1081
+ // thread-id)
1082
+ int warp_id = threadIdx.x / 32;
1083
+ int n_warps =
1084
+ thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
1085
+
1086
+ int warp_row = warp_id / n_warps;
1087
+ int warp_col = warp_id % n_warps;
1088
+
1089
+ cur_k += warp_row * 16;
1090
+
1091
+ int th_id = threadIdx.x % 32;
1092
+ cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
1093
+
1094
+ int s_col_shift =
1095
+ /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
1096
+ (th_id / 4) * act_s_col_stride;
1097
+
1098
+ if (is_same_group[pipe]) {
1099
+ if (k % 2 == 0) {
1100
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
1101
+ sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
1102
+ s_col_shift];
1103
+ } else {
1104
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
1105
+ *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
1106
+ }
1107
+
1108
+ for (int i = 1; i < 4; i++) {
1109
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
1110
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
1111
+ }
1112
+ return;
1113
+ }
1114
+
1115
+ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
1116
+ int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
1117
+
1118
+ constexpr int k_frag_offsets[4] = {0, 1, 8,
1119
+ 9}; // Tensor core offsets per thread
1120
+
1121
+ #pragma unroll
1122
+ for (int i = 0; i < 4; i++) {
1123
+ int actual_k = cur_k + k_frag_offsets[i];
1124
+
1125
+ int group_id = sh_g_idx_int_ptr[actual_k];
1126
+ int rel_group_id = group_id - sh_first_group_id;
1127
+
1128
+ *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
1129
+ sh_s[rel_group_id * s_sh_stride + s_col_shift];
1130
+ }
1131
+ };
1132
+
1133
+ auto fetch_zp_to_registers = [&](int k, int full_pipe) {
1134
+ // This code does not handle group_blocks == 0,
1135
+ // which signifies act_order.
1136
+ // has_zp implies AWQ, which doesn't have act_order,
1137
+ static_assert(!has_zp || group_blocks != 0);
1138
+
1139
+ if constexpr (has_zp && !is_zp_float) {
1140
+ int pipe = full_pipe % stages;
1141
+
1142
+ if constexpr (group_blocks == -1) {
1143
+ for (int i = 0; i < num_ints_per_thread; i++) {
1144
+ frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
1145
+ }
1146
+
1147
+ } else if constexpr (group_blocks >= thread_k_blocks) {
1148
+ int4* sh_zp_stage =
1149
+ sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
1150
+ (pipe / (group_blocks / thread_k_blocks)));
1151
+ for (int i = 0; i < num_ints_per_thread; i++) {
1152
+ frag_qzp[k % 2][i] =
1153
+ (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
1154
+ }
1155
+ } else {
1156
+ int warp_id = threadIdx.x / 32;
1157
+ int n_warps = thread_n_blocks / 4;
1158
+
1159
+ int warp_row = warp_id / n_warps;
1160
+
1161
+ int cur_k = warp_row * 16;
1162
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
1163
+
1164
+ int k_blocks = cur_k / 16;
1165
+ int cur_group_id = 0;
1166
+
1167
+ // Suppress bogus and persistent divide-by-zero warning
1168
+ #pragma nv_diagnostic push
1169
+ #pragma nv_diag_suppress divide_by_zero
1170
+ cur_group_id = k_blocks / group_blocks;
1171
+ #pragma nv_diagnostic pop
1172
+
1173
+ int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
1174
+
1175
+ sh_zp_stage += cur_group_id * zp_sh_stride;
1176
+
1177
+ for (int i = 0; i < num_ints_per_thread; i++) {
1178
+ frag_qzp[k % 2][i] =
1179
+ (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
1180
+ }
1181
+ }
1182
+ }
1183
+
1184
+ else if constexpr (has_zp && is_zp_float) {
1185
+ int pipe = full_pipe % stages;
1186
+
1187
+ if constexpr (group_blocks != -1) {
1188
+ if constexpr (group_blocks >= thread_k_blocks) {
1189
+ int4* sh_zp_stage =
1190
+ sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
1191
+ (pipe / (group_blocks / thread_k_blocks)));
1192
+ reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
1193
+ } else {
1194
+ int warp_id = threadIdx.x / 32;
1195
+ int n_warps = thread_n_blocks / 4;
1196
+
1197
+ int warp_row = warp_id / n_warps;
1198
+
1199
+ int cur_k = warp_row * 16;
1200
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
1201
+
1202
+ int k_blocks = cur_k / 16;
1203
+ // Suppress bogus and persistent divide-by-zero warning
1204
+ #pragma nv_diagnostic push
1205
+ #pragma nv_diag_suppress divide_by_zero
1206
+ int cur_group_id = k_blocks / group_blocks;
1207
+ #pragma nv_diagnostic pop
1208
+
1209
+ int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
1210
+
1211
+ reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
1212
+ sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
1213
+ }
1214
+ }
1215
+ }
1216
+ };
1217
+
1218
+ // Execute the actual tensor core matmul of a sub-tile.
1219
+ auto matmul = [&](int k) {
1220
+ if constexpr (has_zp && !is_zp_float) {
1221
+ FragB frag_zp_0;
1222
+ FragB frag_zp_1;
1223
+ int zp_quant_0, zp_quant_1;
1224
+
1225
+ if constexpr (w_type.size_bits() == 4) {
1226
+ zp_quant_0 = frag_qzp[k % 2][0];
1227
+ zp_quant_1 = zp_quant_0 >> 8;
1228
+ } else {
1229
+ static_assert(w_type.size_bits() == 8);
1230
+ zp_quant_0 = frag_qzp[k % 2][0];
1231
+ zp_quant_1 = frag_qzp[k % 2][1];
1232
+ }
1233
+
1234
+ frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
1235
+ frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
1236
+
1237
+ frag_zp[0] = frag_zp_0[0];
1238
+ frag_zp[1] = frag_zp_0[1];
1239
+ frag_zp[2] = frag_zp_1[0];
1240
+ frag_zp[3] = frag_zp_1[1];
1241
+ }
1242
+
1243
+ // We have the m dimension as the inner loop in order to encourage overlapping
1244
+ // dequantization and matmul operations.
1245
+ #pragma unroll
1246
+ for (int j = 0; j < 4; j++) {
1247
+ FragB frag_b0;
1248
+ FragB frag_b1;
1249
+ int b_quant_0, b_quant_1;
1250
+
1251
+ if constexpr (w_type.size_bits() == 4) {
1252
+ b_quant_0 = frag_b_quant[k % 2][0][j];
1253
+ b_quant_1 = b_quant_0 >> 8;
1254
+ } else {
1255
+ static_assert(w_type.size_bits() == 8);
1256
+ int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
1257
+ b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
1258
+ b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
1259
+ }
1260
+
1261
+ frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
1262
+ frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
1263
+
1264
+ // Apply zero-point to frag_b0
1265
+ if constexpr (has_zp && !is_zp_float) {
1266
+ sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
1267
+ }
1268
+
1269
+ else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
1270
+ sub_zp<scalar_t>(frag_b0, frag_zpf[k % 2][j], 0);
1271
+ }
1272
+
1273
+ // Apply scale to frag_b0
1274
+ if constexpr (has_act_order) {
1275
+ scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
1276
+ act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
1277
+ act_frag_s[k % 2][3][j], 0);
1278
+ } else {
1279
+ if constexpr (group_blocks != -1) {
1280
+ scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
1281
+ }
1282
+ }
1283
+
1284
+ // Apply zero-point to frag_b1
1285
+ if constexpr (has_zp && !is_zp_float) {
1286
+ sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
1287
+ }
1288
+
1289
+ else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
1290
+ sub_zp<scalar_t>(frag_b1, frag_zpf[k % 2][j], 1);
1291
+ }
1292
+
1293
+ // Apply scale to frag_b1
1294
+ if constexpr (has_act_order) {
1295
+ scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
1296
+ act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
1297
+ act_frag_s[k % 2][3][j], 1);
1298
+
1299
+ } else {
1300
+ if constexpr (group_blocks != -1) {
1301
+ scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
1302
+ }
1303
+ }
1304
+
1305
+ #pragma unroll
1306
+ for (int i = 0; i < thread_m_blocks; i++) {
1307
+ mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
1308
+ mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
1309
+ }
1310
+ }
1311
+ };
1312
+
1313
+ // Since we slice across the k dimension of a tile in order to increase the
1314
+ // number of warps while keeping the n dimension of a tile reasonable, we have
1315
+ // multiple warps that accumulate their partial sums of the same output
1316
+ // location; which we have to reduce over in the end. We do in shared memory.
1317
+ auto thread_block_reduce = [&]() {
1318
+ constexpr int red_off = threads / b_sh_stride_threads / 2;
1319
+ if (red_off >= 1) {
1320
+ int red_idx = threadIdx.x / b_sh_stride_threads;
1321
+ constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
1322
+ constexpr int red_sh_delta = b_sh_stride_threads;
1323
+ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
1324
+ (threadIdx.x % b_sh_stride_threads);
1325
+
1326
+ // Parallel logarithmic shared memory reduction. We make sure to avoid any
1327
+ // unnecessary read or write iterations, e.g., for two warps we write only
1328
+ // once by warp 1 and read only once by warp 0.
1329
+
1330
+ #pragma unroll
1331
+ for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
1332
+ #pragma unroll
1333
+ for (int i = red_off; i > 0; i /= 2) {
1334
+ if (i <= red_idx && red_idx < 2 * i) {
1335
+ #pragma unroll
1336
+ for (int j = 0; j < 4 * 2; j++) {
1337
+ int red_sh_wr =
1338
+ red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
1339
+ if (i < red_off) {
1340
+ float* c_rd =
1341
+ reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
1342
+ float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
1343
+ #pragma unroll
1344
+ for (int k = 0; k < 4; k++)
1345
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
1346
+ c_rd[k] + c_wr[k];
1347
+ }
1348
+ sh[red_sh_wr] =
1349
+ reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
1350
+ }
1351
+ }
1352
+ __syncthreads();
1353
+ }
1354
+ if (red_idx == 0) {
1355
+ #pragma unroll
1356
+ for (int i = 0; i < 4 * 2; i++) {
1357
+ float* c_rd =
1358
+ reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
1359
+ #pragma unroll
1360
+ for (int j = 0; j < 4; j++)
1361
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
1362
+ c_rd[j];
1363
+ }
1364
+ }
1365
+ __syncthreads();
1366
+ }
1367
+ }
1368
+ };
1369
+
1370
+ // Since multiple threadblocks may process parts of the same column slice, we
1371
+ // finally have to globally reduce over the results. As the striped
1372
+ // partitioning minimizes the number of such reductions and our outputs are
1373
+ // usually rather small, we perform this reduction serially in L2 cache.
1374
+ auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
1375
+ // We are very careful here to reduce directly in the output buffer to
1376
+ // maximize L2 cache utilization in this step. To do this, we write out
1377
+ // results in FP16 (but still reduce with FP32 compute).
1378
+ constexpr int active_threads = 32 * thread_n_blocks / 4;
1379
+ if (threadIdx.x < active_threads) {
1380
+ int c_gl_stride = prob_n / 8;
1381
+ int c_gl_wr_delta_o = 8 * c_gl_stride;
1382
+ int c_gl_wr_delta_i = 4 * (active_threads / 32);
1383
+ int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
1384
+ 4 * (threadIdx.x / 32) + threadIdx.x % 4;
1385
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
1386
+ constexpr int c_sh_wr_delta = active_threads;
1387
+ int c_sh_wr = threadIdx.x;
1388
+
1389
+ int row = (threadIdx.x % 32) / 4;
1390
+
1391
+ if (!first) {
1392
+ // Interestingly, doing direct global accesses here really seems to mess up
1393
+ // the compiler and lead to slowdowns, hence we also use async-copies even
1394
+ // though these fetches are not actually asynchronous.
1395
+ #pragma unroll
1396
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
1397
+ cp_async4_pred(
1398
+ &sh[c_sh_wr + c_sh_wr_delta * i],
1399
+ &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
1400
+ c_gl_wr_delta_i * (i % 2)],
1401
+ i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
1402
+ }
1403
+ cp_async_fence();
1404
+ cp_async_wait<0>();
1405
+ }
1406
+
1407
+ #pragma unroll
1408
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
1409
+ if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
1410
+ if (!first) {
1411
+ int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
1412
+ #pragma unroll
1413
+ for (int j = 0; j < 2 * 4; j++) {
1414
+ reinterpret_cast<float*>(
1415
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
1416
+ Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
1417
+ }
1418
+ }
1419
+ if (!last) {
1420
+ int4 c;
1421
+ #pragma unroll
1422
+ for (int j = 0; j < 2 * 4; j++) {
1423
+ reinterpret_cast<scalar_t*>(&c)[j] =
1424
+ Dtype::float2num(reinterpret_cast<float*>(
1425
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
1426
+ }
1427
+ C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
1428
+ c;
1429
+ }
1430
+ }
1431
+ }
1432
+ }
1433
+ };
1434
+
1435
+ // Globally reduce over threadblocks that compute the same column block.
1436
+ // We use a tmp C buffer to reduce in full fp32 precision.
1437
+ auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
1438
+ constexpr int tb_m = thread_m_blocks * 16;
1439
+ constexpr int tb_n = thread_n_blocks * 16;
1440
+
1441
+ constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
1442
+
1443
+ constexpr int active_threads = 32 * thread_n_blocks / 4;
1444
+ bool is_th_active = threadIdx.x < active_threads;
1445
+
1446
+ int par_offset = c_size * n_tiles * par_id;
1447
+ int slice_offset = c_size * slice_col;
1448
+
1449
+ constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
1450
+ constexpr int th_size = num_floats * sizeof(float) / 16;
1451
+
1452
+ int c_cur_offset = par_offset + slice_offset;
1453
+
1454
+ if (!is_th_active) {
1455
+ return;
1456
+ }
1457
+
1458
+ if (!first) {
1459
+ float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
1460
+ #pragma unroll
1461
+ for (int k = 0; k < th_size; k++) {
1462
+ sh[threadIdx.x] =
1463
+ C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
1464
+
1465
+ float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
1466
+ #pragma unroll
1467
+ for (int f = 0; f < 4; f++) {
1468
+ frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
1469
+ }
1470
+ }
1471
+ }
1472
+
1473
+ if (!last) {
1474
+ int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
1475
+ #pragma unroll
1476
+ for (int k = 0; k < th_size; k++) {
1477
+ C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
1478
+ }
1479
+ }
1480
+ };
1481
+
1482
+ // Write out the reduce final result in the correct layout. We only actually
1483
+ // reshuffle matrix fragments in this step, the reduction above is performed
1484
+ // in fragment layout.
1485
+ auto write_result = [&]() {
1486
+ int c_gl_stride = prob_n / 8;
1487
+ constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
1488
+ int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
1489
+ constexpr int c_sh_rd_delta =
1490
+ c_sh_stride * (threads / (2 * thread_n_blocks));
1491
+
1492
+ int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
1493
+ (threadIdx.x % (2 * thread_n_blocks));
1494
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
1495
+ int c_sh_wr =
1496
+ (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
1497
+ c_sh_wr += 32 * (threadIdx.x / 32);
1498
+ int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
1499
+ (threadIdx.x % (2 * thread_n_blocks));
1500
+
1501
+ int c_gl_wr_end = c_gl_stride * prob_m;
1502
+
1503
+ // We first reorder in shared memory to guarantee the most efficient final
1504
+ // global write patterns
1505
+ auto write = [&](int idx, float c0, float c1, FragS& s) {
1506
+ scalar_t2 res =
1507
+ Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
1508
+
1509
+ // For per-column quantization we finally apply the scale here (only for
1510
+ // 4-bit)
1511
+ if constexpr (!has_act_order && group_blocks == -1 &&
1512
+ w_type.size_bits() == 4) {
1513
+ res = __hmul2(res, s[0]);
1514
+ }
1515
+
1516
+ ((scalar_t2*)sh)[idx] = res;
1517
+ };
1518
+
1519
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
1520
+ #pragma unroll
1521
+ for (int i = 0; i < thread_m_blocks; i++) {
1522
+ #pragma unroll
1523
+ for (int j = 0; j < 4; j++) {
1524
+ int wr = c_sh_wr + 8 * j;
1525
+ write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
1526
+ frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
1527
+ write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
1528
+ frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
1529
+ write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
1530
+ frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
1531
+ write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
1532
+ frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
1533
+ }
1534
+ c_sh_wr += 16 * (4 * c_sh_stride);
1535
+ }
1536
+ }
1537
+ __syncthreads();
1538
+
1539
+ #pragma unroll
1540
+ for (int i = 0;
1541
+ i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
1542
+ i++) {
1543
+ if (c_gl_wr < c_gl_wr_end) {
1544
+ C[c_gl_wr] = sh[c_sh_rd];
1545
+ c_gl_wr += c_gl_wr_delta;
1546
+ c_sh_rd += c_sh_rd_delta;
1547
+ }
1548
+ }
1549
+ };
1550
+
1551
+ // Start global fetch and register load pipelines.
1552
+ auto start_pipes = [&]() {
1553
+
1554
+ #pragma unroll
1555
+ for (int i = 0; i < stages - 1; i++) {
1556
+ if (has_act_order && i == 0) {
1557
+ int last_g_idx = slice_k_start + stages * tb_k * 2;
1558
+ if (last_g_idx >= prob_k) {
1559
+ last_g_idx = prob_k - 1;
1560
+ }
1561
+ fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
1562
+ }
1563
+
1564
+ if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
1565
+ if (i == 0) {
1566
+ fetch_zp_to_shared();
1567
+ }
1568
+ }
1569
+ fetch_to_shared(i, i, i < slice_iters);
1570
+ }
1571
+
1572
+ zero_accums();
1573
+ wait_for_stage();
1574
+ init_same_group(0);
1575
+ fetch_to_registers(0, 0);
1576
+ fetch_scales_to_registers(0, 0);
1577
+ fetch_zp_to_registers(0, 0);
1578
+ a_gl_rd += a_gl_rd_delta_o * (stages - 1);
1579
+ slice_k_start_shared_fetch += tb_k * (stages - 1);
1580
+ };
1581
+ if (slice_iters) {
1582
+ start_pipes();
1583
+ }
1584
+
1585
+ // Main loop.
1586
+ while (slice_iters) {
1587
+ // We unroll over both the global fetch and the register load pipeline to
1588
+ // ensure all shared memory accesses are static. Note that both pipelines
1589
+ // have even length meaning that the next iteration will always start at
1590
+ // index 0.
1591
+
1592
+ #pragma unroll
1593
+ for (int pipe = 0; pipe < stages;) {
1594
+ #pragma unroll
1595
+ for (int k = 0; k < b_sh_wr_iters; k++) {
1596
+ fetch_to_registers(k + 1, pipe % stages);
1597
+ fetch_scales_to_registers(k + 1, pipe);
1598
+ fetch_zp_to_registers(k + 1, pipe);
1599
+ if (k == b_sh_wr_iters - 2) {
1600
+ fetch_to_shared((pipe + stages - 1) % stages, pipe,
1601
+ slice_iters >= stages);
1602
+ pipe++;
1603
+ wait_for_stage();
1604
+ init_same_group(pipe % stages);
1605
+ }
1606
+ matmul(k);
1607
+ }
1608
+ slice_iters--;
1609
+ if (slice_iters == 0) {
1610
+ break;
1611
+ }
1612
+ }
1613
+
1614
+ a_gl_rd += a_gl_rd_delta_o * stages;
1615
+ slice_k_start += tb_k * stages;
1616
+ slice_k_start_shared_fetch += tb_k * stages;
1617
+
1618
+ if constexpr (has_act_order) {
1619
+ int first_group_id = g_idx[slice_k_start];
1620
+ int last_g_idx = slice_k_start + stages * tb_k * 2;
1621
+ if (last_g_idx >= prob_k) {
1622
+ last_g_idx = prob_k - 1;
1623
+ }
1624
+ int last_group_id = g_idx[last_g_idx];
1625
+ if (last_group_id >= sh_first_group_id + sh_num_groups) {
1626
+ fetch_scales_to_shared(false, first_group_id, last_group_id);
1627
+ __syncthreads();
1628
+ }
1629
+ }
1630
+
1631
+ // Process results and, if necessary, proceed to the next column slice.
1632
+ // While this pattern may not be the most readable, other ways of writing
1633
+ // the loop seemed to noticeably worse performance after compilation.
1634
+ if (slice_iters == 0) {
1635
+ cp_async_wait<0>();
1636
+ bool last = slice_idx == slice_count - 1;
1637
+ // For per-column scales, we only fetch them here in the final step before
1638
+ // write-out
1639
+ if constexpr (!has_act_order && group_blocks == -1) {
1640
+ if constexpr (w_type.size_bits() == 8) {
1641
+ if (s_sh_wr_pred) {
1642
+ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
1643
+ }
1644
+ cp_async_fence();
1645
+ } else {
1646
+ if (last) {
1647
+ if (s_sh_wr_pred) {
1648
+ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
1649
+ }
1650
+ cp_async_fence();
1651
+ }
1652
+ }
1653
+ }
1654
+
1655
+ thread_block_reduce();
1656
+ if constexpr (!has_act_order && group_blocks == -1) {
1657
+ if constexpr (w_type.size_bits() == 8) {
1658
+ cp_async_wait<0>();
1659
+ __syncthreads();
1660
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
1661
+ reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
1662
+ reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
1663
+ }
1664
+
1665
+ } else {
1666
+ if (last) {
1667
+ cp_async_wait<0>();
1668
+ __syncthreads();
1669
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
1670
+ reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
1671
+ reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
1672
+ }
1673
+ }
1674
+ }
1675
+ }
1676
+
1677
+ // For 8-bit channelwise, we apply the scale before the global reduction
1678
+ // that converts the fp32 results to fp16 (so that we avoid possible
1679
+ // overflow in fp16)
1680
+ if constexpr (!has_act_order && group_blocks == -1 &&
1681
+ w_type.size_bits() == 8) {
1682
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
1683
+ #pragma unroll
1684
+ for (int i = 0; i < thread_m_blocks; i++) {
1685
+ #pragma unroll
1686
+ for (int j = 0; j < 4; j++) {
1687
+ scale_float<scalar_t>(
1688
+ reinterpret_cast<float*>(&frag_c[i][j][0][0]),
1689
+ frag_s[j / 2][2 * (j % 2) + 0]);
1690
+ scale_float<scalar_t>(
1691
+ reinterpret_cast<float*>(&frag_c[i][j][0][2]),
1692
+ frag_s[j / 2][2 * (j % 2) + 0]);
1693
+
1694
+ scale_float<scalar_t>(
1695
+ reinterpret_cast<float*>(&frag_c[i][j][1][0]),
1696
+ frag_s[j / 2][2 * (j % 2) + 1]);
1697
+ scale_float<scalar_t>(
1698
+ reinterpret_cast<float*>(&frag_c[i][j][1][2]),
1699
+ frag_s[j / 2][2 * (j % 2) + 1]);
1700
+ }
1701
+ }
1702
+ }
1703
+ }
1704
+
1705
+ if (slice_count > 1) { // only globally reduce if there is more than one
1706
+ // block in a slice
1707
+ barrier_acquire(&locks[slice_col], slice_idx);
1708
+ if (use_fp32_reduce) {
1709
+ global_reduce_fp32(slice_idx == 0, last);
1710
+ } else {
1711
+ global_reduce_fp16(slice_idx == 0, last);
1712
+ }
1713
+ barrier_release(&locks[slice_col], last);
1714
+ }
1715
+ if (last) // only the last block in a slice actually writes the result
1716
+ write_result();
1717
+ slice_row = 0;
1718
+ slice_col_par++;
1719
+ slice_col++;
1720
+ init_slice();
1721
+ if (slice_iters) {
1722
+ a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
1723
+ (threadIdx.x % a_gl_rd_delta_o);
1724
+ #pragma unroll
1725
+ for (int i = 0; i < b_sh_wr_iters; i++)
1726
+ B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
1727
+ if (slice_col == 0) {
1728
+ #pragma unroll
1729
+ for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
1730
+ }
1731
+
1732
+ // Update slice k/n for scales loading
1733
+ if constexpr (has_act_order) {
1734
+ slice_k_start = tb_k * slice_row;
1735
+ slice_k_finish = slice_k_start + tb_k * slice_iters;
1736
+ slice_k_start_shared_fetch = slice_k_start;
1737
+ slice_n_offset = act_s_col_tb_stride * slice_col;
1738
+
1739
+ } else {
1740
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
1741
+ zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
1742
+ }
1743
+
1744
+ start_pipes();
1745
+ }
1746
+ }
1747
+ }
1748
+ }
1749
+
1750
+ #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
1751
+ HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
1752
+ IS_ZP_FLOAT) \
1753
+ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
1754
+ thread_n_blocks == THREAD_N_BLOCKS && \
1755
+ thread_k_blocks == THREAD_K_BLOCKS && \
1756
+ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
1757
+ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
1758
+ is_zp_float == IS_ZP_FLOAT) { \
1759
+ if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
1760
+ cudaFuncSetAttribute( \
1761
+ Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
1762
+ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
1763
+ HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
1764
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
1765
+ Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
1766
+ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
1767
+ HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
1768
+ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
1769
+ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
1770
+ num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
1771
+ } \
1772
+ }
1773
+
1774
+ typedef struct {
1775
+ int thread_k;
1776
+ int thread_n;
1777
+ int num_threads;
1778
+ } thread_config_t;
1779
+
1780
+ typedef struct {
1781
+ int max_m_blocks;
1782
+ thread_config_t tb_cfg;
1783
+ } exec_config_t;
1784
+
1785
+ thread_config_t small_batch_thread_configs[] = {
1786
+ // Ordered by priority
1787
+
1788
+ // thread_k, thread_n, num_threads
1789
+ {128, 128, 256},
1790
+ {64, 128, 128},
1791
+ {128, 64, 128},
1792
+ };
1793
+
1794
+ thread_config_t large_batch_thread_configs[] = {
1795
+ // Ordered by priority
1796
+
1797
+ // thread_k, thread_n, num_threads
1798
+ {64, 256, 256},
1799
+ {64, 128, 128},
1800
+ {128, 64, 128},
1801
+
1802
+ };
1803
+
1804
+ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
1805
+ int prob_n, int prob_k, int num_bits, int group_size,
1806
+ bool has_act_order, bool is_k_full) {
1807
+ bool cache_scales_chunk = has_act_order && !is_k_full;
1808
+
1809
+ int tb_n = th_config.thread_n;
1810
+ int tb_k = th_config.thread_k;
1811
+
1812
+ // Get max scale groups per thread-block
1813
+ int tb_groups;
1814
+ if (group_size == -1) {
1815
+ tb_groups = 1;
1816
+ } else if (group_size == 0) {
1817
+ tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
1818
+ } else {
1819
+ tb_groups = div_ceil(tb_k, group_size);
1820
+ }
1821
+
1822
+ if (cache_scales_chunk) {
1823
+ int load_groups =
1824
+ tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
1825
+ load_groups = max(load_groups, 32); // We load at least 32 scale groups
1826
+ return load_groups * tb_n * 2;
1827
+
1828
+ } else {
1829
+ int tb_scales = tb_groups * tb_n * 2;
1830
+
1831
+ return tb_scales * pipe_stages;
1832
+ }
1833
+ }
1834
+
1835
+ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
1836
+ int prob_m, int prob_n, int prob_k, int num_bits,
1837
+ int scales_cache_size, int max_shared_mem) {
1838
+ int pack_factor = 32 / num_bits;
1839
+
1840
+ // Get B size
1841
+ int tb_k = th_config.thread_k;
1842
+ int tb_n = th_config.thread_n;
1843
+
1844
+ int b_size = (tb_k * tb_n / pack_factor) * 4;
1845
+
1846
+ // Get A size
1847
+ int m_blocks = div_ceil(prob_m, 16);
1848
+ int tb_max_m = 16;
1849
+
1850
+ while (true) {
1851
+ if (m_blocks >= max_m_blocks) {
1852
+ tb_max_m *= max_m_blocks;
1853
+ break;
1854
+ }
1855
+
1856
+ max_m_blocks--;
1857
+ if (max_m_blocks == 0) {
1858
+ TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
1859
+ }
1860
+ }
1861
+
1862
+ int a_size = (tb_max_m * tb_k) * 2;
1863
+
1864
+ float pipe_size = (a_size + b_size) * pipe_stages;
1865
+
1866
+ TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
1867
+
1868
+ return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
1869
+ }
1870
+
1871
+ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
1872
+ int prob_m, int prob_n, int prob_k, int num_bits,
1873
+ int group_size, bool has_act_order, bool is_k_full,
1874
+ int max_shared_mem) {
1875
+ // Sanity
1876
+ if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
1877
+ th_config.num_threads == -1) {
1878
+ return false;
1879
+ }
1880
+
1881
+ // Verify K/N are divisible by thread K/N
1882
+ if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
1883
+ return false;
1884
+ }
1885
+
1886
+ // Verify min for thread K/N
1887
+ if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
1888
+ return false;
1889
+ }
1890
+
1891
+ // num_threads must be at least 128 (= 4 warps)
1892
+ if (th_config.num_threads < 128) {
1893
+ return false;
1894
+ }
1895
+
1896
+ // Determine cache for scales
1897
+ int scales_cache_size =
1898
+ get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
1899
+ group_size, has_act_order, is_k_full);
1900
+
1901
+ // Check that pipeline fits into cache
1902
+ if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
1903
+ num_bits, scales_cache_size, max_shared_mem)) {
1904
+ return false;
1905
+ }
1906
+
1907
+ return true;
1908
+ }
1909
+
1910
+ int determine_reduce_max_m(int prob_m, int max_par) {
1911
+ constexpr int tile_m_size = 16;
1912
+
1913
+ if (prob_m <= tile_m_size) {
1914
+ return tile_m_size;
1915
+
1916
+ } else if (prob_m <= tile_m_size * 2) {
1917
+ return tile_m_size * 2;
1918
+
1919
+ } else if (prob_m <= tile_m_size * 3) {
1920
+ return tile_m_size * 3;
1921
+
1922
+ } else if (prob_m <= tile_m_size * 4) {
1923
+ return tile_m_size * 4;
1924
+
1925
+ } else {
1926
+ int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
1927
+ return tile_m_size * 4 * cur_par;
1928
+ }
1929
+ }
1930
+
1931
+ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
1932
+ int num_bits, int group_size,
1933
+ bool has_act_order, bool is_k_full,
1934
+ int max_shared_mem) {
1935
+ int max_m_blocks = 4;
1936
+ while (max_m_blocks > 0) {
1937
+ if (prob_m <= 16) {
1938
+ for (auto th_config : small_batch_thread_configs) {
1939
+ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
1940
+ num_bits, group_size, has_act_order, is_k_full,
1941
+ max_shared_mem)) {
1942
+ return exec_config_t{max_m_blocks, th_config};
1943
+ }
1944
+ }
1945
+ } else {
1946
+ for (auto th_config : large_batch_thread_configs) {
1947
+ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
1948
+ num_bits, group_size, has_act_order, is_k_full,
1949
+ max_shared_mem)) {
1950
+ return exec_config_t{max_m_blocks, th_config};
1951
+ }
1952
+ }
1953
+ }
1954
+
1955
+ max_m_blocks--; // Process less M blocks per invocation to reduce cache
1956
+ // usage
1957
+ }
1958
+
1959
+ return exec_config_t{0, {-1, -1, -1}};
1960
+ }
1961
+
1962
+ #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
1963
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
1964
+ false) \
1965
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
1966
+ false) \
1967
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
1968
+ false) \
1969
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
1970
+ false) \
1971
+ \
1972
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
1973
+ false) \
1974
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
1975
+ false) \
1976
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
1977
+ false) \
1978
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
1979
+ false) \
1980
+ \
1981
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
1982
+ false) \
1983
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
1984
+ false) \
1985
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
1986
+ false) \
1987
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
1988
+ false) \
1989
+ \
1990
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
1991
+ false) \
1992
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
1993
+ false) \
1994
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
1995
+ false) \
1996
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
1997
+ false) \
1998
+ \
1999
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
2000
+ false) \
2001
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
2002
+ false) \
2003
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
2004
+ false) \
2005
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
2006
+ false)
2007
+
2008
+ #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
2009
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
2010
+ false) \
2011
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
2012
+ false) \
2013
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
2014
+ false) \
2015
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
2016
+ false) \
2017
+ \
2018
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
2019
+ false) \
2020
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
2021
+ false) \
2022
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
2023
+ false) \
2024
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
2025
+ false) \
2026
+ \
2027
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
2028
+ false) \
2029
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
2030
+ false) \
2031
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
2032
+ false) \
2033
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
2034
+ false) \
2035
+ \
2036
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
2037
+ false) \
2038
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
2039
+ false) \
2040
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
2041
+ false) \
2042
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
2043
+
2044
+ // We currently have 4-bit models only with group_blocks == 4
2045
+ #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
2046
+ __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
2047
+ true) \
2048
+ __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
2049
+ true) \
2050
+ __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
2051
+ true) \
2052
+ __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
2053
+
2054
+ template <typename scalar_t>
2055
+ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
2056
+ void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
2057
+ int prob_n, int prob_k, void* workspace,
2058
+ vllm::ScalarType const& q_type, bool has_act_order,
2059
+ bool is_k_full, bool has_zp, int num_groups, int group_size,
2060
+ int dev, cudaStream_t stream, int thread_k, int thread_n,
2061
+ int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
2062
+ if (has_zp) {
2063
+ TORCH_CHECK(
2064
+ q_type == vllm::kU4 || q_type == vllm::kU8,
2065
+ "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
2066
+ } else {
2067
+ TORCH_CHECK(
2068
+ q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
2069
+ "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
2070
+ q_type.str());
2071
+ }
2072
+
2073
+ TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
2074
+ ", ", prob_n, ", ", prob_k, "]");
2075
+
2076
+ // TODO: remove alias when we start supporting other 8bit types
2077
+ int num_bits = q_type.size_bits();
2078
+ int tot_m = prob_m;
2079
+ int tot_m_blocks = div_ceil(tot_m, 16);
2080
+ int pad = 16 * tot_m_blocks - tot_m;
2081
+
2082
+ if (sms == -1) {
2083
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
2084
+ }
2085
+
2086
+ int max_shared_mem = 0;
2087
+ cudaDeviceGetAttribute(&max_shared_mem,
2088
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
2089
+ TORCH_CHECK(max_shared_mem > 0);
2090
+
2091
+ // Set thread config
2092
+ exec_config_t exec_cfg;
2093
+ if (thread_k != -1 && thread_n != -1) {
2094
+ // User-defined config
2095
+ exec_cfg =
2096
+ exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
2097
+ } else {
2098
+ // Auto config
2099
+ exec_cfg =
2100
+ determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
2101
+ has_act_order, is_k_full, max_shared_mem);
2102
+ }
2103
+
2104
+ TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
2105
+ is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
2106
+ prob_m, prob_n, prob_k, num_bits, group_size,
2107
+ has_act_order, is_k_full, max_shared_mem),
2108
+ "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
2109
+ ", thread_k = ", exec_cfg.tb_cfg.thread_k,
2110
+ ", thread_n = ", exec_cfg.tb_cfg.thread_n,
2111
+ ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
2112
+ prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
2113
+ ", group_size = ", group_size,
2114
+ ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
2115
+ ", max_shared_mem = ", max_shared_mem);
2116
+
2117
+ int num_threads = exec_cfg.tb_cfg.num_threads;
2118
+ thread_k = exec_cfg.tb_cfg.thread_k;
2119
+ thread_n = exec_cfg.tb_cfg.thread_n;
2120
+
2121
+ int thread_k_blocks = thread_k / 16;
2122
+ int thread_n_blocks = thread_n / 16;
2123
+
2124
+ int blocks = sms;
2125
+
2126
+ TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
2127
+ " is not divisible by thread_n = ", thread_n);
2128
+ TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
2129
+ " is not divisible by thread_k = ", thread_k);
2130
+
2131
+ int group_blocks = 0;
2132
+ if (has_act_order) {
2133
+ if (is_k_full) {
2134
+ TORCH_CHECK(group_size != -1);
2135
+ group_blocks = group_size / 16;
2136
+ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
2137
+ " is not divisible by group_blocks = ", group_blocks);
2138
+ } else {
2139
+ TORCH_CHECK(group_size == 0);
2140
+ group_blocks = 0;
2141
+ }
2142
+
2143
+ } else {
2144
+ if (group_size == -1) {
2145
+ group_blocks = -1;
2146
+ } else {
2147
+ group_blocks = group_size / 16;
2148
+ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
2149
+ " is not divisible by group_blocks = ", group_blocks);
2150
+ }
2151
+ }
2152
+
2153
+ const int4* A_ptr = (const int4*)A;
2154
+ const int4* B_ptr = (const int4*)B;
2155
+ int4* C_ptr = (int4*)C;
2156
+ int4* C_tmp_ptr = (int4*)C_tmp;
2157
+ const int4* s_ptr = (const int4*)s;
2158
+ const int4* zp_ptr = (const int4*)zp;
2159
+ const int* g_idx_ptr = (const int*)g_idx;
2160
+ const int* perm_ptr = (const int*)perm;
2161
+ int4* a_tmp_ptr = (int4*)a_tmp;
2162
+
2163
+ int* locks = (int*)workspace;
2164
+
2165
+ if (has_act_order) {
2166
+ // Permute A columns
2167
+ int block_rows = div_ceil(prob_m, blocks);
2168
+ permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
2169
+ A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
2170
+ A_ptr = a_tmp_ptr;
2171
+ }
2172
+
2173
+ // If we have a full K, then we can run the non-act-order version of Marlin
2174
+ // (since the weight rows are reordered by increasing group ids, and by having
2175
+ // a full K, we have full original groups)
2176
+ if (is_k_full) {
2177
+ has_act_order = false;
2178
+ }
2179
+
2180
+ // Main loop
2181
+ for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
2182
+ int thread_m_blocks = tot_m_blocks - i;
2183
+ prob_m = tot_m - 16 * i;
2184
+ int par = 1;
2185
+ if (thread_m_blocks > exec_cfg.max_m_blocks) {
2186
+ // Note that parallel > 1 currently only works for inputs without any
2187
+ // padding
2188
+ par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
2189
+ if (par > max_par) par = max_par;
2190
+ prob_m = (16 * exec_cfg.max_m_blocks) * par;
2191
+ i += exec_cfg.max_m_blocks * (par - 1);
2192
+ thread_m_blocks = exec_cfg.max_m_blocks;
2193
+ }
2194
+
2195
+ if (false) {
2196
+ }
2197
+ GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
2198
+ GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
2199
+ GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
2200
+ GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
2201
+ GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
2202
+ GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
2203
+ GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
2204
+ GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
2205
+
2206
+ AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
2207
+ AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
2208
+ AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
2209
+ AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
2210
+ AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
2211
+ AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
2212
+ AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
2213
+ AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
2214
+
2215
+ HQQ_CALL_IF(vllm::kU4, 16, 4, 256)
2216
+ HQQ_CALL_IF(vllm::kU4, 8, 8, 256)
2217
+ HQQ_CALL_IF(vllm::kU4, 8, 4, 128)
2218
+ HQQ_CALL_IF(vllm::kU4, 4, 8, 128)
2219
+ else {
2220
+ TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
2221
+ ", ", prob_k, "]", ", has_act_order = ", has_act_order,
2222
+ ", num_groups = ", num_groups, ", group_size = ", group_size,
2223
+ ", thread_m_blocks = ", thread_m_blocks,
2224
+ ", thread_n_blocks = ", thread_n_blocks,
2225
+ ", thread_k_blocks = ", thread_k_blocks,
2226
+ ", num_bits = ", num_bits);
2227
+ }
2228
+
2229
+ A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
2230
+ C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
2231
+ }
2232
+ }
2233
+
2234
+ } // namespace marlin
2235
+
2236
+ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
2237
+ torch::Tensor& b_scales, torch::Tensor& b_zeros,
2238
+ torch::Tensor& g_idx, torch::Tensor& perm,
2239
+ torch::Tensor& workspace,
2240
+ vllm::ScalarTypeId const& b_q_type_id,
2241
+ int64_t size_m, int64_t size_n, int64_t size_k,
2242
+ bool is_k_full, bool has_zp,
2243
+ bool use_fp32_reduce, bool is_zp_float) {
2244
+ vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
2245
+ if (has_zp) {
2246
+ TORCH_CHECK(
2247
+ b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
2248
+ "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
2249
+ } else {
2250
+ TORCH_CHECK(
2251
+ b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
2252
+ "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
2253
+ b_q_type.str());
2254
+ }
2255
+
2256
+ if (has_zp && is_zp_float) {
2257
+ TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
2258
+ "Computation type must be float16 (half) when using float zero "
2259
+ "points.");
2260
+ }
2261
+
2262
+ int pack_factor = 32 / b_q_type.size_bits();
2263
+
2264
+ // Verify A
2265
+ TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
2266
+ ", size_m = ", size_m);
2267
+ TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
2268
+ ", size_k = ", size_k);
2269
+
2270
+ // Verify B
2271
+ TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
2272
+ " is not divisible by tile_size = ", marlin::tile_size);
2273
+ TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
2274
+ "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
2275
+ ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
2276
+ TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
2277
+ "b_q_weight.size(1) = ", b_q_weight.size(1),
2278
+ " is not divisible by tile_size = ", marlin::tile_size);
2279
+ int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
2280
+ TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
2281
+ ", actual_size_n = ", actual_size_n);
2282
+
2283
+ // Verify device and strides
2284
+ TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
2285
+ TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
2286
+
2287
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
2288
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
2289
+
2290
+ TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
2291
+ TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
2292
+
2293
+ TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
2294
+ TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
2295
+
2296
+ TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
2297
+ TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
2298
+
2299
+ TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
2300
+ TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
2301
+
2302
+ // Alloc buffers
2303
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
2304
+ auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
2305
+ torch::Tensor c = torch::empty({size_m, size_n}, options);
2306
+ torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
2307
+
2308
+ // Alloc C tmp buffer that is going to be used for the global reduce
2309
+ int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
2310
+ int reduce_n = size_n;
2311
+ auto options_fp32 =
2312
+ torch::TensorOptions().dtype(at::kFloat).device(a.device());
2313
+ if (!use_fp32_reduce) {
2314
+ reduce_max_m = 0;
2315
+ reduce_n = 0;
2316
+ }
2317
+ torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
2318
+
2319
+ // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
2320
+ // auto -1)
2321
+ int thread_k = -1;
2322
+ // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
2323
+ // auto -1)
2324
+ int thread_n = -1;
2325
+ // sms: number of SMs to use for the kernel (can usually be left as auto -1)
2326
+ int sms = -1;
2327
+
2328
+ // Verify g_idx and perm
2329
+ TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
2330
+ (g_idx.size(0) == size_k && perm.size(0) == size_k),
2331
+ "Unexpected g_idx.size(0) = ", g_idx.size(0),
2332
+ " and perm.size(0) = ", perm.size(0),
2333
+ ", where size_k = ", size_k);
2334
+
2335
+ // Detect groupsize and act_order
2336
+ int num_groups = -1;
2337
+ int group_size = -1;
2338
+ bool has_act_order = g_idx.size(0) != 0;
2339
+
2340
+ int rank = b_scales.sizes().size();
2341
+ TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
2342
+ TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
2343
+ " is not size_n = ", size_n);
2344
+ num_groups = b_scales.size(0);
2345
+
2346
+ if (has_act_order) {
2347
+ if (is_k_full) {
2348
+ TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
2349
+ TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
2350
+ ", is not divisible by num_groups = ", num_groups);
2351
+ group_size = size_k / num_groups;
2352
+ } else {
2353
+ group_size = 0;
2354
+ }
2355
+
2356
+ } else {
2357
+ if (num_groups > 1) {
2358
+ TORCH_CHECK(
2359
+ size_k % num_groups == 0, "size_k = ", size_k,
2360
+ ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
2361
+ group_size = size_k / num_groups;
2362
+ } else {
2363
+ group_size = -1;
2364
+ }
2365
+ }
2366
+
2367
+ // Verify b_zeros
2368
+ if (has_zp) {
2369
+ int rank = b_zeros.sizes().size();
2370
+ TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
2371
+ if (is_zp_float) {
2372
+ TORCH_CHECK(b_zeros.size(1) == size_n,
2373
+ "b_zeros dim 1 = ", b_zeros.size(1),
2374
+ " is not size_n = ", size_n);
2375
+ TORCH_CHECK(num_groups == b_zeros.size(0),
2376
+ "b_zeros dim 0 = ", b_zeros.size(0),
2377
+ " is not num_groups = ", num_groups);
2378
+ TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
2379
+ } else {
2380
+ TORCH_CHECK(b_zeros.size(0) == num_groups,
2381
+ "b_zeros dim 0 = ", b_zeros.size(0),
2382
+ " is not num_groups = ", num_groups);
2383
+ TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
2384
+ "b_zeros dim 1 = ", b_zeros.size(1),
2385
+ " is not size_n / pack_factor = ", size_n / pack_factor);
2386
+ }
2387
+ }
2388
+
2389
+ // Verify workspace size
2390
+ TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
2391
+ ", is not divisible by min_thread_n = ", marlin::min_thread_n);
2392
+ int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
2393
+ TORCH_CHECK(workspace.numel() >= min_workspace_size,
2394
+ "workspace.numel = ", workspace.numel(),
2395
+ " is below min_workspace_size = ", min_workspace_size);
2396
+
2397
+ int dev = a.get_device();
2398
+ if (a.scalar_type() == at::ScalarType::Half) {
2399
+ marlin::marlin_mm<half>(
2400
+ a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
2401
+ c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
2402
+ b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
2403
+ a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
2404
+ workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
2405
+ num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
2406
+ thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2407
+ } else if (a.scalar_type() == at::ScalarType::BFloat16) {
2408
+ marlin::marlin_mm<nv_bfloat16>(
2409
+ a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
2410
+ c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
2411
+ b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
2412
+ perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
2413
+ workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
2414
+ num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
2415
+ thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2416
+ } else {
2417
+ TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
2418
+ }
2419
+
2420
+ return c;
2421
+ }
2422
+
2423
+ #endif
gptq_marlin/gptq_marlin_repack.cu ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "marlin.cuh"
2
+
3
+ namespace marlin {
4
+
5
+ template <int const num_threads, int const num_bits, bool const has_perm>
6
+ __global__ void gptq_marlin_repack_kernel(
7
+ uint32_t const* __restrict__ b_q_weight_ptr,
8
+ uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
9
+ int size_k, int size_n) {
10
+ constexpr int pack_factor = 32 / num_bits;
11
+
12
+ int k_tiles = size_k / tile_k_size;
13
+ int n_tiles = size_n / tile_n_size;
14
+ int block_k_tiles = div_ceil(k_tiles, gridDim.x);
15
+
16
+ int start_k_tile = blockIdx.x * block_k_tiles;
17
+ if (start_k_tile >= k_tiles) {
18
+ return;
19
+ }
20
+
21
+ int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
22
+
23
+ // Wait until the next thread tile has been loaded to shared memory.
24
+ auto wait_for_stage = [&]() {
25
+ // We only have `stages - 2` active fetches since we are double buffering
26
+ // and can only issue the next fetch when it is guaranteed that the previous
27
+ // shared memory load is fully complete (as it may otherwise be
28
+ // overwritten).
29
+ cp_async_wait<repack_stages - 2>();
30
+ __syncthreads();
31
+ };
32
+
33
+ extern __shared__ int4 sh[];
34
+
35
+ constexpr int perm_size = tile_k_size / 4;
36
+
37
+ int4* sh_perm_ptr = sh;
38
+ int4* sh_pipe_ptr = sh_perm_ptr;
39
+ if constexpr (has_perm) {
40
+ sh_pipe_ptr += perm_size;
41
+ }
42
+
43
+ constexpr int tile_ints = tile_k_size / pack_factor;
44
+
45
+ constexpr int stage_n_threads = tile_n_size / 4;
46
+ constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
47
+ constexpr int stage_size = stage_k_threads * stage_n_threads;
48
+
49
+ auto load_perm_to_shared = [&](int k_tile_id) {
50
+ int first_k_int4 = (k_tile_id * tile_k_size) / 4;
51
+
52
+ int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
53
+
54
+ if (threadIdx.x < perm_size) {
55
+ sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
56
+ }
57
+ __syncthreads();
58
+ };
59
+
60
+ auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
61
+ if (n_tile_id >= n_tiles) {
62
+ cp_async_fence();
63
+ return;
64
+ }
65
+
66
+ int first_n = n_tile_id * tile_n_size;
67
+
68
+ int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
69
+
70
+ if constexpr (has_perm) {
71
+ if (threadIdx.x < stage_size) {
72
+ int k_id = threadIdx.x / stage_n_threads;
73
+ int n_id = threadIdx.x % stage_n_threads;
74
+
75
+ uint32_t const* sh_perm_int_ptr =
76
+ reinterpret_cast<uint32_t const*>(sh_perm_ptr);
77
+
78
+ int src_k = sh_perm_int_ptr[k_id];
79
+ int src_k_packed = src_k / pack_factor;
80
+
81
+ cp_async4(
82
+ &sh_ptr[k_id * stage_n_threads + n_id],
83
+ reinterpret_cast<int4 const*>(&(
84
+ b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
85
+ }
86
+
87
+ } else {
88
+ if (threadIdx.x < stage_size) {
89
+ int k_id = threadIdx.x / stage_n_threads;
90
+ int n_id = threadIdx.x % stage_n_threads;
91
+
92
+ int first_k = k_tile_id * tile_k_size;
93
+ int first_k_packed = first_k / pack_factor;
94
+
95
+ cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
96
+ reinterpret_cast<int4 const*>(
97
+ &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
98
+ first_n + (n_id * 4)])));
99
+ }
100
+ }
101
+
102
+ cp_async_fence();
103
+ };
104
+
105
+ auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
106
+ if (n_tile_id >= n_tiles) {
107
+ return;
108
+ }
109
+
110
+ int warp_id = threadIdx.x / 32;
111
+ int th_id = threadIdx.x % 32;
112
+
113
+ if (warp_id >= 4) {
114
+ return;
115
+ }
116
+
117
+ int tc_col = th_id / 4;
118
+ int tc_row = (th_id % 4) * 2;
119
+
120
+ constexpr int tc_offsets[4] = {0, 1, 8, 9};
121
+
122
+ int cur_n = warp_id * 16 + tc_col;
123
+
124
+ constexpr int sh_stride = 64;
125
+ constexpr uint32_t mask = (1 << num_bits) - 1;
126
+
127
+ int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
128
+ uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
129
+
130
+ uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
131
+
132
+ uint32_t vals[8];
133
+
134
+ if constexpr (has_perm) {
135
+ for (int i = 0; i < 4; i++) {
136
+ int k_idx = tc_row + tc_offsets[i];
137
+
138
+ uint32_t src_k = sh_perm_int_ptr[k_idx];
139
+ uint32_t src_k_pos = src_k % pack_factor;
140
+
141
+ uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
142
+ uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
143
+
144
+ uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
145
+ uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
146
+
147
+ vals[i] = b1_cur_val;
148
+ vals[4 + i] = b2_cur_val;
149
+ }
150
+
151
+ } else {
152
+ uint32_t b1_vals[tile_ints];
153
+ uint32_t b2_vals[tile_ints];
154
+
155
+ #pragma unroll
156
+ for (int i = 0; i < tile_ints; i++) {
157
+ b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
158
+ b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
159
+ }
160
+
161
+ #pragma unroll
162
+ for (int i = 0; i < 4; i++) {
163
+ int cur_elem = tc_row + tc_offsets[i];
164
+ int cur_int = cur_elem / pack_factor;
165
+ int cur_pos = cur_elem % pack_factor;
166
+
167
+ vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
168
+ vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
169
+ }
170
+ }
171
+
172
+ constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
173
+ int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
174
+
175
+ // Result of:
176
+ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
177
+ if constexpr (num_bits == 4) {
178
+ constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
179
+
180
+ uint32_t res = 0;
181
+ #pragma unroll
182
+ for (int i = 0; i < 8; i++) {
183
+ res |= vals[pack_idx[i]] << (i * 4);
184
+ }
185
+
186
+ out_ptr[out_offset + th_id * 4 + warp_id] = res;
187
+
188
+ } else {
189
+ constexpr int pack_idx[4] = {0, 2, 1, 3};
190
+
191
+ uint32_t res1 = 0;
192
+ uint32_t res2 = 0;
193
+ #pragma unroll
194
+ for (int i = 0; i < 4; i++) {
195
+ res1 |= vals[pack_idx[i]] << (i * 8);
196
+ res2 |= vals[4 + pack_idx[i]] << (i * 8);
197
+ }
198
+
199
+ out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
200
+ out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
201
+ }
202
+ };
203
+
204
+ auto start_pipes = [&](int k_tile_id, int n_tile_id) {
205
+ #pragma unroll
206
+ for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
207
+ fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
208
+ }
209
+
210
+ wait_for_stage();
211
+ };
212
+ #pragma unroll
213
+ for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
214
+ int n_tile_id = 0;
215
+
216
+ if constexpr (has_perm) {
217
+ load_perm_to_shared(k_tile_id);
218
+ }
219
+
220
+ start_pipes(k_tile_id, n_tile_id);
221
+
222
+ while (n_tile_id < n_tiles) {
223
+ #pragma unroll
224
+ for (int pipe = 0; pipe < repack_stages; pipe++) {
225
+ fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
226
+ n_tile_id + pipe + repack_stages - 1);
227
+ repack_tile(pipe, k_tile_id, n_tile_id + pipe);
228
+ wait_for_stage();
229
+ }
230
+ n_tile_id += repack_stages;
231
+ }
232
+ }
233
+ }
234
+
235
+ } // namespace marlin
236
+
237
+ #define CALL_IF(NUM_BITS, HAS_PERM) \
238
+ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
239
+ cudaFuncSetAttribute( \
240
+ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
241
+ HAS_PERM>, \
242
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
243
+ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
244
+ HAS_PERM> \
245
+ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
246
+ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
247
+ }
248
+
249
+ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
250
+ int64_t size_k, int64_t size_n,
251
+ int64_t num_bits) {
252
+ // Verify compatibility with marlin tile of 16x64
253
+ TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
254
+ " is not divisible by tile_k_size = ", marlin::tile_k_size);
255
+ TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
256
+ " is not divisible by tile_n_size = ", marlin::tile_n_size);
257
+
258
+ TORCH_CHECK(num_bits == 4 || num_bits == 8,
259
+ "num_bits must be 4 or 8. Got = ", num_bits);
260
+ int const pack_factor = 32 / num_bits;
261
+
262
+ // Verify B
263
+ TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
264
+ "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
265
+ ", size_k = ", size_k, ", pack_factor = ", pack_factor);
266
+ TORCH_CHECK(b_q_weight.size(1) == size_n,
267
+ "b_q_weight.size(1) = ", b_q_weight.size(1),
268
+ " is not size_n = ", size_n);
269
+
270
+ // Verify device and strides
271
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
272
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
273
+ TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
274
+
275
+ TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
276
+ TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
277
+ TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
278
+
279
+ // Alloc buffers
280
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
281
+ auto options = torch::TensorOptions()
282
+ .dtype(b_q_weight.dtype())
283
+ .device(b_q_weight.device());
284
+ torch::Tensor out = torch::empty(
285
+ {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
286
+ options);
287
+
288
+ // Detect if there is act_order
289
+ bool has_perm = perm.size(0) != 0;
290
+
291
+ // Get ptrs
292
+ uint32_t const* b_q_weight_ptr =
293
+ reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
294
+ uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
295
+ uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
296
+
297
+ // Get dev info
298
+ int dev = b_q_weight.get_device();
299
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
300
+ int blocks;
301
+ cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
302
+
303
+ int max_shared_mem = 0;
304
+ cudaDeviceGetAttribute(&max_shared_mem,
305
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
306
+ TORCH_CHECK(max_shared_mem > 0);
307
+
308
+ if (false) {
309
+ }
310
+ CALL_IF(4, false)
311
+ CALL_IF(4, true)
312
+ CALL_IF(8, false)
313
+ CALL_IF(8, true)
314
+ else {
315
+ TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
316
+ ", has_perm = ", has_perm);
317
+ }
318
+
319
+ return out;
320
+ }
321
+
322
+ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
323
+ torch::Tensor& perm, c10::SymInt size_k,
324
+ c10::SymInt size_n, int64_t num_bits) {
325
+ int const pack_factor = 32 / num_bits;
326
+ auto options = torch::TensorOptions()
327
+ .dtype(b_q_weight.dtype())
328
+ .device(b_q_weight.device());
329
+ return torch::empty_symint(
330
+ {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
331
+ options);
332
+ }
333
+