danieldk HF staff commited on
Commit
5c6fb68
·
1 Parent(s): a77838d

Add `scaled_(int|fp8)_quant` and `fp8_marlin_gemm`

Browse files
build.toml CHANGED
@@ -4,10 +4,11 @@ version = "0.0.1"
4
  [torch]
5
  name = "quantization"
6
  src = [
7
- "ext-torch/registration.h",
8
  "ext-torch/torch_binding.cpp",
9
  "ext-torch/torch_binding.h"
10
  ]
 
11
  pysrc = [
12
  "ext-torch/__init__.py"
13
  ]
@@ -39,3 +40,32 @@ src = [
39
  ]
40
  include = [ "." ]
41
  depends = [ "cutlass", "torch" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  [torch]
5
  name = "quantization"
6
  src = [
7
+ "core/registration.h",
8
  "ext-torch/torch_binding.cpp",
9
  "ext-torch/torch_binding.h"
10
  ]
11
+ include = [ "." ]
12
  pysrc = [
13
  "ext-torch/__init__.py"
14
  ]
 
40
  ]
41
  include = [ "." ]
42
  depends = [ "cutlass", "torch" ]
43
+
44
+ [kernel.fp8_common]
45
+ capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
46
+ src = [
47
+ "fp8/common.cu",
48
+ "fp8/common.cuh",
49
+ "dispatch_utils.h"
50
+ ]
51
+ include = [ "." ]
52
+ depends = [ "torch" ]
53
+
54
+ [kernel.fp8_marlin]
55
+ capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
56
+ src = [
57
+ "fp8/fp8_marlin.cu",
58
+ "gptq_marlin/marlin.cuh",
59
+ "gptq_marlin/marlin_dtypes.cuh",
60
+ ]
61
+ #include = [ "." ]
62
+ depends = [ "torch" ]
63
+
64
+ [kernel.int8_common]
65
+ capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
66
+ src = [
67
+ "compressed_tensors/int8_quant_kernels.cu",
68
+ "dispatch_utils.h"
69
+ ]
70
+ include = [ "." ]
71
+ depends = [ "torch" ]
compressed_tensors/int8_quant_kernels.cu ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <torch/all.h>
3
+ #include <cmath>
4
+
5
+ #include "dispatch_utils.h"
6
+
7
+ #ifndef USE_ROCM
8
+ #include <cub/util_type.cuh>
9
+ #include <cub/cub.cuh>
10
+ #else
11
+ #include <hipcub/util_type.hpp>
12
+ #include <hipcub/hipcub.hpp>
13
+ #endif
14
+
15
+ static inline __device__ int8_t float_to_int8_rn(float x) {
16
+ #ifdef USE_ROCM
17
+ static constexpr auto i8_min =
18
+ static_cast<float>(std::numeric_limits<int8_t>::min());
19
+ static constexpr auto i8_max =
20
+ static_cast<float>(std::numeric_limits<int8_t>::max());
21
+
22
+ // To match the rounding mode of CUDA, we use nearbyint.
23
+ // It uses the current rounding mode, which is always FE_TONEAREST on HIP.
24
+ // If that changes in the future, we may need to set the rounding mode
25
+ // explicitly, either at runtime or compile time.
26
+ float dst = std::nearbyint(x);
27
+
28
+ // saturate
29
+ dst = std::clamp(dst, i8_min, i8_max);
30
+ return static_cast<int8_t>(dst);
31
+ #else
32
+ // CUDA path
33
+ uint32_t dst;
34
+ asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
35
+ return reinterpret_cast<const int8_t&>(dst);
36
+ #endif
37
+ }
38
+
39
+ static inline __device__ int32_t float_to_int32_rn(float x) {
40
+ #ifdef USE_ROCM
41
+ // int32_max is not exactly representable as float.
42
+ // Therefore, we need to be careful and manually return int32_max on overflow.
43
+ // For symmetry, we also do the same for int32_min, even though it is exactly
44
+ // representable as float and the conversion should be exact.
45
+ static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
46
+ static constexpr auto i32_min_f = static_cast<float>(i32_min);
47
+ static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
48
+ static constexpr auto i32_max_f = static_cast<float>(i32_max);
49
+
50
+ // To match the rounding mode of CUDA, we use nearbyint.
51
+ // It uses the current rounding mode, which is always FE_TONEAREST on HIP.
52
+ // If that changes in the future, we may need to set the rounding mode
53
+ // explicitly, either at runtime or compile time.
54
+ float dst = std::nearbyint(x);
55
+
56
+ // saturate on the higher end.
57
+ if (dst >= i32_max_f) {
58
+ return i32_max;
59
+ }
60
+ // saturate on the lower end.
61
+ if (dst <= i32_min_f) {
62
+ return i32_min;
63
+ }
64
+
65
+ return static_cast<int32_t>(dst);
66
+ #else
67
+ // CUDA path
68
+ uint32_t dst;
69
+ asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
70
+ return reinterpret_cast<const int32_t&>(dst);
71
+ #endif
72
+ }
73
+
74
+ static inline __device__ int8_t int32_to_int8(int32_t x) {
75
+ #ifdef USE_ROCM
76
+ static constexpr auto i8_min =
77
+ static_cast<int32_t>(std::numeric_limits<int8_t>::min());
78
+ static constexpr auto i8_max =
79
+ static_cast<int32_t>(std::numeric_limits<int8_t>::max());
80
+
81
+ // saturate
82
+ int32_t dst = std::clamp(x, i8_min, i8_max);
83
+ return static_cast<int8_t>(dst);
84
+ #else
85
+ // CUDA path
86
+ uint32_t dst;
87
+ asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
88
+ return reinterpret_cast<const int8_t&>(dst);
89
+ #endif
90
+ }
91
+
92
+ namespace vllm {
93
+
94
+ template <typename scalar_t, typename scale_type>
95
+ __global__ void static_scaled_int8_quant_kernel(
96
+ scalar_t const* __restrict__ input, int8_t* __restrict__ out,
97
+ scale_type const* scale_ptr, const int hidden_size) {
98
+ int const tid = threadIdx.x;
99
+ int64_t const token_idx = blockIdx.x;
100
+ scale_type const scale = *scale_ptr;
101
+
102
+ // Must be performed using 64-bit math to avoid integer overflow.
103
+ out += token_idx * hidden_size;
104
+ input += token_idx * hidden_size;
105
+
106
+ for (int i = tid; i < hidden_size; i += blockDim.x) {
107
+ out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
108
+ }
109
+ }
110
+
111
+ template <typename scalar_t, typename scale_type, typename azp_type>
112
+ __global__ void static_scaled_int8_azp_quant_kernel(
113
+ scalar_t const* __restrict__ input, int8_t* __restrict__ out,
114
+ scale_type const* scale_ptr, azp_type const* azp_ptr,
115
+ const int hidden_size) {
116
+ int const tid = threadIdx.x;
117
+ int64_t const token_idx = blockIdx.x;
118
+ scale_type const scale = *scale_ptr;
119
+ azp_type const azp = *azp_ptr;
120
+
121
+ // Must be performed using 64-bit math to avoid integer overflow.
122
+ out += token_idx * hidden_size;
123
+ input += token_idx * hidden_size;
124
+
125
+ for (int i = tid; i < hidden_size; i += blockDim.x) {
126
+ auto const val = static_cast<float>(input[i]);
127
+ auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
128
+ out[i] = quant_val;
129
+ }
130
+ }
131
+
132
+ template <typename scalar_t, typename scale_type>
133
+ __global__ void dynamic_scaled_int8_quant_kernel(
134
+ scalar_t const* __restrict__ input, int8_t* __restrict__ out,
135
+ scale_type* scale, const int hidden_size) {
136
+ int const tid = threadIdx.x;
137
+ int64_t const token_idx = blockIdx.x;
138
+ float absmax_val = 0.0f;
139
+ float const zero = 0.0f;
140
+
141
+ // Must be performed using 64-bit math to avoid integer overflow.
142
+ out += token_idx * hidden_size;
143
+ input += token_idx * hidden_size;
144
+
145
+ for (int i = tid; i < hidden_size; i += blockDim.x) {
146
+ float val = static_cast<float>(input[i]);
147
+ val = val > zero ? val : -val;
148
+ absmax_val = val > absmax_val ? val : absmax_val;
149
+ }
150
+
151
+ using BlockReduce = cub::BlockReduce<float, 1024>;
152
+ __shared__ typename BlockReduce::TempStorage reduceStorage;
153
+ float const block_absmax_val_maybe =
154
+ BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
155
+ __shared__ float block_absmax_val;
156
+ if (tid == 0) {
157
+ block_absmax_val = block_absmax_val_maybe;
158
+ scale[token_idx] = block_absmax_val / 127.0f;
159
+ }
160
+ __syncthreads();
161
+
162
+ float const tmp_scale = 127.0f / block_absmax_val;
163
+ for (int i = tid; i < hidden_size; i += blockDim.x) {
164
+ out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
165
+ }
166
+ }
167
+
168
+ template <typename scalar_t, typename scale_type, typename azp_type>
169
+ __global__ void dynamic_scaled_int8_azp_quant_kernel(
170
+ scalar_t const* __restrict__ input, int8_t* __restrict__ out,
171
+ scale_type* scale, azp_type* azp, const int hidden_size) {
172
+ int64_t const token_idx = blockIdx.x;
173
+
174
+ // Must be performed using 64-bit math to avoid integer overflow.
175
+ out += token_idx * hidden_size;
176
+ input += token_idx * hidden_size;
177
+
178
+ // Scan for the min and max value for this token
179
+ float max_val = std::numeric_limits<float>::min();
180
+ float min_val = std::numeric_limits<float>::max();
181
+ for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
182
+ auto val = static_cast<float>(input[i]);
183
+ max_val = std::max(max_val, val);
184
+ min_val = std::min(min_val, val);
185
+ }
186
+
187
+ // Reduce the max and min values across the block
188
+ using BlockReduce = cub::BlockReduce<float, 1024>;
189
+ __shared__ typename BlockReduce::TempStorage reduceStorage;
190
+ max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
191
+ __syncthreads(); // Make sure min doesn't mess with max shared memory
192
+ min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
193
+
194
+ __shared__ scale_type scale_sh;
195
+ __shared__ azp_type azp_sh;
196
+
197
+ // Compute the scale and zero point and store them, only on the first thread
198
+ if (threadIdx.x == 0) {
199
+ float const scale_val = (max_val - min_val) / 255.0f;
200
+ // Use rounding to even (same as torch.round)
201
+ auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
202
+ auto const azp_val = static_cast<azp_type>(azp_float);
203
+
204
+ // Store the scale and azp into shared and global
205
+ scale[token_idx] = scale_sh = scale_val;
206
+ azp[token_idx] = azp_sh = azp_val;
207
+ }
208
+
209
+ // Wait for the scale and azp to be computed
210
+ __syncthreads();
211
+
212
+ float const scale_val = scale_sh;
213
+ azp_type const azp_val = azp_sh;
214
+
215
+ // Quantize the values
216
+ for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
217
+ auto const val = static_cast<float>(input[i]);
218
+ auto const quant_val =
219
+ int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
220
+ out[i] = quant_val;
221
+ }
222
+ }
223
+
224
+ } // namespace vllm
225
+
226
+ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
227
+ torch::Tensor const& input, // [..., hidden_size]
228
+ torch::Tensor const& scale,
229
+ c10::optional<torch::Tensor> const& azp) {
230
+ TORCH_CHECK(input.is_contiguous());
231
+ TORCH_CHECK(out.is_contiguous());
232
+ TORCH_CHECK(scale.numel() == 1);
233
+ TORCH_CHECK(!azp || azp->numel() == 1);
234
+
235
+ int const hidden_size = input.size(-1);
236
+ int const num_tokens = input.numel() / hidden_size;
237
+ dim3 const grid(num_tokens);
238
+ dim3 const block(std::min(hidden_size, 1024));
239
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
240
+ VLLM_DISPATCH_FLOATING_TYPES(
241
+ input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
242
+ if (!azp) {
243
+ vllm::static_scaled_int8_quant_kernel<scalar_t, float>
244
+ <<<grid, block, 0, stream>>>(
245
+ input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
246
+ scale.data_ptr<float>(), hidden_size);
247
+ } else {
248
+ vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
249
+ <<<grid, block, 0, stream>>>(
250
+ input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
251
+ scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
252
+ hidden_size);
253
+ }
254
+ });
255
+ }
256
+
257
+ void dynamic_scaled_int8_quant(
258
+ torch::Tensor& out, // [..., hidden_size]
259
+ torch::Tensor const& input, // [..., hidden_size]
260
+ torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
261
+ TORCH_CHECK(input.is_contiguous());
262
+ TORCH_CHECK(out.is_contiguous());
263
+ TORCH_CHECK(scales.is_contiguous());
264
+ TORCH_CHECK(!azp || azp->is_contiguous());
265
+
266
+ int const hidden_size = input.size(-1);
267
+ int const num_tokens = input.numel() / hidden_size;
268
+ dim3 const grid(num_tokens);
269
+ dim3 const block(std::min(hidden_size, 1024));
270
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
271
+ VLLM_DISPATCH_FLOATING_TYPES(
272
+ input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
273
+ if (!azp) {
274
+ vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
275
+ <<<grid, block, 0, stream>>>(
276
+ input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
277
+ scales.data_ptr<float>(), hidden_size);
278
+ } else {
279
+ vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
280
+ <<<grid, block, 0, stream>>>(
281
+ input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
282
+ scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
283
+ hidden_size);
284
+ }
285
+ });
286
+ }
{ext-torch → core}/registration.h RENAMED
File without changes
dispatch_utils.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4
+ */
5
+ #pragma once
6
+
7
+ #include <torch/all.h>
8
+
9
+ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
+
14
+ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
16
+
17
+ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
18
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
19
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
20
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
21
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
22
+
23
+ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
24
+ AT_DISPATCH_SWITCH(TYPE, NAME, \
25
+ VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
26
+
27
+ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
28
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
29
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
30
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
31
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
32
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
33
+
34
+ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
35
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
ext-torch/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
ext-torch/torch_binding.cpp CHANGED
@@ -1,10 +1,9 @@
1
  #include <torch/library.h>
2
 
3
- #include "registration.h"
4
  #include "torch_binding.h"
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
-
8
  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
9
  // quantization, as well as bias
10
  ops.def(
@@ -27,6 +26,46 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
27
  ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
28
  ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
31
 
32
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
1
  #include <torch/library.h>
2
 
3
+ #include "core/registration.h"
4
  #include "torch_binding.h"
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
 
7
  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
8
  // quantization, as well as bias
9
  ops.def(
 
26
  ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
27
  ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
28
 
29
+ // Compute FP8 quantized tensor for given scaling factor.
30
+ ops.def(
31
+ "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
32
+ "()");
33
+ ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
34
+
35
+ // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
36
+ ops.def(
37
+ "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
38
+ "-> "
39
+ "()");
40
+ ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
41
+
42
+ // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
43
+ ops.def(
44
+ "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
45
+ "Tensor! scale, Tensor? scale_ub) -> "
46
+ "()");
47
+ ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
48
+ &dynamic_per_token_scaled_fp8_quant);
49
+
50
+ // Compute int8 quantized tensor for given scaling factor.
51
+ ops.def(
52
+ "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
53
+ "Tensor? azp) -> ()");
54
+ ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
55
+
56
+ // Compute int8 quantized tensor and scaling factor
57
+ ops.def(
58
+ "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
59
+ "Tensor!? azp) -> ()");
60
+ ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
61
+ &dynamic_scaled_int8_quant);
62
+
63
+ // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
64
+ ops.def(
65
+ "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
66
+ "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
67
+ "SymInt size_k) -> Tensor");
68
+ ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
69
  }
70
 
71
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
ext-torch/torch_binding.h CHANGED
@@ -2,17 +2,47 @@
2
 
3
  #include <torch/torch.h>
4
 
5
- bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
6
-
7
- void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
8
- torch::Tensor const& b, torch::Tensor const& a_scales,
9
- torch::Tensor const& b_scales,
10
- c10::optional<torch::Tensor> const& bias);
11
-
12
- void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
13
- torch::Tensor const& b,
14
- torch::Tensor const& a_scales,
15
- torch::Tensor const& b_scales,
16
- torch::Tensor const& azp_adj,
17
- c10::optional<torch::Tensor> const& azp,
18
  c10::optional<torch::Tensor> const& bias);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #include <torch/torch.h>
4
 
5
+ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
6
+
7
+ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
8
+ torch::Tensor const& b, torch::Tensor const& a_scales,
9
+ torch::Tensor const& b_scales,
10
+ c10::optional<torch::Tensor> const& bias);
11
+
12
+ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
13
+ torch::Tensor const& b,
14
+ torch::Tensor const& a_scales,
15
+ torch::Tensor const& b_scales,
16
+ torch::Tensor const& azp_adj,
17
+ c10::optional<torch::Tensor> const& azp,
18
  c10::optional<torch::Tensor> const& bias);
19
+
20
+ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
21
+ torch::Tensor const& scale,
22
+ c10::optional<torch::Tensor> const& azp);
23
+
24
+ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
25
+ torch::Tensor& scales,
26
+ c10::optional<torch::Tensor> const& azp);
27
+
28
+ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
29
+ torch::Tensor b_gptq_qzeros,
30
+ torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
31
+ bool use_exllama, int64_t bit);
32
+
33
+ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
34
+
35
+ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
36
+ torch::Tensor const& scale);
37
+
38
+ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
39
+ torch::Tensor& scale);
40
+
41
+ void dynamic_per_token_scaled_fp8_quant(
42
+ torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
43
+ c10::optional<torch::Tensor> const& scale_ub);
44
+
45
+ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
46
+ torch::Tensor& b_scales, torch::Tensor& workspace,
47
+ int64_t num_bits, int64_t size_m, int64_t size_n,
48
+ int64_t size_k);
fp8/amd/hip_float8.h ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef __HIPCC__
4
+ #include <hip/hip_runtime.h>
5
+ #else
6
+ #include <type_traits>
7
+ #include <stdint.h>
8
+ #include <math.h>
9
+ #include <iostream>
10
+ #endif
11
+
12
+ #include "hip_float8_impl.h"
13
+
14
+ struct alignas(1) hip_fp8 {
15
+ struct from_bits_t {};
16
+ HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
17
+ return from_bits_t();
18
+ }
19
+ uint8_t data;
20
+
21
+ hip_fp8() = default;
22
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
23
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
24
+ explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
25
+ : data(v) {}
26
+
27
+ #ifdef __HIP__MI300__
28
+ // NOTE: ON-DEVICE... always optimal bias
29
+ explicit HIP_FP8_DEVICE hip_fp8(float v)
30
+ : data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
31
+
32
+ explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
33
+ : hip_fp8(static_cast<float>(v)) {}
34
+
35
+ // Host only implementation using s/w simulation
36
+ explicit HIP_FP8_HOST
37
+ #else // __HIP__MI300__
38
+ // both Host and DEVICE for non-MI300 using s/w simulation
39
+ explicit HIP_FP8_HOST_DEVICE
40
+ #endif // __HIP__MI300__
41
+ hip_fp8(float v) {
42
+ data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
43
+ true /*clip*/>(v);
44
+ }
45
+
46
+ explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
47
+ : hip_fp8(static_cast<float>(v)) {}
48
+
49
+ #ifdef __HIP__MI300__
50
+ // upcast using device specific intrinsic
51
+ explicit inline HIP_FP8_DEVICE operator float() const {
52
+ float fval;
53
+ uint32_t i32val = static_cast<uint32_t>(data);
54
+
55
+ // upcast
56
+ asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
57
+ : "=v"(fval)
58
+ : "v"(i32val));
59
+
60
+ return fval;
61
+ }
62
+
63
+ explicit inline HIP_FP8_HOST operator float() const
64
+ #else // __HIP__MI300__
65
+ explicit inline HIP_FP8_HOST_DEVICE operator float() const
66
+ #endif // __HIP__MI300__
67
+ {
68
+ return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
69
+ data);
70
+ }
71
+ };
72
+
73
+ namespace std {
74
+ inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
75
+ inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
76
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
77
+ } // namespace std
78
+
79
+ // Special operator overloading
80
+ inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
81
+ return os << float(f8);
82
+ }
83
+
84
+ // all + operator overloading with mixed types
85
+ // mixed types, always converts to f32, does computation in f32, and returns
86
+ // float
87
+ inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
88
+ return (fa + float(b));
89
+ }
90
+
91
+ inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
92
+ return (float(a) + fb);
93
+ }
94
+
95
+ inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
96
+ return hip_fp8(float(a) + float(b));
97
+ }
98
+
99
+ inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
100
+ return a = hip_fp8(float(a) + float(b));
101
+ }
102
+
103
+ // overloading multiplication, always returns float,
104
+ inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
105
+ return float(a) * float(b);
106
+ }
107
+
108
+ inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
109
+ return (a * float(b));
110
+ }
111
+
112
+ inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
113
+ return (float(a) * b);
114
+ }
115
+
116
+ inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
117
+ return ((float)a * float(b));
118
+ }
119
+
120
+ inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
121
+ return ((float)a * float(b));
122
+ }
123
+
124
+ // overloading for compare
125
+ inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
126
+ return (a.data == b.data);
127
+ }
128
+ inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
129
+ return (a.data != b.data);
130
+ }
131
+
132
+ inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
133
+ return static_cast<float>(a) >= static_cast<float>(b);
134
+ }
135
+ inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
136
+ return static_cast<float>(a) > static_cast<float>(b);
137
+ }
fp8/amd/hip_float8_impl.h ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if defined(__HIPCC__) && \
4
+ (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
5
+ #define __HIP__MI300__
6
+ #endif
7
+
8
+ #ifdef __HIPCC__
9
+ #define HIP_FP8_HOST_DEVICE __host__ __device__
10
+ #define HIP_FP8_HOST __host__
11
+ #define HIP_FP8_DEVICE __device__
12
+ #else
13
+ #define HIP_FP8_HOST_DEVICE
14
+ #define HIP_FP8_HOST
15
+ #define HIP_FP8_DEVICE
16
+ #endif
17
+
18
+ namespace hip_fp8_impl {
19
+
20
+ #ifdef __HIP__MI300__
21
+ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
22
+ uint8_t i8data;
23
+ union {
24
+ float fval;
25
+ uint32_t i32val;
26
+ uint8_t i8val[4]; // NOTE: not endian independent
27
+ } val;
28
+
29
+ uint32_t ival = 0;
30
+ val.fval = v;
31
+
32
+ if ((val.i32val & 0x7F800000) !=
33
+ 0x7F800000) { /// propagate NAN/INF, no clipping
34
+ val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
35
+ }
36
+
37
+ ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
38
+ false); // false -> WORD0
39
+ val.i32val = ival;
40
+ i8data = val.i8val[0];
41
+
42
+ return i8data;
43
+ }
44
+ #endif // __HIP__MI300__
45
+
46
+ HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
47
+ #if defined(__HIPCC__) || defined(__CUDA_ARCH__)
48
+ HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
49
+ #endif
50
+
51
+ template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
52
+ HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
53
+ uint32_t rng = 0) {
54
+ #ifdef __HIPCC__
55
+ constexpr bool is_half = std::is_same<T, _Float16>::value;
56
+ #else
57
+ constexpr bool is_half = false;
58
+ #endif
59
+ constexpr bool is_float = std::is_same<T, float>::value;
60
+ static_assert(wm + we == 7, "wm+we==7");
61
+ static_assert(is_half || is_float, "Only half and float can be cast to f8");
62
+
63
+ const int mfmt = (sizeof(T) == 4) ? 23 : 10;
64
+ uint32_t x;
65
+ if (sizeof(T) == 4) {
66
+ x = reinterpret_cast<uint32_t&>(_x);
67
+ } else {
68
+ x = reinterpret_cast<uint16_t&>(_x);
69
+ }
70
+
71
+ uint32_t head, mantissa;
72
+ int exponent, bias;
73
+ uint32_t sign;
74
+
75
+ if (sizeof(T) == 4) {
76
+ head = x & 0xFF800000;
77
+ mantissa = x & 0x7FFFFF;
78
+ exponent = (head >> 23) & 0xFF;
79
+ sign = head >> 31;
80
+ bias = 127;
81
+ } else {
82
+ head = x & 0xFC00;
83
+ mantissa = x & 0x3FF;
84
+ exponent = (head >> 10) & 0x1F;
85
+ sign = head >> 15;
86
+ bias = 15;
87
+ }
88
+
89
+ uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
90
+
91
+ // Deal with inf and NaNs
92
+ if (negative_zero_nan) {
93
+ if (sizeof(T) == 4) {
94
+ if ((x & 0x7F800000) == 0x7F800000) {
95
+ return 0x80;
96
+ }
97
+ } else {
98
+ // if(__hisinf(x) || __hisnan(x))
99
+ if ((x & 0x7C00) == 0x7C00) {
100
+ return 0x80;
101
+ }
102
+ }
103
+ } else {
104
+ if (sizeof(T) == 4) {
105
+ if ((x & 0x7F800000) == 0x7F800000) {
106
+ return signed_inf + (mantissa != 0 ? 1 : 0);
107
+ }
108
+ } else {
109
+ if ((x & 0x7C00) == 0x7C00) {
110
+ return signed_inf + (mantissa != 0 ? 1 : 0);
111
+ }
112
+ }
113
+ }
114
+ if (x == 0) {
115
+ return 0;
116
+ }
117
+
118
+ // First need to check if it is normal or denorm as there is a difference of
119
+ // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
120
+ // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
121
+ // to mantissa and truncate. And for RNE, no need to add rng. Then probably
122
+ // need to check whether there is carry and adjust exponent and mantissa again
123
+
124
+ // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
125
+ // bits
126
+ const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
127
+ const int f8_denormal_act_exponent =
128
+ 1 - f8_bias; // actual exponent of f8 denormal
129
+ // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
130
+ // f8_exponent is the converted f8 exponent with bias encoding
131
+ // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
132
+ // the difference needs to be adjusted and mantissa shifted
133
+ int act_exponent, f8_exponent, exponent_diff;
134
+
135
+ if (exponent == 0) { // fp32/fp16 is in denormal.
136
+ /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
137
+ mostly concern fp16 here. In this case, f8 is usually in denormal. But there
138
+ could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
139
+ exponent bias 16. It means that there are some numbers in fp16 denormal but they
140
+ are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
141
+ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
142
+ (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
143
+ act_exponent = exponent - bias + 1;
144
+ exponent_diff =
145
+ f8_denormal_act_exponent -
146
+ act_exponent; // actual exponent is exponent-bias+1 as it is denormal
147
+ } else { // fp32/fp16 is normal with implicit 1
148
+ act_exponent = exponent - bias;
149
+ if (act_exponent <= f8_denormal_act_exponent) {
150
+ /* This is the case where fp32/fp16 is normal but it is in f8 denormal
151
+ range. For example fp8 nanoo mode, denormal exponent is -7, but if the
152
+ fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
153
+ Therefore it needs to be adjust to -6 and mantissa shift right by 1.
154
+ So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
155
+ exponent_diff = f8_denormal_act_exponent - act_exponent;
156
+ } else { // both fp32/fp16 and f8 are in normal range
157
+ exponent_diff = 0; // exponent_diff=0 does not mean there is no
158
+ // difference for this case, act_exponent could be
159
+ // larger. Just that it does not need shift mantissa
160
+ }
161
+ mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
162
+ }
163
+
164
+ bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
165
+ static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
166
+ /* This part is a bit tricky. The judgment of whether it is a tie needs to be
167
+ done before we shift right as shift right could rip off some residual part
168
+ and make something not midpoint look like midpoint. For example, the fp16
169
+ number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
170
+ shift right by 4 bits, it would look like midpoint.
171
+ */
172
+
173
+ if (exponent_diff > 0) {
174
+ mantissa >>= exponent_diff;
175
+ } else if (exponent_diff == -1) {
176
+ mantissa <<= -exponent_diff;
177
+ }
178
+ bool implicit_one = mantissa & (1 << mfmt);
179
+ // if there is no implicit 1, it means the f8 is denormal and need to adjust
180
+ // to denorm exponent
181
+ f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
182
+ f8_bias - (implicit_one ? 0 : 1);
183
+
184
+ // Now we have the exponent and mantissa adjusted
185
+ uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
186
+ bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
187
+ // that is not truncated is 1
188
+ mantissa +=
189
+ (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
190
+ drop_mask;
191
+
192
+ // Now we deal with overflow
193
+ if (f8_exponent == 0) {
194
+ if ((1 << mfmt) & mantissa) {
195
+ f8_exponent = 1; // denormal overflow to become normal, promote exponent
196
+ }
197
+ } else {
198
+ if ((1 << (mfmt + 1)) & mantissa) {
199
+ mantissa >>= 1;
200
+ f8_exponent++;
201
+ }
202
+ }
203
+
204
+ mantissa >>= (mfmt - wm);
205
+
206
+ // above range: quantize to maximum possible float of the same sign
207
+ const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
208
+ if (f8_exponent > max_exp) {
209
+ if (clip) {
210
+ mantissa = (1 << wm) - 1;
211
+ f8_exponent = max_exp;
212
+ } else {
213
+ return signed_inf;
214
+ }
215
+ }
216
+
217
+ if (f8_exponent == 0 && mantissa == 0) {
218
+ return negative_zero_nan ? 0 : (sign << 7);
219
+ }
220
+ mantissa &= (1 << wm) - 1;
221
+ return (sign << 7) | (f8_exponent << wm) | mantissa;
222
+ }
223
+
224
+ template <int we, int wm, typename T = float, bool negative_zero_nan = true>
225
+ inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
226
+ #ifdef __HIPCC__
227
+ constexpr bool is_half = std::is_same<T, _Float16>::value;
228
+ #else
229
+ constexpr bool is_half = false;
230
+ #endif
231
+ constexpr bool is_float = std::is_same<T, float>::value;
232
+ static_assert(is_half || is_float, "only half and float are supported");
233
+
234
+ constexpr int weo = is_half ? 5 : 8;
235
+ constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
236
+
237
+ T fInf, fNegInf, fNaN, fNeg0;
238
+
239
+ #ifdef __HIPCC__
240
+ if (is_half) {
241
+ const uint16_t ihInf = 0x7C00;
242
+ const uint16_t ihNegInf = 0xFC00;
243
+ const uint16_t ihNaN = 0x7C01;
244
+ const uint16_t ihNeg0 = 0x8000;
245
+ fInf = reinterpret_cast<const _Float16&>(ihInf);
246
+ fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
247
+ fNaN = reinterpret_cast<const _Float16&>(ihNaN);
248
+ fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
249
+ } else
250
+ #endif
251
+ if (is_float) {
252
+ const uint32_t ifInf = 0x7F800000;
253
+ const uint32_t ifNegInf = 0xFF800000;
254
+ const uint32_t ifNaN = 0x7F800001;
255
+ const uint32_t ifNeg0 = 0x80000000;
256
+ fInf = reinterpret_cast<const float&>(ifInf);
257
+ fNegInf = reinterpret_cast<const float&>(ifNegInf);
258
+ fNaN = reinterpret_cast<const float&>(ifNaN);
259
+ fNeg0 = reinterpret_cast<const float&>(ifNeg0);
260
+ }
261
+
262
+ if (x == 0) {
263
+ return 0;
264
+ }
265
+
266
+ uint32_t sign = x >> 7;
267
+ uint32_t mantissa = x & ((1 << wm) - 1);
268
+ int exponent = (x & 0x7F) >> wm;
269
+ if (negative_zero_nan) {
270
+ if (x == 0x80) {
271
+ return fNaN;
272
+ }
273
+ } else {
274
+ if (x == 0x80) {
275
+ return fNeg0;
276
+ }
277
+ if (exponent == ((1 << we) - 1)) {
278
+ return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
279
+ }
280
+ }
281
+ typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
282
+ if (we == 5 && is_half && !negative_zero_nan) {
283
+ retval = x << 8;
284
+ return reinterpret_cast<const T&>(retval);
285
+ }
286
+
287
+ const int exp_low_cutoff =
288
+ (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
289
+
290
+ // subnormal input
291
+ if (exponent == 0) {
292
+ // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
293
+ int sh = 1 + clz(mantissa) - (32 - wm);
294
+ mantissa <<= sh;
295
+ exponent += 1 - sh;
296
+ mantissa &= ((1 << wm) - 1);
297
+ }
298
+ exponent += exp_low_cutoff - 1;
299
+ mantissa <<= wmo - wm;
300
+
301
+ // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
302
+ if (exponent <= 0) {
303
+ mantissa |= 1 << wmo;
304
+ mantissa >>= 1 - exponent;
305
+ exponent = 0;
306
+ }
307
+
308
+ if (sizeof(T) == 2) {
309
+ retval = (sign << 15) | (exponent << 10) | mantissa;
310
+ } else {
311
+ retval = (sign << 31) | (exponent << 23) | mantissa;
312
+ }
313
+ return reinterpret_cast<const T&>(retval);
314
+ }
315
+
316
+ } // namespace hip_fp8_impl
fp8/amd/quant_utils.cuh ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include "hip_float8.h"
3
+
4
+ #include <hip/hip_fp16.h>
5
+ #include <hip/hip_bf16.h>
6
+ #include <hip/hip_bfloat16.h>
7
+
8
+ #include "../../../attention/dtype_fp8.cuh"
9
+ #include "../../../attention/dtype_float32.cuh"
10
+ #include "../../../attention/dtype_bfloat16.cuh"
11
+
12
+ namespace vllm {
13
+ #ifdef USE_ROCM
14
+
15
+ namespace fp8 {
16
+ #ifdef ENABLE_FP8
17
+
18
+ template <typename Tout, typename Tin>
19
+ __inline__ __device__ Tout vec_conversion(const Tin& x) {
20
+ return x;
21
+ }
22
+
23
+ template <typename Tout, typename Tin>
24
+ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
25
+ const float scale) {
26
+ return x;
27
+ }
28
+
29
+ // fp8 -> half
30
+ template <>
31
+ __inline__ __device__ uint16_t
32
+ vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
33
+ hip_fp8 f8{a, hip_fp8::from_bits()};
34
+ __half_raw res;
35
+ res.data = static_cast<float>(f8);
36
+ return res.x;
37
+ }
38
+
39
+ // fp8x2 -> half2
40
+ template <>
41
+ __inline__ __device__ uint32_t
42
+ vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
43
+ #if defined(__HIP__MI300__) && \
44
+ defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
45
+ const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
46
+ union {
47
+ __half2_raw h2r;
48
+ uint32_t ui32;
49
+ } tmp;
50
+ tmp.h2r.x.data = f2[0];
51
+ tmp.h2r.y.data = f2[1];
52
+ return tmp.ui32;
53
+ #else
54
+ union {
55
+ uint16_t u16[2];
56
+ uint32_t u32;
57
+ } tmp;
58
+
59
+ tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
60
+ tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
61
+ return tmp.u32;
62
+ #endif
63
+ }
64
+
65
+ // fp8x4 -> half2x2
66
+ template <>
67
+ __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
68
+ union {
69
+ uint2 u32x2;
70
+ uint32_t u32[2];
71
+ } tmp;
72
+ tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
73
+ tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
74
+ return tmp.u32x2;
75
+ }
76
+
77
+ // fp8x8 -> half2x4
78
+ template <>
79
+ __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
80
+ union {
81
+ uint4 u64x2;
82
+ uint2 u64[2];
83
+ } tmp;
84
+ tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
85
+ tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
86
+ return tmp.u64x2;
87
+ }
88
+
89
+ using __nv_bfloat16 = __hip_bfloat16;
90
+
91
+ // fp8 -> __nv_bfloat16
92
+ template <>
93
+ __inline__ __device__ __nv_bfloat16
94
+ vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
95
+ hip_fp8 f8{a, hip_fp8::from_bits()};
96
+ float f{f8};
97
+ return __float2bfloat16(f);
98
+ }
99
+
100
+ using __nv_bfloat162 = __hip_bfloat162;
101
+
102
+ // fp8x2 -> __nv_bfloat162
103
+ template <>
104
+ __inline__ __device__ __nv_bfloat162
105
+ vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
106
+ __nv_bfloat162 res;
107
+ res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
108
+ res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
109
+ return res;
110
+ }
111
+
112
+ // fp8x4 -> bf16_4_t
113
+ template <>
114
+ __inline__ __device__ bf16_4_t
115
+ vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
116
+ bf16_4_t res;
117
+ res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
118
+ res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
119
+ return res;
120
+ }
121
+
122
+ // fp8x8 -> bf16_8_t
123
+ template <>
124
+ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
125
+ bf16_4_t tmp1, tmp2;
126
+ tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
127
+ tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
128
+ bf16_8_t res;
129
+ res.x = tmp1.x;
130
+ res.y = tmp1.y;
131
+ res.z = tmp2.x;
132
+ res.w = tmp2.y;
133
+ return res;
134
+ }
135
+
136
+ // fp8 -> float
137
+ template <>
138
+ __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
139
+ hip_fp8 fp8{a, hip_fp8::from_bits()};
140
+ return static_cast<float>(fp8);
141
+ }
142
+
143
+ // fp8x2 -> float2
144
+ template <>
145
+ __inline__ __device__ float2
146
+ vec_conversion<float2, uint16_t>(const uint16_t& a) {
147
+ #if defined(__HIP__MI300__) && \
148
+ defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
149
+ float2 res;
150
+ const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
151
+ res.x = f2[0];
152
+ res.y = f2[1];
153
+ return res;
154
+ #else
155
+ float2 res;
156
+ res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
157
+ res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
158
+ return res;
159
+ #endif
160
+ }
161
+
162
+ // fp8x4 -> float4
163
+ template <>
164
+ __inline__ __device__ Float4_
165
+ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
166
+ Float4_ res;
167
+ res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
168
+ res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
169
+ return res;
170
+ }
171
+
172
+ // fp8x8 -> float8
173
+ template <>
174
+ __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
175
+ Float4_ tmp1, tmp2;
176
+ tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
177
+ tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
178
+ Float8_ res;
179
+ res.x = tmp1.x;
180
+ res.y = tmp1.y;
181
+ res.z = tmp2.x;
182
+ res.w = tmp2.y;
183
+ return res;
184
+ }
185
+
186
+ // half -> fp8
187
+ template <>
188
+ __inline__ __device__ uint8_t
189
+ vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
190
+ __half_raw tmp;
191
+ tmp.x = a;
192
+
193
+ hip_fp8 f8{static_cast<float>(tmp.data)};
194
+ return f8.data;
195
+ }
196
+
197
+ // bf16 -> fp8
198
+ template <>
199
+ __inline__ __device__ uint8_t
200
+ vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
201
+ hip_fp8 res{__bfloat162float(a)};
202
+ return res.data;
203
+ }
204
+
205
+ // float -> fp8
206
+ template <>
207
+ __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
208
+ hip_fp8 f8(a);
209
+ return f8.data;
210
+ }
211
+
212
+ // fp8x4 -> float4
213
+ template <>
214
+ __inline__ __device__ float4
215
+ vec_conversion<float4, uint32_t>(const uint32_t& a) {
216
+ Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
217
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
218
+ return res;
219
+ }
220
+
221
+ // float2 -> half2
222
+ template <>
223
+ __inline__ __device__ uint32_t
224
+ vec_conversion<uint32_t, float2>(const float2& a) {
225
+ union {
226
+ half2 float16;
227
+ uint32_t uint32;
228
+ };
229
+
230
+ float16 = __float22half2_rn(a);
231
+ return uint32;
232
+ }
233
+
234
+ // Float4 -> half2x2
235
+ template <>
236
+ __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
237
+ uint2 b;
238
+ float2 val;
239
+ val.x = a.x.x;
240
+ val.y = a.x.y;
241
+ b.x = vec_conversion<uint32_t, float2>(val);
242
+
243
+ val.x = a.y.x;
244
+ val.y = a.y.y;
245
+ b.y = vec_conversion<uint32_t, float2>(val);
246
+ return b;
247
+ }
248
+
249
+ // Float4 -> float4
250
+ template <>
251
+ __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
252
+ float4 b;
253
+ b.x = a.x.x;
254
+ b.y = a.x.y;
255
+ b.z = a.y.x;
256
+ b.w = a.y.y;
257
+ return b;
258
+ }
259
+
260
+ // Float8 -> half2x4
261
+ template <>
262
+ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
263
+ uint4 b;
264
+ b.x = vec_conversion<uint32_t, float2>(a.x);
265
+ b.y = vec_conversion<uint32_t, float2>(a.y);
266
+ b.z = vec_conversion<uint32_t, float2>(a.z);
267
+ b.w = vec_conversion<uint32_t, float2>(a.w);
268
+ return b;
269
+ }
270
+
271
+ // float2 -> bfloat162
272
+ template <>
273
+ __inline__ __device__ __nv_bfloat162
274
+ vec_conversion<__nv_bfloat162, float2>(const float2& a) {
275
+ __nv_bfloat162 b = __float22bfloat162_rn(a);
276
+ return b;
277
+ }
278
+
279
+ // Float4 -> bfloat162x2
280
+ template <>
281
+ __inline__ __device__ bf16_4_t
282
+ vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
283
+ bf16_4_t b;
284
+ b.x = __float22bfloat162_rn(a.x);
285
+ b.y = __float22bfloat162_rn(a.y);
286
+ return b;
287
+ }
288
+
289
+ // Float8 -> bfloat162x4
290
+ template <>
291
+ __inline__ __device__ bf16_8_t
292
+ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
293
+ bf16_8_t b;
294
+ b.x = __float22bfloat162_rn(a.x);
295
+ b.y = __float22bfloat162_rn(a.y);
296
+ b.z = __float22bfloat162_rn(a.z);
297
+ b.w = __float22bfloat162_rn(a.w);
298
+ return b;
299
+ }
300
+
301
+ /* Scaled and vectorized conversions, for data exchange between high and low
302
+ precision domains
303
+
304
+ Convention of the scale in API, e.g: FP8_data = Quantization(
305
+ High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
306
+ scale => HP
307
+
308
+ */
309
+
310
+ // fp8 -> half
311
+ template <>
312
+ __inline__ __device__ uint16_t
313
+ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
314
+ hip_fp8 f8{a, hip_fp8::from_bits()};
315
+ __half_raw res;
316
+ res.data = static_cast<float>(f8) * scale;
317
+ return res.x;
318
+ }
319
+
320
+ // fp8x2 -> half2
321
+ template <>
322
+ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
323
+ const uint16_t& a, const float scale) {
324
+ #if defined(__HIP__MI300__) && \
325
+ defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
326
+ const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
327
+ union {
328
+ __half2_raw h2r;
329
+ uint32_t ui32;
330
+ } tmp;
331
+ tmp.h2r.x.data = f2[0] * scale;
332
+ tmp.h2r.y.data = f2[1] * scale;
333
+ return tmp.ui32;
334
+ #else
335
+ union {
336
+ uint16_t u16[2];
337
+ uint32_t u32;
338
+ } tmp;
339
+
340
+ tmp.u16[0] =
341
+ scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
342
+ tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
343
+ static_cast<uint8_t>(a >> 8U), scale);
344
+ return tmp.u32;
345
+ #endif
346
+ }
347
+
348
+ // fp8x4 -> half2x2
349
+ template <>
350
+ __inline__ __device__ uint2
351
+ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
352
+ union {
353
+ uint2 u32x2;
354
+ uint32_t u32[2];
355
+ } tmp;
356
+ tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
357
+ tmp.u32[1] =
358
+ scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
359
+ return tmp.u32x2;
360
+ }
361
+
362
+ // fp8x8 -> half2x4
363
+ template <>
364
+ __inline__ __device__ uint4
365
+ scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
366
+ union {
367
+ uint4 u64x2;
368
+ uint2 u64[2];
369
+ } tmp;
370
+ tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
371
+ tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
372
+ return tmp.u64x2;
373
+ }
374
+
375
+ using __nv_bfloat16 = __hip_bfloat16;
376
+
377
+ // fp8 -> __nv_bfloat16
378
+ template <>
379
+ __inline__ __device__ __nv_bfloat16
380
+ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
381
+ const float scale) {
382
+ hip_fp8 f8{a, hip_fp8::from_bits()};
383
+ float f{f8};
384
+ return __float2bfloat16(f * scale);
385
+ }
386
+
387
+ using __nv_bfloat162 = __hip_bfloat162;
388
+
389
+ // fp8x2 -> __nv_bfloat162
390
+ template <>
391
+ __inline__ __device__ __nv_bfloat162
392
+ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
393
+ const float scale) {
394
+ __nv_bfloat162 res;
395
+ res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
396
+ res.y =
397
+ scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
398
+ return res;
399
+ }
400
+
401
+ // fp8x4 -> bf16_4_t
402
+ template <>
403
+ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
404
+ const uint32_t& a, const float scale) {
405
+ bf16_4_t res;
406
+ res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
407
+ res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
408
+ scale);
409
+ return res;
410
+ }
411
+
412
+ // fp8x8 -> bf16_8_t
413
+ template <>
414
+ __inline__ __device__ bf16_8_t
415
+ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
416
+ bf16_4_t tmp1, tmp2;
417
+ tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
418
+ tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
419
+ bf16_8_t res;
420
+ res.x = tmp1.x;
421
+ res.y = tmp1.y;
422
+ res.z = tmp2.x;
423
+ res.w = tmp2.y;
424
+ return res;
425
+ }
426
+
427
+ // fp8 -> float
428
+ template <>
429
+ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
430
+ const uint8_t& a, const float scale) {
431
+ hip_fp8 fp8{a, hip_fp8::from_bits()};
432
+ return static_cast<float>(fp8) * scale;
433
+ }
434
+
435
+ // fp8x2 -> float2
436
+ template <>
437
+ __inline__ __device__ float2
438
+ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
439
+ #if defined(__HIP__MI300__) && \
440
+ defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
441
+ float2 res;
442
+ const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
443
+ res.x = f2[0] * scale;
444
+ res.y = f2[1] * scale;
445
+ return res;
446
+ #else
447
+ float2 res;
448
+ res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
449
+ res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
450
+ scale);
451
+ return res;
452
+ #endif
453
+ }
454
+
455
+ // fp8x4 -> float4
456
+ template <>
457
+ __inline__ __device__ Float4_
458
+ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
459
+ Float4_ res;
460
+ res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
461
+ res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
462
+ return res;
463
+ }
464
+
465
+ // fp8x8 -> float8
466
+ template <>
467
+ __inline__ __device__ Float8_
468
+ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
469
+ Float4_ tmp1, tmp2;
470
+ tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
471
+ tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
472
+ Float8_ res;
473
+ res.x = tmp1.x;
474
+ res.y = tmp1.y;
475
+ res.z = tmp2.x;
476
+ res.w = tmp2.y;
477
+ return res;
478
+ }
479
+
480
+ /* Quantize(HP / scale) => FP8 */
481
+
482
+ // TODO(Hai): vectorized to add
483
+
484
+ // half -> fp8
485
+ template <>
486
+ __inline__ __device__ uint8_t
487
+ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
488
+ __half_raw tmp;
489
+ tmp.x = a;
490
+
491
+ hip_fp8 f8{static_cast<float>(tmp.data) / scale};
492
+ return f8.data;
493
+ }
494
+
495
+ // bf16 -> fp8
496
+ template <>
497
+ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
498
+ const __nv_bfloat16& a, const float scale) {
499
+ hip_fp8 res{__bfloat162float(a) / scale};
500
+ return res.data;
501
+ }
502
+
503
+ // float -> fp8
504
+ template <>
505
+ __inline__ __device__ uint8_t
506
+ scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
507
+ hip_fp8 f8(a / scale);
508
+ return f8.data;
509
+ }
510
+
511
+ // fp8x4 -> float4
512
+ template <>
513
+ __inline__ __device__ float4
514
+ scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
515
+ Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
516
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
517
+ return res;
518
+ }
519
+ #endif // ENABLE_FP8
520
+
521
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
522
+ __inline__ __device__ Tout convert(const Tin& x) {
523
+ #ifdef ENABLE_FP8
524
+ if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
525
+ return vec_conversion<Tout, Tin>(x);
526
+ }
527
+ #endif
528
+ assert(false);
529
+ return {}; // Squash missing return statement warning
530
+ }
531
+
532
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
533
+ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
534
+ #ifdef ENABLE_FP8
535
+ if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
536
+ return scaled_vec_conversion<Tout, Tin>(x, scale);
537
+ }
538
+ #endif
539
+ assert(false);
540
+ return {}; // Squash missing return statement warning
541
+ }
542
+
543
+ // The following macro is used to dispatch the conversion function based on
544
+ // the data type of the key and value cache. The FN is a macro that calls a
545
+ // function with template<typename scalar_t, typename cache_t,
546
+ // Fp8KVCacheDataType kv_dt>.
547
+ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
548
+ if (KV_DTYPE == "auto") { \
549
+ if (SRC_DTYPE == at::ScalarType::Float) { \
550
+ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
551
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
552
+ FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
553
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
554
+ FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
555
+ } else { \
556
+ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
557
+ } \
558
+ } else { \
559
+ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
560
+ if (SRC_DTYPE == at::ScalarType::Float) { \
561
+ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
562
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
563
+ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
564
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
565
+ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
566
+ } else { \
567
+ TORCH_CHECK(false, \
568
+ "Unsupported input type of kv cache: ", SRC_DTYPE); \
569
+ } \
570
+ } else { \
571
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
572
+ } \
573
+ }
574
+
575
+ } // namespace fp8
576
+ #endif // USE_ROCM
577
+ } // namespace vllm
fp8/common.cu ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "dispatch_utils.h"
3
+
4
+ #include <c10/cuda/CUDAGuard.h>
5
+
6
+ #ifndef USE_ROCM
7
+ #include <cub/cub.cuh>
8
+ #else
9
+ #include <hipcub/hipcub.hpp>
10
+ #endif
11
+
12
+ namespace vllm {
13
+
14
+ template <typename scalar_t>
15
+ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
16
+ const scalar_t* __restrict__ input,
17
+ const float* __restrict__ scale,
18
+ int64_t num_elems) {
19
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
20
+
21
+ // Invert the scale so that we can use multiplications to avoid expensive
22
+ // division.
23
+ const float inverted_scale = 1.0f / (*scale);
24
+ scaled_fp8_conversion_vec<scalar_t, true>(
25
+ out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
26
+ }
27
+
28
+ template <typename scalar_t>
29
+ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
30
+ FP8_TYPE* __restrict__ out, float* __restrict__ scale,
31
+ scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
32
+ const int hidden_size) {
33
+ float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
34
+
35
+ int const tid = threadIdx.x;
36
+ int const token_idx = blockIdx.x;
37
+
38
+ // Use int64 to avoid overflowing an int32 when calculating this offset
39
+ int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
40
+ scalar_t const* __restrict__ token_input = &input[offset];
41
+ FP8_TYPE* __restrict__ token_output = &out[offset];
42
+
43
+ // For vectorization, token_input and token_output pointers need to be
44
+ // aligned at 8-byte and 4-byte addresses respectively.
45
+ bool const can_vectorize = hidden_size % 4 == 0;
46
+
47
+ float absmax_val = 0.0f;
48
+ if (can_vectorize) {
49
+ absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
50
+ } else {
51
+ for (int i = tid; i < hidden_size; i += blockDim.x) {
52
+ float const x = static_cast<float>(token_input[i]);
53
+ absmax_val = max(absmax_val, fabs(x));
54
+ }
55
+ }
56
+
57
+ using BlockReduce = cub::BlockReduce<float, 1024>;
58
+ __shared__ typename BlockReduce::TempStorage reduceStorage;
59
+ float const block_absmax_val_maybe =
60
+ BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
61
+ __shared__ float token_scale;
62
+ if (tid == 0) {
63
+ if (scale_ub) {
64
+ token_scale = min(block_absmax_val_maybe, *scale_ub);
65
+ } else {
66
+ token_scale = block_absmax_val_maybe;
67
+ }
68
+ // token scale computation
69
+ token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
70
+ scale[token_idx] = token_scale;
71
+ }
72
+ __syncthreads();
73
+
74
+ // Note that we don't use inverted scales so we can match FBGemm impl.
75
+ if (can_vectorize) {
76
+ scaled_fp8_conversion_vec<scalar_t, false>(
77
+ token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
78
+ } else {
79
+ for (int i = tid; i < hidden_size; i += blockDim.x) {
80
+ token_output[i] = scaled_fp8_conversion<false>(
81
+ static_cast<float>(token_input[i]), token_scale);
82
+ }
83
+ }
84
+ }
85
+
86
+ } // namespace vllm
87
+
88
+ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
89
+ torch::Tensor const& input, // [..., d]
90
+ torch::Tensor const& scale) // [1]
91
+ {
92
+ int64_t num_tokens = input.numel() / input.size(-1);
93
+ int64_t num_elems = input.numel();
94
+ dim3 grid(num_tokens);
95
+ dim3 block(1024);
96
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
97
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
98
+ VLLM_DISPATCH_FLOATING_TYPES(
99
+ input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
100
+ vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
101
+ out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
102
+ scale.data_ptr<float>(), num_elems);
103
+ });
104
+ }
105
+
106
+ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
107
+ torch::Tensor const& input, // [..., d]
108
+ torch::Tensor& scale) // [1]
109
+ {
110
+ int64_t num_tokens = input.numel() / input.size(-1);
111
+ int64_t num_elems = input.numel();
112
+ dim3 grid(num_tokens);
113
+ dim3 block(1024);
114
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
115
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
116
+ VLLM_DISPATCH_FLOATING_TYPES(
117
+ input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
118
+ vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
119
+ scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
120
+ vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
121
+ out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
122
+ scale.data_ptr<float>(), num_elems);
123
+ });
124
+ }
125
+
126
+ void dynamic_per_token_scaled_fp8_quant(
127
+ torch::Tensor& out, // [..., d]
128
+ torch::Tensor const& input, // [..., d]
129
+ torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
130
+ TORCH_CHECK(input.is_contiguous());
131
+ TORCH_CHECK(out.is_contiguous());
132
+
133
+ int const hidden_size = input.size(-1);
134
+ int const num_tokens = input.numel() / hidden_size;
135
+ dim3 const grid(num_tokens);
136
+ dim3 const block(std::min(hidden_size, 1024));
137
+
138
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
139
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
140
+ VLLM_DISPATCH_FLOATING_TYPES(
141
+ input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
142
+ vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
143
+ <<<grid, block, 0, stream>>>(
144
+ out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
145
+ input.data_ptr<scalar_t>(),
146
+ scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
147
+ hidden_size);
148
+ });
149
+ }
fp8/common.cuh ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cmath>
4
+
5
+ #ifndef USE_ROCM
6
+ #include <c10/util/Float8_e4m3fn.h>
7
+ using FP8_TYPE = c10::Float8_e4m3fn;
8
+ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
9
+ std::numeric_limits<FP8_TYPE>::max();
10
+ #else
11
+ #include <c10/util/Float8_e4m3fnuz.h>
12
+ #include "amd/hip_float8.h"
13
+ using FP8_TYPE = c10::Float8_e4m3fnuz;
14
+ // Using the default max value from pytorch (240.0) will cause accuracy
15
+ // issue when running dynamic quantization. Here use 224.0f for rocm.
16
+ constexpr auto FP8_E4M3_MAX = 224.0f;
17
+ #endif
18
+
19
+ namespace vllm {
20
+
21
+ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
22
+ float old;
23
+ old = (value >= 0)
24
+ ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
25
+ : __uint_as_float(
26
+ atomicMin((unsigned int*)addr, __float_as_uint(value)));
27
+
28
+ return old;
29
+ }
30
+
31
+ template <bool is_scale_inverted>
32
+ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
33
+ float const scale) {
34
+ float x = 0.0f;
35
+ if constexpr (is_scale_inverted) {
36
+ x = val * scale;
37
+ } else {
38
+ x = val / scale;
39
+ }
40
+
41
+ float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
42
+ #ifndef USE_ROCM
43
+ return static_cast<c10::Float8_e4m3fn>(r);
44
+ #else
45
+ // Use hardware cvt instruction for fp8 on rocm
46
+ return c10::Float8_e4m3fnuz(hip_fp8(r).data,
47
+ c10::Float8_e4m3fnuz::from_bits());
48
+ #endif
49
+ }
50
+
51
+ // Compute the absolute maximum m of the input tensor and store
52
+ // m / float8_e4m3::max() in *scale. Each thread block performs a
53
+ // reduction tree and the memory in scale is atomically updated.
54
+ // So to get the right answer, *scale needs to be initialized to
55
+ // a value <= 0.0 and we need to wait for all thread blocks to
56
+ // finish before consuming *scale.
57
+ template <typename scalar_t>
58
+ __global__ void segmented_max_reduction(float* __restrict__ scale,
59
+ const scalar_t* __restrict__ input,
60
+ int64_t num_elems) {
61
+ __shared__ float cache[1024];
62
+ int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
63
+
64
+ // First store maximum for all values processes by
65
+ // the current thread in cache[threadIdx.x]
66
+ scalar_t tmp = 0.0;
67
+ while (i < num_elems) {
68
+ float x = static_cast<float>(input[i]);
69
+ tmp = max(tmp, fabs(x));
70
+ i += blockDim.x * gridDim.x;
71
+ }
72
+ cache[threadIdx.x] = tmp;
73
+
74
+ __syncthreads();
75
+
76
+ // Now perform parallel reduction within the thread block
77
+ int ib = blockDim.x / 2;
78
+ while (ib != 0) {
79
+ if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
80
+ cache[threadIdx.x] = cache[threadIdx.x + ib];
81
+ }
82
+ __syncthreads();
83
+ ib /= 2;
84
+ }
85
+ // Finally, since cache[0] contains the maximum for this thread block,
86
+ // atomically write the max to the target location
87
+ if (threadIdx.x == 0) {
88
+ atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
89
+ }
90
+ }
91
+
92
+ template <typename scalar_t>
93
+ struct __align__(8) vec4_t {
94
+ scalar_t x;
95
+ scalar_t y;
96
+ scalar_t z;
97
+ scalar_t w;
98
+ };
99
+
100
+ typedef struct __align__(4) {
101
+ FP8_TYPE x;
102
+ FP8_TYPE y;
103
+ FP8_TYPE z;
104
+ FP8_TYPE w;
105
+ }
106
+ float8x4_t;
107
+
108
+ template <typename scalar_t>
109
+ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
110
+ int64_t const num_elems, int const tid,
111
+ int const step) {
112
+ // Vectorized input/output to better utilize memory bandwidth.
113
+ vec4_t<scalar_t> const* vectorized_in =
114
+ reinterpret_cast<vec4_t<scalar_t> const*>(input);
115
+
116
+ int64_t const num_vec_elems = num_elems >> 2;
117
+ float absmax_val = 0.0f;
118
+
119
+ #pragma unroll 4
120
+ for (int64_t i = tid; i < num_vec_elems; i += step) {
121
+ vec4_t<scalar_t> in_vec = vectorized_in[i];
122
+ absmax_val = max(absmax_val, fabs(in_vec.x));
123
+ absmax_val = max(absmax_val, fabs(in_vec.y));
124
+ absmax_val = max(absmax_val, fabs(in_vec.z));
125
+ absmax_val = max(absmax_val, fabs(in_vec.w));
126
+ }
127
+
128
+ // Handle the remaining elements if num_elems is not divisible by 4
129
+ for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
130
+ absmax_val = max(absmax_val, fabs(input[i]));
131
+ }
132
+
133
+ return absmax_val;
134
+ }
135
+
136
+ template <typename scalar_t, bool is_scale_inverted>
137
+ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
138
+ scalar_t const* __restrict__ input,
139
+ float const scale,
140
+ int64_t const num_elems,
141
+ int const tid, int const step) {
142
+ // Vectorized input/output to better utilize memory bandwidth.
143
+ vec4_t<scalar_t> const* vectorized_in =
144
+ reinterpret_cast<vec4_t<scalar_t> const*>(input);
145
+ float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
146
+
147
+ int64_t const num_vec_elems = num_elems >> 2;
148
+
149
+ #pragma unroll 4
150
+ for (int64_t i = tid; i < num_vec_elems; i += step) {
151
+ vec4_t<scalar_t> in_vec = vectorized_in[i];
152
+ float8x4_t out_vec;
153
+
154
+ out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
155
+ static_cast<float>(in_vec.x), scale);
156
+ out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
157
+ static_cast<float>(in_vec.y), scale);
158
+ out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
159
+ static_cast<float>(in_vec.z), scale);
160
+ out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
161
+ static_cast<float>(in_vec.w), scale);
162
+ vectorized_out[i] = out_vec;
163
+ }
164
+
165
+ // Handle the remaining elements if num_elems is not divisible by 4
166
+ for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
167
+ out[i] = scaled_fp8_conversion<is_scale_inverted>(
168
+ static_cast<float>(input[i]), scale);
169
+ }
170
+ }
171
+
172
+ } // namespace vllm
fp8/fp8_marlin.cu ADDED
@@ -0,0 +1,1306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Modified by Neural Magic
3
+ * Copyright (C) Marlin.2024 Elias Frantar
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ /*
19
+ * Adapted from https://github.com/IST-DASLab/marlin
20
+ */
21
+
22
+ #include "../gptq_marlin/marlin.cuh"
23
+ #include "../gptq_marlin/marlin_dtypes.cuh"
24
+
25
+ using namespace marlin;
26
+
27
+ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
28
+ static_assert(std::is_same<scalar_t, half>::value || \
29
+ std::is_same<scalar_t, nv_bfloat16>::value, \
30
+ "only float16 and bfloat16 is supported");
31
+
32
+ template <typename T>
33
+ inline std::string str(T x) {
34
+ return std::to_string(x);
35
+ }
36
+
37
+ namespace fp8_marlin {
38
+
39
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
40
+
41
+ template <typename scalar_t, // compute dtype, half or nv_float16
42
+ const int num_bits, // number of bits used for weights
43
+ const int threads, // number of threads in a threadblock
44
+ const int thread_m_blocks, // number of 16x16 blocks in the m
45
+ // dimension (batchsize) of the
46
+ // threadblock
47
+ const int thread_n_blocks, // same for n dimension (output)
48
+ const int thread_k_blocks, // same for k dimension (reduction)
49
+ const int stages, // number of stages for the async global->shared
50
+ // fetch pipeline
51
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
52
+ // with a separate quantization scale
53
+ >
54
+ __global__ void Marlin(
55
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
56
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
57
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
58
+ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
59
+ // (k/groupsize)xn
60
+ int num_groups, // number of scale groups per output channel
61
+ int prob_m, // batch dimension m
62
+ int prob_n, // output dimension n
63
+ int prob_k, // reduction dimension k
64
+ int* locks // extra global storage for barrier synchronization
65
+ ) {}
66
+
67
+ } // namespace fp8_marlin
68
+
69
+ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
70
+ torch::Tensor& b_scales, torch::Tensor& workspace,
71
+ int64_t num_bits, int64_t size_m, int64_t size_n,
72
+ int64_t size_k) {
73
+ TORCH_CHECK_NOT_IMPLEMENTED(false,
74
+ "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
75
+ return torch::empty({1, 1});
76
+ }
77
+
78
+ #else
79
+
80
+ // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
81
+ // output/accumulation.
82
+ template <typename scalar_t>
83
+ __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
84
+ const typename ScalarType<scalar_t>::FragB& frag_b,
85
+ typename ScalarType<scalar_t>::FragC& frag_c) {
86
+ const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
87
+ const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
88
+ float* c = reinterpret_cast<float*>(&frag_c);
89
+ if constexpr (std::is_same<scalar_t, half>::value) {
90
+ asm volatile(
91
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
92
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
93
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
94
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
95
+ "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
96
+ } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
97
+ asm volatile(
98
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
99
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
100
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
101
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
102
+ "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
103
+ } else {
104
+ STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
105
+ }
106
+ }
107
+
108
+ // Instruction for loading a full 16x16 matrix fragment of operand A from shared
109
+ // memory, directly in tensor core layout.
110
+ template <typename scalar_t>
111
+ __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
112
+ const void* smem_ptr) {
113
+ uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
114
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
115
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
116
+ : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
117
+ : "r"(smem));
118
+ }
119
+
120
+ // Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16
121
+ // bf16 Reference:
122
+ // - FP16:
123
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
124
+ // - BF16:
125
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
126
+ template <typename scalar_t>
127
+ __device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
128
+ STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
129
+ }
130
+
131
+ template <>
132
+ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
133
+ // Constants for FP8 (E4M3) and FP16 formats
134
+ constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
135
+ constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
136
+
137
+ // Calculate MASK for extracting mantissa and exponent
138
+ constexpr int MASK1 = 0x80000000;
139
+ constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
140
+ constexpr int MASK3 = MASK2 & 0x7fffffff;
141
+ constexpr int MASK = MASK3 | (MASK3 >> 16);
142
+ // Final MASK value: 0x7F007F00
143
+
144
+ // Extract and shift FP8 values to FP16 format
145
+ int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
146
+ int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
147
+
148
+ // Construct and apply exponent bias
149
+ constexpr int BIAS_OFFSET =
150
+ (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
151
+ const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
152
+
153
+ // Convert to half2 and apply bias
154
+ typename ScalarType<half>::FragB frag_b;
155
+ // Note: reverse indexing is intentional because weights are permuted
156
+ frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
157
+ frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
158
+ return frag_b;
159
+ }
160
+
161
+ template <>
162
+ __device__ inline typename ScalarType<nv_bfloat16>::FragB
163
+ dequant_8bit<nv_bfloat16>(int q) {
164
+ // Constants for FP8 (E4M3) and BF16 formats
165
+ constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
166
+ constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
167
+
168
+ // Calculate MASK for extracting mantissa and exponent
169
+ constexpr int MASK1 = 0x80000000;
170
+ constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
171
+ constexpr int MASK3 = MASK2 & 0x7fffffff;
172
+ constexpr int MASK = MASK3 | (MASK3 >> 16);
173
+ // Final MASK value: 0x7F007F00
174
+
175
+ // Extract and shift FP8 values to BF16 format
176
+ int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
177
+ int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
178
+
179
+ // Construct and apply exponent bias
180
+ constexpr int BIAS_OFFSET =
181
+ (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
182
+ // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
183
+ // position
184
+ constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
185
+ const nv_bfloat162 bias_reg =
186
+ __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
187
+
188
+ // Convert to bfloat162 and apply bias
189
+ typename ScalarType<nv_bfloat16>::FragB frag_b;
190
+ // Note: reverse indexing is intentional because weights are permuted
191
+ frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
192
+ frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
193
+ return frag_b;
194
+ }
195
+
196
+ // Multiply dequantized values by the corresponding quantization scale; used
197
+ // only for grouped quantization.
198
+ template <typename scalar_t>
199
+ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
200
+ typename ScalarType<scalar_t>::FragS& frag_s,
201
+ int i) {
202
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
203
+ scalar_t2 s =
204
+ ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
205
+ frag_b[0] = __hmul2(frag_b[0], s);
206
+ frag_b[1] = __hmul2(frag_b[1], s);
207
+ }
208
+
209
+ // Given 2 floats multiply by 2 scales (halves)
210
+ template <typename scalar_t>
211
+ __device__ inline void scale_float(float* c,
212
+ typename ScalarType<scalar_t>::FragS& s) {
213
+ scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
214
+ c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
215
+ c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
216
+ }
217
+
218
+ // Wait until barrier reaches `count`, then lock for current threadblock.
219
+ __device__ inline void barrier_acquire(int* lock, int count) {
220
+ if (threadIdx.x == 0) {
221
+ int state = -1;
222
+ do
223
+ // Guarantee that subsequent writes by this threadblock will be visible
224
+ // globally.
225
+ asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
226
+ : "=r"(state)
227
+ : "l"(lock));
228
+ while (state != count);
229
+ }
230
+ __syncthreads();
231
+ }
232
+
233
+ // Release barrier and increment visitation count.
234
+ __device__ inline void barrier_release(int* lock, bool reset = false) {
235
+ __syncthreads();
236
+ if (threadIdx.x == 0) {
237
+ if (reset) {
238
+ lock[0] = 0;
239
+ return;
240
+ }
241
+ int val = 1;
242
+ // Make sure that all writes since acquiring this barrier are visible
243
+ // globally, while releasing the barrier.
244
+ asm volatile("fence.acq_rel.gpu;\n");
245
+ asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
246
+ :
247
+ : "l"(lock), "r"(val));
248
+ }
249
+ }
250
+
251
+ template <typename scalar_t, // compute dtype, half or nv_float16
252
+ const int num_bits, // number of bits used for weights
253
+ const int threads, // number of threads in a threadblock
254
+ const int thread_m_blocks, // number of 16x16 blocks in the m
255
+ // dimension (batchsize) of the
256
+ // threadblock
257
+ const int thread_n_blocks, // same for n dimension (output)
258
+ const int thread_k_blocks, // same for k dimension (reduction)
259
+ const int stages, // number of stages for the async global->shared
260
+ // fetch pipeline
261
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
262
+ // with a separate quantization scale
263
+ >
264
+ __global__ void Marlin(
265
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
266
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
267
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
268
+ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
269
+ // (k/groupsize)xn
270
+ int num_groups, // number of scale groups per output channel
271
+ int prob_m, // batch dimension m
272
+ int prob_n, // output dimension n
273
+ int prob_k, // reduction dimension k
274
+ int* locks // extra global storage for barrier synchronization
275
+ ) {
276
+ // Each threadblock processes one "stripe" of the B matrix with (roughly) the
277
+ // same size, which might involve multiple column "slices" (of width 16 *
278
+ // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
279
+ // example:
280
+ // 0 1 3
281
+ // 0 2 3
282
+ // 1 2 4
283
+ // While this kind of partitioning makes things somewhat more complicated, it
284
+ // ensures good utilization of all SMs for many kinds of shape and GPU
285
+ // configurations, while requiring as few slow global cross-threadblock
286
+ // reductions as possible.
287
+ using Dtype = ScalarType<scalar_t>;
288
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
289
+ using FragA = typename ScalarType<scalar_t>::FragA;
290
+ using FragB = typename ScalarType<scalar_t>::FragB;
291
+ using FragC = typename ScalarType<scalar_t>::FragC;
292
+ using FragS = typename ScalarType<scalar_t>::FragS;
293
+
294
+ constexpr int pack_factor = 32 / num_bits;
295
+
296
+ // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
297
+ // better partitioning with less reductions
298
+ int parallel = 1;
299
+ if (prob_m > 16 * thread_m_blocks) {
300
+ parallel = prob_m / (16 * thread_m_blocks);
301
+ prob_m = 16 * thread_m_blocks;
302
+ }
303
+
304
+ int k_tiles = prob_k / 16 / thread_k_blocks;
305
+ int n_tiles = prob_n / 16 / thread_n_blocks;
306
+ int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
307
+
308
+ int slice_row = (iters * blockIdx.x) % k_tiles;
309
+ int slice_col_par = (iters * blockIdx.x) / k_tiles;
310
+ int slice_col = slice_col_par;
311
+ int slice_iters; // number of threadblock tiles in the current slice
312
+ int slice_count =
313
+ 0; // total number of active threadblocks in the current slice
314
+ int slice_idx; // index of threadblock in current slice; numbered bottom to
315
+ // top
316
+
317
+ // We can easily implement parallel problem execution by just remapping
318
+ // indices and advancing global pointers
319
+ if (slice_col_par >= n_tiles) {
320
+ A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
321
+ C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
322
+ locks += (slice_col_par / n_tiles) * n_tiles;
323
+ slice_col = slice_col_par % n_tiles;
324
+ }
325
+
326
+ // Compute all information about the current slice which is required for
327
+ // synchronization.
328
+ auto init_slice = [&]() {
329
+ slice_iters =
330
+ iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
331
+ if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
332
+ if (slice_iters == 0) return;
333
+ if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
334
+ slice_count = 1;
335
+ slice_idx = 0;
336
+ int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
337
+ if (col_first <= k_tiles * (slice_col_par + 1)) {
338
+ int col_off = col_first - k_tiles * slice_col_par;
339
+ slice_count = div_ceil(k_tiles - col_off, iters);
340
+ if (col_off > 0) slice_count++;
341
+ int delta_first = iters * blockIdx.x - col_first;
342
+ if (delta_first < 0 || (col_off == 0 && delta_first == 0))
343
+ slice_idx = slice_count - 1;
344
+ else {
345
+ slice_idx = slice_count - 1 - delta_first / iters;
346
+ if (col_off > 0) slice_idx--;
347
+ }
348
+ }
349
+ if (slice_col == n_tiles) {
350
+ A += 16 * thread_m_blocks * prob_k / 8;
351
+ C += 16 * thread_m_blocks * prob_n / 8;
352
+ locks += n_tiles;
353
+ slice_col = 0;
354
+ }
355
+ };
356
+ init_slice();
357
+
358
+ // A sizes/strides
359
+
360
+ // stride of the A matrix in global memory
361
+ int a_gl_stride = prob_k / 8;
362
+ // stride of an A matrix tile in shared memory
363
+ constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
364
+ // delta between subsequent A tiles in global memory
365
+ constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
366
+ // between subsequent accesses within a tile
367
+ int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
368
+ // between shared memory writes
369
+ constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
370
+ // between shared memory tile reads
371
+ constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
372
+ // within a shared memory tile
373
+ constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
374
+ // overall size of a tile
375
+ constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
376
+ // number of shared write iterations for a tile
377
+ constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
378
+
379
+ // B sizes/strides
380
+ int b_gl_stride = 16 * prob_n / (pack_factor * 4);
381
+ constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
382
+ constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
383
+ constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
384
+
385
+ int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
386
+ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
387
+ constexpr int b_sh_wr_delta = threads * b_thread_vecs;
388
+ constexpr int b_sh_rd_delta = threads * b_thread_vecs;
389
+ constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
390
+ constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
391
+
392
+ // Scale sizes/strides without act_order
393
+ int s_gl_stride = prob_n / 8;
394
+ constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
395
+
396
+ // Scale size/strides with act_order
397
+ constexpr int tb_k = 16 * thread_k_blocks;
398
+ constexpr int g_idx_stage = 0;
399
+ // constexpr int act_s_row_stride = 1;
400
+ // int act_s_col_stride = act_s_row_stride * num_groups;
401
+ int act_s_col_stride = 1;
402
+ int act_s_col_warp_stride = act_s_col_stride * 8;
403
+ int tb_n_warps = thread_n_blocks / 4;
404
+ int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
405
+
406
+ // Global A read index of current thread.
407
+ int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
408
+ (threadIdx.x % a_gl_rd_delta_o);
409
+ a_gl_rd += a_gl_rd_delta_o * slice_row;
410
+ // Shared write index of current thread.
411
+ int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
412
+ (threadIdx.x % a_gl_rd_delta_o);
413
+ // Shared read index.
414
+ int a_sh_rd =
415
+ a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
416
+ a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
417
+
418
+ int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
419
+ (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
420
+ b_gl_rd += b_sh_stride * slice_col;
421
+ b_gl_rd += b_gl_rd_delta_o * slice_row;
422
+ int b_sh_wr = threadIdx.x * b_thread_vecs;
423
+ int b_sh_rd = threadIdx.x * b_thread_vecs;
424
+
425
+ // For act_order
426
+ int slice_k_start = tb_k * slice_row;
427
+ int slice_k_start_shared_fetch = slice_k_start;
428
+ int slice_n_offset = act_s_col_tb_stride * slice_col;
429
+
430
+ // No act_order
431
+ int s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
432
+ int s_sh_wr = threadIdx.x;
433
+ bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
434
+
435
+ // We scale a `half2` tile in row-major layout for column-wise quantization.
436
+ int s_sh_rd =
437
+ 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
438
+
439
+ // Precompute which thread should not read memory in which iterations; this is
440
+ // needed if there are more threads than required for a certain tilesize or
441
+ // when the batchsize is not a multiple of 16.
442
+ bool a_sh_wr_pred[a_sh_wr_iters];
443
+ #pragma unroll
444
+ for (int i = 0; i < a_sh_wr_iters; i++)
445
+ a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
446
+
447
+ // To ensure that writing and reading A tiles to/from shared memory, the
448
+ // latter in fragment format, is fully bank conflict free, we need to use a
449
+ // rather fancy XOR-based layout. The key here is that neither reads nor
450
+ // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
451
+ // same shared memory banks. Further, it seems (based on NSight-Compute) that
452
+ // each warp must also write a consecutive memory segment?
453
+ auto transform_a = [&](int i) {
454
+ int row = i / a_gl_rd_delta_o;
455
+ return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
456
+ };
457
+ // Since the computation of this remapping is non-trivial and, due to our main
458
+ // loop unrolls, all shared memory accesses are static, we simply precompute
459
+ // both transformed reads and writes.
460
+ int a_sh_wr_trans[a_sh_wr_iters];
461
+ #pragma unroll
462
+ for (int i = 0; i < a_sh_wr_iters; i++)
463
+ a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
464
+ int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
465
+ #pragma unroll
466
+ for (int i = 0; i < b_sh_wr_iters; i++) {
467
+ #pragma unroll
468
+ for (int j = 0; j < thread_m_blocks; j++)
469
+ a_sh_rd_trans[i][j] =
470
+ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
471
+ }
472
+
473
+ // Since B-accesses have non-constant stride they have to be computed at
474
+ // runtime; we break dependencies between subsequent accesses with a tile by
475
+ // maintining multiple pointers (we have enough registers), a tiny
476
+ // optimization.
477
+ const int4* B_ptr[b_sh_wr_iters];
478
+ #pragma unroll
479
+ for (int i = 0; i < b_sh_wr_iters; i++)
480
+ B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
481
+
482
+ extern __shared__ int4 sh[];
483
+ // Shared memory storage for global fetch pipelines.
484
+ int4* sh_a = sh;
485
+ int4* sh_b = sh_a + (stages * a_sh_stage);
486
+ int4* sh_g_idx = sh_b + (stages * b_sh_stage);
487
+ int4* sh_s = sh_g_idx + (stages * g_idx_stage);
488
+
489
+ // Register storage for double buffer of shared memory reads.
490
+ FragA frag_a[2][thread_m_blocks];
491
+ I4 frag_b_quant[2][b_thread_vecs];
492
+ FragC frag_c[thread_m_blocks][4][2];
493
+ FragS frag_s[2][4];
494
+
495
+ // Zero accumulators.
496
+ auto zero_accums = [&]() {
497
+ #pragma unroll
498
+ for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
499
+ reinterpret_cast<float*>(frag_c)[i] = 0;
500
+ };
501
+
502
+ int sh_first_group_id = -1;
503
+ int sh_num_groups = -1;
504
+ constexpr int sh_max_num_groups = 32;
505
+
506
+ auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
507
+ int last_group_id) {
508
+ sh_first_group_id = first_group_id;
509
+ sh_num_groups = last_group_id - first_group_id + 1;
510
+
511
+ if (sh_num_groups < sh_max_num_groups) {
512
+ sh_num_groups = sh_max_num_groups;
513
+ }
514
+
515
+ if (sh_first_group_id + sh_num_groups > num_groups) {
516
+ sh_num_groups = num_groups - sh_first_group_id;
517
+ }
518
+
519
+ int row_offset = first_group_id * s_gl_stride;
520
+
521
+ if (is_async) {
522
+ for (int i = 0; i < sh_num_groups; i++) {
523
+ if (threadIdx.x < s_sh_stride) {
524
+ cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
525
+ &scales_ptr[row_offset + (i * s_gl_stride) +
526
+ slice_n_offset + threadIdx.x]);
527
+ }
528
+ }
529
+ } else {
530
+ for (int i = 0; i < sh_num_groups; i++) {
531
+ if (threadIdx.x < s_sh_stride) {
532
+ sh_s[(i * s_sh_stride) + threadIdx.x] =
533
+ scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
534
+ threadIdx.x];
535
+ }
536
+ }
537
+ }
538
+ };
539
+ // Asynchronously fetch the next A, B and s tile from global to the next
540
+ // shared memory pipeline location.
541
+ auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
542
+ if (pred) {
543
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
544
+ #pragma unroll
545
+ for (int i = 0; i < a_sh_wr_iters; i++) {
546
+ cp_async4_pred(
547
+ &sh_a_stage[a_sh_wr_trans[i]],
548
+ &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
549
+ a_sh_wr_pred[i]);
550
+ }
551
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
552
+ #pragma unroll
553
+ for (int i = 0; i < b_sh_wr_iters; i++) {
554
+ #pragma unroll
555
+ for (int j = 0; j < b_thread_vecs; j++) {
556
+ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
557
+ }
558
+
559
+ B_ptr[i] += b_gl_rd_delta_o;
560
+ }
561
+ }
562
+ // Insert a fence even when we are winding down the pipeline to ensure that
563
+ // waiting is also correct at this point.
564
+ cp_async_fence();
565
+ };
566
+
567
+ // Wait until the next thread tile has been loaded to shared memory.
568
+ auto wait_for_stage = [&]() {
569
+ // We only have `stages - 2` active fetches since we are double buffering
570
+ // and can only issue the next fetch when it is guaranteed that the previous
571
+ // shared memory load is fully complete (as it may otherwise be
572
+ // overwritten).
573
+ cp_async_wait<stages - 2>();
574
+ __syncthreads();
575
+ };
576
+
577
+ // Load the next sub-tile from the current location in the shared memory pipe
578
+ // into the current register buffer.
579
+ auto fetch_to_registers = [&](int k, int pipe) {
580
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
581
+ #pragma unroll
582
+ for (int i = 0; i < thread_m_blocks; i++)
583
+ ldsm4<scalar_t>(frag_a[k % 2][i],
584
+ &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
585
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
586
+
587
+ #pragma unroll
588
+ for (int i = 0; i < b_thread_vecs; i++) {
589
+ frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
590
+ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
591
+ }
592
+ };
593
+
594
+ bool is_same_group[stages];
595
+ int same_group_id[stages];
596
+
597
+ auto init_same_group = [&](int pipe) {
598
+ is_same_group[pipe] = false;
599
+ same_group_id[pipe] = 0;
600
+ return;
601
+ };
602
+
603
+ // Execute the actual tensor core matmul of a sub-tile.
604
+ auto matmul = [&](int k) {
605
+ // We have the m dimension as the inner loop in order to encourage overlapping
606
+ // dequantization and matmul operations.
607
+ #pragma unroll
608
+ for (int j = 0; j < 4; j++) {
609
+ FragB frag_b0;
610
+ FragB frag_b1;
611
+
612
+ int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
613
+ int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
614
+ int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
615
+
616
+ frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
617
+ frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
618
+
619
+ #pragma unroll
620
+ for (int i = 0; i < thread_m_blocks; i++) {
621
+ mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
622
+ mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
623
+ }
624
+ }
625
+ };
626
+
627
+ // Since we slice across the k dimension of a tile in order to increase the
628
+ // number of warps while keeping the n dimension of a tile reasonable, we have
629
+ // multiple warps that accumulate their partial sums of the same output
630
+ // location; which we have to reduce over in the end. We do in shared memory.
631
+ auto thread_block_reduce = [&]() {
632
+ constexpr int red_off = threads / b_sh_stride_threads / 2;
633
+ if (red_off >= 1) {
634
+ int red_idx = threadIdx.x / b_sh_stride_threads;
635
+ constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
636
+ constexpr int red_sh_delta = b_sh_stride_threads;
637
+ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
638
+ (threadIdx.x % b_sh_stride_threads);
639
+
640
+ // Parallel logarithmic shared memory reduction. We make sure to avoid any
641
+ // unnecessary read or write iterations, e.g., for two warps we write only
642
+ // once by warp 1 and read only once by warp 0.
643
+
644
+ #pragma unroll
645
+ for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
646
+ #pragma unroll
647
+ for (int i = red_off; i > 0; i /= 2) {
648
+ if (i <= red_idx && red_idx < 2 * i) {
649
+ #pragma unroll
650
+ for (int j = 0; j < 4 * 2; j++) {
651
+ int red_sh_wr =
652
+ red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
653
+ if (i < red_off) {
654
+ float* c_rd =
655
+ reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
656
+ float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
657
+ #pragma unroll
658
+ for (int k = 0; k < 4; k++)
659
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
660
+ c_rd[k] + c_wr[k];
661
+ }
662
+ sh[red_sh_wr] =
663
+ reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
664
+ }
665
+ }
666
+ __syncthreads();
667
+ }
668
+ if (red_idx == 0) {
669
+ #pragma unroll
670
+ for (int i = 0; i < 4 * 2; i++) {
671
+ float* c_rd =
672
+ reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
673
+ #pragma unroll
674
+ for (int j = 0; j < 4; j++)
675
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
676
+ c_rd[j];
677
+ }
678
+ }
679
+ __syncthreads();
680
+ }
681
+ }
682
+ };
683
+
684
+ // Since multiple threadblocks may process parts of the same column slice, we
685
+ // finally have to globally reduce over the results. As the striped
686
+ // partitioning minimizes the number of such reductions and our outputs are
687
+ // usually rather small, we perform this reduction serially in L2 cache.
688
+ auto global_reduce = [&](bool first = false, bool last = false) {
689
+ // We are very careful here to reduce directly in the output buffer to
690
+ // maximize L2 cache utilization in this step. To do this, we write out
691
+ // results in FP16 (but still reduce with FP32 compute).
692
+ constexpr int active_threads = 32 * thread_n_blocks / 4;
693
+ if (threadIdx.x < active_threads) {
694
+ int c_gl_stride = prob_n / 8;
695
+ int c_gl_wr_delta_o = 8 * c_gl_stride;
696
+ int c_gl_wr_delta_i = 4 * (active_threads / 32);
697
+ int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
698
+ 4 * (threadIdx.x / 32) + threadIdx.x % 4;
699
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
700
+ constexpr int c_sh_wr_delta = active_threads;
701
+ int c_sh_wr = threadIdx.x;
702
+
703
+ int row = (threadIdx.x % 32) / 4;
704
+
705
+ if (!first) {
706
+ // Interestingly, doing direct global accesses here really seems to mess up
707
+ // the compiler and lead to slowdowns, hence we also use async-copies even
708
+ // though these fetches are not actually asynchronous.
709
+ #pragma unroll
710
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
711
+ cp_async4_pred(
712
+ &sh[c_sh_wr + c_sh_wr_delta * i],
713
+ &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
714
+ c_gl_wr_delta_i * (i % 2)],
715
+ i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
716
+ }
717
+ cp_async_fence();
718
+ cp_async_wait<0>();
719
+ }
720
+
721
+ #pragma unroll
722
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
723
+ if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
724
+ if (!first) {
725
+ int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
726
+ #pragma unroll
727
+ for (int j = 0; j < 2 * 4; j++) {
728
+ reinterpret_cast<float*>(
729
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
730
+ Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
731
+ }
732
+ }
733
+ if (!last) {
734
+ int4 c;
735
+ #pragma unroll
736
+ for (int j = 0; j < 2 * 4; j++) {
737
+ reinterpret_cast<scalar_t*>(&c)[j] =
738
+ Dtype::float2num(reinterpret_cast<float*>(
739
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
740
+ }
741
+ C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
742
+ c;
743
+ }
744
+ }
745
+ }
746
+ }
747
+ };
748
+
749
+ // Write out the reduce final result in the correct layout. We only actually
750
+ // reshuffle matrix fragments in this step, the reduction above is performed
751
+ // in fragment layout.
752
+ auto write_result = [&]() {
753
+ int c_gl_stride = prob_n / 8;
754
+ constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
755
+ int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
756
+ constexpr int c_sh_rd_delta =
757
+ c_sh_stride * (threads / (2 * thread_n_blocks));
758
+
759
+ int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
760
+ (threadIdx.x % (2 * thread_n_blocks));
761
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
762
+ int c_sh_wr =
763
+ (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
764
+ c_sh_wr += 32 * (threadIdx.x / 32);
765
+ int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
766
+ (threadIdx.x % (2 * thread_n_blocks));
767
+
768
+ int c_gl_wr_end = c_gl_stride * prob_m;
769
+
770
+ // We first reorder in shared memory to guarantee the most efficient final
771
+ // global write patterns
772
+ auto write = [&](int idx, float c0, float c1, FragS& s) {
773
+ scalar_t2 res =
774
+ Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
775
+
776
+ ((scalar_t2*)sh)[idx] = res;
777
+ };
778
+
779
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
780
+ #pragma unroll
781
+ for (int i = 0; i < thread_m_blocks; i++) {
782
+ #pragma unroll
783
+ for (int j = 0; j < 4; j++) {
784
+ int wr = c_sh_wr + 8 * j;
785
+ write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
786
+ frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
787
+ write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
788
+ frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
789
+ write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
790
+ frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
791
+ write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
792
+ frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
793
+ }
794
+ c_sh_wr += 16 * (4 * c_sh_stride);
795
+ }
796
+ }
797
+ __syncthreads();
798
+
799
+ #pragma unroll
800
+ for (int i = 0;
801
+ i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
802
+ i++) {
803
+ if (c_gl_wr < c_gl_wr_end) {
804
+ C[c_gl_wr] = sh[c_sh_rd];
805
+ c_gl_wr += c_gl_wr_delta;
806
+ c_sh_rd += c_sh_rd_delta;
807
+ }
808
+ }
809
+ };
810
+
811
+ // Start global fetch and register load pipelines.
812
+ auto start_pipes = [&]() {
813
+
814
+ #pragma unroll
815
+ for (int i = 0; i < stages - 1; i++) {
816
+ fetch_to_shared(i, i, i < slice_iters);
817
+ }
818
+
819
+ zero_accums();
820
+ wait_for_stage();
821
+ init_same_group(0);
822
+ fetch_to_registers(0, 0);
823
+ a_gl_rd += a_gl_rd_delta_o * (stages - 1);
824
+ slice_k_start_shared_fetch += tb_k * (stages - 1);
825
+ };
826
+ if (slice_iters) {
827
+ start_pipes();
828
+ }
829
+
830
+ // Main loop.
831
+ while (slice_iters) {
832
+ // We unroll over both the global fetch and the register load pipeline to
833
+ // ensure all shared memory accesses are static. Note that both pipelines
834
+ // have even length meaning that the next iteration will always start at
835
+ // index 0.
836
+
837
+ #pragma unroll
838
+ for (int pipe = 0; pipe < stages;) {
839
+ #pragma unroll
840
+ for (int k = 0; k < b_sh_wr_iters; k++) {
841
+ fetch_to_registers(k + 1, pipe % stages);
842
+ if (k == b_sh_wr_iters - 2) {
843
+ fetch_to_shared((pipe + stages - 1) % stages, pipe,
844
+ slice_iters >= stages);
845
+ pipe++;
846
+ wait_for_stage();
847
+ init_same_group(pipe % stages);
848
+ }
849
+ matmul(k);
850
+ }
851
+ slice_iters--;
852
+ if (slice_iters == 0) {
853
+ break;
854
+ }
855
+ }
856
+
857
+ a_gl_rd += a_gl_rd_delta_o * stages;
858
+ slice_k_start += tb_k * stages;
859
+ slice_k_start_shared_fetch += tb_k * stages;
860
+
861
+ // Process results and, if necessary, proceed to the next column slice.
862
+ // While this pattern may not be the most readable, other ways of writing
863
+ // the loop seemed to noticeably worse performance after compilation.
864
+ if (slice_iters == 0) {
865
+ cp_async_wait<0>();
866
+ bool last = slice_idx == slice_count - 1;
867
+ // For per-column scales, we only fetch them here in the final step before
868
+ // write-out
869
+ if (s_sh_wr_pred) {
870
+ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
871
+ }
872
+ cp_async_fence();
873
+
874
+ thread_block_reduce();
875
+
876
+ cp_async_wait<0>();
877
+ __syncthreads();
878
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
879
+ reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
880
+ reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
881
+ }
882
+
883
+ // For 8-bit channelwise, we apply the scale before the global reduction
884
+ // that converts the fp32 results to fp16 (so that we avoid possible
885
+ // overflow in fp16)
886
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
887
+ #pragma unroll
888
+ for (int i = 0; i < thread_m_blocks; i++) {
889
+ #pragma unroll
890
+ for (int j = 0; j < 4; j++) {
891
+ scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
892
+ frag_s[j / 2][2 * (j % 2) + 0]);
893
+ scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
894
+ frag_s[j / 2][2 * (j % 2) + 0]);
895
+
896
+ scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
897
+ frag_s[j / 2][2 * (j % 2) + 1]);
898
+ scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
899
+ frag_s[j / 2][2 * (j % 2) + 1]);
900
+ }
901
+ }
902
+ }
903
+
904
+ if (slice_count > 1) { // only globally reduce if there is more than one
905
+ // block in a slice
906
+ barrier_acquire(&locks[slice_col], slice_idx);
907
+ global_reduce(slice_idx == 0, last);
908
+ barrier_release(&locks[slice_col], last);
909
+ }
910
+ if (last) // only the last block in a slice actually writes the result
911
+ write_result();
912
+ slice_row = 0;
913
+ slice_col_par++;
914
+ slice_col++;
915
+ init_slice();
916
+ if (slice_iters) {
917
+ a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
918
+ (threadIdx.x % a_gl_rd_delta_o);
919
+ #pragma unroll
920
+ for (int i = 0; i < b_sh_wr_iters; i++)
921
+ B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
922
+ if (slice_col == 0) {
923
+ #pragma unroll
924
+ for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
925
+ }
926
+
927
+ // Update slice k/n for scales loading
928
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
929
+
930
+ start_pipes();
931
+ }
932
+ }
933
+ }
934
+ }
935
+
936
+ #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
937
+ THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \
938
+ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
939
+ thread_n_blocks == THREAD_N_BLOCKS && \
940
+ thread_k_blocks == THREAD_K_BLOCKS && \
941
+ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
942
+ cudaFuncSetAttribute( \
943
+ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
944
+ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \
945
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
946
+ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
947
+ THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS> \
948
+ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
949
+ A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \
950
+ locks); \
951
+ }
952
+
953
+ typedef struct {
954
+ int thread_k;
955
+ int thread_n;
956
+ int num_threads;
957
+ } thread_config_t;
958
+
959
+ typedef struct {
960
+ int max_m_blocks;
961
+ thread_config_t tb_cfg;
962
+ } exec_config_t;
963
+
964
+ thread_config_t small_batch_thread_configs[] = {
965
+ // Ordered by priority
966
+
967
+ // thread_k, thread_n, num_threads
968
+ {128, 128, 256},
969
+ {64, 128, 128},
970
+ {128, 64, 128},
971
+ };
972
+
973
+ thread_config_t large_batch_thread_configs[] = {
974
+ // Ordered by priority
975
+
976
+ // thread_k, thread_n, num_threads
977
+ {64, 256, 256},
978
+ {64, 128, 128},
979
+ {128, 64, 128},
980
+
981
+ };
982
+
983
+ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
984
+ int prob_n, int prob_k, int num_bits,
985
+ int group_size) {
986
+ int tb_n = th_config.thread_n;
987
+
988
+ // Get max scale groups per thread-block
989
+ // Fixed for channelwise
990
+ int tb_groups = 1;
991
+ int tb_scales = tb_groups * tb_n * 2;
992
+
993
+ return tb_scales * pipe_stages;
994
+ }
995
+
996
+ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
997
+ int prob_m, int prob_n, int prob_k, int num_bits,
998
+ int scales_cache_size, int max_shared_mem) {
999
+ int pack_factor = 32 / num_bits;
1000
+
1001
+ // Get B size
1002
+ int tb_k = th_config.thread_k;
1003
+ int tb_n = th_config.thread_n;
1004
+
1005
+ int b_size = (tb_k * tb_n / pack_factor) * 4;
1006
+
1007
+ // Get A size
1008
+ int m_blocks = div_ceil(prob_m, 16);
1009
+ int tb_max_m = 16;
1010
+
1011
+ while (true) {
1012
+ if (m_blocks >= max_m_blocks) {
1013
+ tb_max_m *= max_m_blocks;
1014
+ break;
1015
+ }
1016
+
1017
+ max_m_blocks--;
1018
+ if (max_m_blocks == 0) {
1019
+ TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
1020
+ }
1021
+ }
1022
+
1023
+ int a_size = (tb_max_m * tb_k) * 2;
1024
+
1025
+ float pipe_size = (a_size + b_size) * pipe_stages;
1026
+
1027
+ TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
1028
+
1029
+ return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
1030
+ }
1031
+
1032
+ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
1033
+ int prob_m, int prob_n, int prob_k, int num_bits,
1034
+ int group_size, int max_shared_mem) {
1035
+ // Sanity
1036
+ if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
1037
+ th_config.num_threads == -1) {
1038
+ return false;
1039
+ }
1040
+
1041
+ // Verify K/N are divisible by thread K/N
1042
+ if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
1043
+ return false;
1044
+ }
1045
+
1046
+ // Verify min for thread K/N
1047
+ if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
1048
+ return false;
1049
+ }
1050
+
1051
+ // num_threads must be at least 128 (= 4 warps)
1052
+ if (th_config.num_threads < 128) {
1053
+ return false;
1054
+ }
1055
+
1056
+ // Determine cache for scales
1057
+ int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n,
1058
+ prob_k, num_bits, group_size);
1059
+
1060
+ // Check that pipeline fits into cache
1061
+ if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
1062
+ num_bits, scales_cache_size, max_shared_mem)) {
1063
+ return false;
1064
+ }
1065
+
1066
+ return true;
1067
+ }
1068
+
1069
+ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
1070
+ int num_bits, int group_size,
1071
+ int max_shared_mem) {
1072
+ int max_m_blocks = 4;
1073
+ while (max_m_blocks > 0) {
1074
+ if (prob_m <= 16) {
1075
+ for (auto th_config : small_batch_thread_configs) {
1076
+ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
1077
+ num_bits, group_size, max_shared_mem)) {
1078
+ return exec_config_t{max_m_blocks, th_config};
1079
+ }
1080
+ }
1081
+ } else {
1082
+ for (auto th_config : large_batch_thread_configs) {
1083
+ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
1084
+ num_bits, group_size, max_shared_mem)) {
1085
+ return exec_config_t{max_m_blocks, th_config};
1086
+ }
1087
+ }
1088
+ }
1089
+
1090
+ max_m_blocks--; // Process less M blocks per invocation to reduce cache
1091
+ // usage
1092
+ }
1093
+
1094
+ return exec_config_t{0, {-1, -1, -1}};
1095
+ }
1096
+
1097
+ #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
1098
+ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
1099
+ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
1100
+ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
1101
+ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)
1102
+
1103
+ template <typename scalar_t>
1104
+ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m,
1105
+ int prob_n, int prob_k, void* workspace, int num_bits,
1106
+ int num_groups, int group_size, int dev,
1107
+ cudaStream_t stream, int thread_k, int thread_n, int sms,
1108
+ int max_par) {
1109
+ TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits);
1110
+ TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
1111
+ ", ", prob_n, ", ", prob_k, "]");
1112
+
1113
+ int tot_m = prob_m;
1114
+ int tot_m_blocks = div_ceil(tot_m, 16);
1115
+ int pad = 16 * tot_m_blocks - tot_m;
1116
+
1117
+ if (sms == -1) {
1118
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
1119
+ }
1120
+
1121
+ int max_shared_mem = 0;
1122
+ cudaDeviceGetAttribute(&max_shared_mem,
1123
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
1124
+ TORCH_CHECK(max_shared_mem > 0);
1125
+
1126
+ // Set thread config
1127
+ exec_config_t exec_cfg;
1128
+ if (thread_k != -1 && thread_n != -1) {
1129
+ // User-defined config
1130
+ exec_cfg =
1131
+ exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
1132
+ } else {
1133
+ // Auto config
1134
+ exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,
1135
+ group_size, max_shared_mem);
1136
+ }
1137
+
1138
+ TORCH_CHECK(
1139
+ exec_cfg.max_m_blocks > 0 &&
1140
+ is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,
1141
+ prob_n, prob_k, num_bits, group_size, max_shared_mem),
1142
+ "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
1143
+ ", thread_k = ", exec_cfg.tb_cfg.thread_k,
1144
+ ", thread_n = ", exec_cfg.tb_cfg.thread_n,
1145
+ ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m,
1146
+ ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
1147
+ ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem);
1148
+
1149
+ int num_threads = exec_cfg.tb_cfg.num_threads;
1150
+ thread_k = exec_cfg.tb_cfg.thread_k;
1151
+ thread_n = exec_cfg.tb_cfg.thread_n;
1152
+
1153
+ int thread_k_blocks = thread_k / 16;
1154
+ int thread_n_blocks = thread_n / 16;
1155
+
1156
+ int blocks = sms;
1157
+
1158
+ TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
1159
+ " is not divisible by thread_n = ", thread_n);
1160
+ TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
1161
+ " is not divisible by thread_k = ", thread_k);
1162
+
1163
+ int group_blocks = -1;
1164
+
1165
+ const int4* A_ptr = (const int4*)A;
1166
+ const int4* B_ptr = (const int4*)B;
1167
+ int4* C_ptr = (int4*)C;
1168
+ const int4* s_ptr = (const int4*)s;
1169
+
1170
+ int* locks = (int*)workspace;
1171
+
1172
+ // Main loop
1173
+ for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
1174
+ int thread_m_blocks = tot_m_blocks - i;
1175
+ prob_m = tot_m - 16 * i;
1176
+ int par = 1;
1177
+ if (thread_m_blocks > exec_cfg.max_m_blocks) {
1178
+ // Note that parallel > 1 currently only works for inputs without any
1179
+ // padding
1180
+ par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
1181
+ if (par > max_par) par = max_par;
1182
+ prob_m = (16 * exec_cfg.max_m_blocks) * par;
1183
+ i += exec_cfg.max_m_blocks * (par - 1);
1184
+ thread_m_blocks = exec_cfg.max_m_blocks;
1185
+ }
1186
+
1187
+ // Define kernel configurations
1188
+ if (false) {
1189
+ }
1190
+ CALL_IF(8, 32, 2, 256)
1191
+ CALL_IF(8, 16, 4, 256)
1192
+ CALL_IF(8, 8, 8, 256)
1193
+ CALL_IF(8, 8, 4, 128)
1194
+ CALL_IF(8, 4, 8, 128)
1195
+ else {
1196
+ TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
1197
+ str(prob_n) + ", " + str(prob_k) + "]" +
1198
+ ", num_groups = " + str(num_groups) +
1199
+ ", group_size = " + str(group_size) +
1200
+ ", thread_m_blocks = " + str(thread_m_blocks) +
1201
+ ", thread_n_blocks = " + str(thread_n_blocks) +
1202
+ ", thread_k_blocks = " + str(thread_k_blocks));
1203
+ }
1204
+
1205
+ A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
1206
+ C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
1207
+ }
1208
+ }
1209
+
1210
+ } // namespace fp8_marlin
1211
+
1212
+ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
1213
+ torch::Tensor& b_scales, torch::Tensor& workspace,
1214
+ int64_t num_bits, int64_t size_m, int64_t size_n,
1215
+ int64_t size_k) {
1216
+ // Verify num_bits
1217
+ TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits);
1218
+ int pack_factor = 32 / num_bits;
1219
+
1220
+ // Verify A
1221
+ TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
1222
+ ", size_m = ", size_m);
1223
+ TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
1224
+ ", size_k = ", size_k);
1225
+
1226
+ // Verify B
1227
+ TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
1228
+ " is not divisible by tile_size = ", marlin::tile_size);
1229
+ TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
1230
+ "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
1231
+ ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
1232
+ TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
1233
+ "b_q_weight.size(1) = ", b_q_weight.size(1),
1234
+ " is not divisible by tile_size = ", marlin::tile_size);
1235
+ int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
1236
+ TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
1237
+ ", actual_size_n = ", actual_size_n);
1238
+
1239
+ // Verify device and strides
1240
+ TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
1241
+ TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
1242
+
1243
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
1244
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
1245
+
1246
+ TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
1247
+ TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
1248
+
1249
+ // Alloc buffers
1250
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
1251
+ auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
1252
+ torch::Tensor c = torch::empty({size_m, size_n}, options);
1253
+
1254
+ // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
1255
+ // auto -1)
1256
+ int thread_k = -1;
1257
+ // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
1258
+ // auto -1)
1259
+ int thread_n = -1;
1260
+ // sms: number of SMs to use for the kernel (can usually be left as auto -1)
1261
+ int sms = -1;
1262
+
1263
+ // Detect groupsize and act_order
1264
+ int num_groups = -1;
1265
+ int group_size = -1;
1266
+
1267
+ int b_rank = b_scales.sizes().size();
1268
+ TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
1269
+ TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
1270
+ " is not size_n = ", size_n);
1271
+ // Channelwise only for FP8
1272
+ TORCH_CHECK(b_scales.size(0) == 1)
1273
+ num_groups = b_scales.size(0);
1274
+
1275
+ // Verify workspace size
1276
+ TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
1277
+ ", is not divisible by min_thread_n = ", marlin::min_thread_n);
1278
+ int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
1279
+ TORCH_CHECK(workspace.numel() >= min_workspace_size,
1280
+ "workspace.numel = ", workspace.numel(),
1281
+ " is below min_workspace_size = ", min_workspace_size);
1282
+
1283
+ int dev = a.get_device();
1284
+ if (a.scalar_type() == at::ScalarType::Half) {
1285
+ fp8_marlin::marlin_mm_f16i4<half>(
1286
+ a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
1287
+ b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
1288
+ workspace.data_ptr(), num_bits, num_groups, group_size, dev,
1289
+ at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
1290
+ marlin::max_par);
1291
+ } else if (a.scalar_type() == at::ScalarType::BFloat16) {
1292
+ fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
1293
+ a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
1294
+ c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
1295
+ size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
1296
+ dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
1297
+ marlin::max_par);
1298
+ } else {
1299
+ TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
1300
+ }
1301
+
1302
+ return c;
1303
+ }
1304
+
1305
+ #endif
1306
+
fp8/nvidia/quant_utils.cuh ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "../../../attention/attention_dtypes.h"
4
+ #include <assert.h>
5
+ #include <float.h>
6
+ #include <stdint.h>
7
+ #include <type_traits>
8
+
9
+ namespace vllm {
10
+ #ifndef USE_ROCM
11
+
12
+ namespace fp8 {
13
+ #ifdef ENABLE_FP8
14
+
15
+ #if 0 // Disable the following code to reduce the binary size.
16
+ template <typename Tout, typename Tin>
17
+ __inline__ __device__ Tout
18
+ vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
19
+ return x;
20
+ }
21
+
22
+ // fp8 -> half
23
+ template <>
24
+ __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
25
+ const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
26
+ __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
27
+ return res.x;
28
+ }
29
+
30
+ // fp8x2 -> half2
31
+ template <>
32
+ __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
33
+ const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
34
+ union {
35
+ uint16_t u16[2];
36
+ uint32_t u32;
37
+ } tmp;
38
+ __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
39
+ tmp.u16[0] = res.x;
40
+ tmp.u16[1] = res.y;
41
+ return tmp.u32;
42
+ }
43
+
44
+ // fp8x4 -> half2x2
45
+ template <>
46
+ __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
47
+ const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
48
+ union {
49
+ uint2 u32x2;
50
+ uint32_t u32[2];
51
+ } tmp;
52
+ tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
53
+ tmp.u32[1] =
54
+ vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
55
+ return tmp.u32x2;
56
+ }
57
+
58
+ // fp8x8 -> half2x4
59
+ template <>
60
+ __inline__ __device__ uint4 vec_conversion<uint4, uint2>(
61
+ const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
62
+ union {
63
+ uint4 u64x2;
64
+ uint2 u64[2];
65
+ } tmp;
66
+ tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
67
+ tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
68
+ return tmp.u64x2;
69
+ }
70
+
71
+ // fp8 -> __nv_bfloat16
72
+ template <>
73
+ __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
74
+ const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
75
+ // Note there is no direct convert function from fp8 to bf16.
76
+ // fp8 -> half
77
+ __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
78
+ // half -> float -> bf16
79
+ float tmp = half_to_float(res.x);
80
+ return __float2bfloat16(tmp);
81
+ }
82
+
83
+ // fp8x2 -> __nv_bfloat162
84
+ template <>
85
+ __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
86
+ const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
87
+ __nv_bfloat162 res;
88
+ res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
89
+ res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
90
+ return res;
91
+ }
92
+
93
+ // fp8x4 -> bf16_4_t
94
+ template <>
95
+ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
96
+ const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
97
+ bf16_4_t res;
98
+ res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
99
+ res.y =
100
+ vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
101
+ return res;
102
+ }
103
+
104
+ // fp8x8 -> bf16_8_t
105
+ template <>
106
+ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
107
+ const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
108
+ bf16_4_t tmp1, tmp2;
109
+ tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
110
+ tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
111
+ bf16_8_t res;
112
+ res.x = tmp1.x;
113
+ res.y = tmp1.y;
114
+ res.z = tmp2.x;
115
+ res.w = tmp2.y;
116
+ return res;
117
+ }
118
+
119
+ // fp8 -> float
120
+ template <>
121
+ __inline__ __device__ float
122
+ vec_conversion<float, uint8_t>(const uint8_t &a,
123
+ const __nv_fp8_interpretation_t fp8_type) {
124
+ // fp8 -> half
125
+ uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
126
+ // half -> float
127
+ return half_to_float(tmp);
128
+ }
129
+
130
+ // fp8x2 -> float2
131
+ template <>
132
+ __inline__ __device__ float2 vec_conversion<float2, uint16_t>(
133
+ const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
134
+ // fp8x2 -> half2
135
+ uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
136
+ // half2 -> float2
137
+ return half2_to_float2(tmp);
138
+ }
139
+
140
+ // fp8x4 -> float4
141
+ template <>
142
+ __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
143
+ const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
144
+ Float4_ res;
145
+ res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
146
+ res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
147
+ return res;
148
+ }
149
+
150
+ // fp8x8 -> float8
151
+ template <>
152
+ __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
153
+ const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
154
+ Float4_ tmp1, tmp2;
155
+ tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
156
+ tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
157
+ Float8_ res;
158
+ res.x = tmp1.x;
159
+ res.y = tmp1.y;
160
+ res.z = tmp2.x;
161
+ res.w = tmp2.y;
162
+ return res;
163
+ }
164
+
165
+ // half -> fp8
166
+ template <>
167
+ __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
168
+ const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
169
+ __half_raw tmp;
170
+ tmp.x = a;
171
+ __nv_fp8_storage_t res =
172
+ __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
173
+ return (uint8_t)res;
174
+ }
175
+
176
+ // bf16 -> fp8
177
+ template <>
178
+ __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
179
+ const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
180
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
181
+ assert(false);
182
+ #else
183
+ __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
184
+ __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
185
+ return (uint8_t)res;
186
+ #endif
187
+ }
188
+
189
+ // float -> fp8
190
+ template <>
191
+ __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
192
+ const float &a, const __nv_fp8_interpretation_t fp8_type) {
193
+ __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
194
+ return (uint8_t)res;
195
+ }
196
+
197
+ // fp8x4 -> float4
198
+ template <>
199
+ __inline__ __device__ float4 vec_conversion<float4, uint32_t>(
200
+ const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
201
+ Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
202
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
203
+ return res;
204
+ }
205
+
206
+ template <>
207
+ __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
208
+ const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
209
+ union {
210
+ half2 float16;
211
+ uint32_t uint32;
212
+ };
213
+
214
+ float16 = __float22half2_rn(a);
215
+ return uint32;
216
+ }
217
+
218
+ template <>
219
+ __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
220
+ const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
221
+ uint2 b;
222
+ float2 val;
223
+ val.x = a.x.x;
224
+ val.y = a.x.y;
225
+ b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
226
+
227
+ val.x = a.y.x;
228
+ val.y = a.y.y;
229
+ b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
230
+
231
+ return b;
232
+ }
233
+
234
+ template <>
235
+ __inline__ __device__ float4 vec_conversion<float4, Float4_>(
236
+ const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
237
+ float4 b;
238
+ b.x = a.x.x;
239
+ b.y = a.x.y;
240
+ b.z = a.y.x;
241
+ b.w = a.y.y;
242
+ return b;
243
+ }
244
+
245
+ template <>
246
+ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
247
+ const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
248
+ uint4 b;
249
+ b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
250
+ b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
251
+ b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
252
+ b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
253
+ return b;
254
+ }
255
+
256
+ template <>
257
+ __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
258
+ const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
259
+ __nv_bfloat162 b;
260
+ from_float(b, a);
261
+ return b;
262
+ }
263
+
264
+ template <>
265
+ __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
266
+ const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
267
+ bf16_4_t b;
268
+ from_float(b, a);
269
+ return b;
270
+ }
271
+
272
+ template <>
273
+ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
274
+ const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
275
+ bf16_8_t b;
276
+ from_float(b, a);
277
+ return b;
278
+ }
279
+ #endif
280
+
281
+ /* Scaled and vectorized conversions, for data exchange between high and low
282
+ precision domains Convention of the scale in API, e.g: FP8_data =
283
+ Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
284
+ Dequant(FP8) * scale => HP
285
+ */
286
+
287
+ template <typename Tout, typename Tin>
288
+ __inline__ __device__ Tout scaled_vec_conversion(
289
+ const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
290
+ return x;
291
+ }
292
+
293
+ // fp8 -> half
294
+ template <>
295
+ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
296
+ const uint8_t& a, const float scale,
297
+ const __nv_fp8_interpretation_t fp8_type) {
298
+ __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
299
+ return float_to_half(half_to_float(tmp.x) * scale);
300
+ }
301
+
302
+ // fp8x2 -> half2
303
+ template <>
304
+ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
305
+ const uint16_t& a, const float scale,
306
+ const __nv_fp8_interpretation_t fp8_type) {
307
+ union {
308
+ uint16_t u16[2];
309
+ uint32_t u32;
310
+ } tmp;
311
+ __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
312
+ tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
313
+ tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
314
+ return tmp.u32;
315
+ }
316
+
317
+ // fp8x4 -> half2x2
318
+ template <>
319
+ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
320
+ const uint32_t& a, const float scale,
321
+ const __nv_fp8_interpretation_t fp8_type) {
322
+ union {
323
+ uint2 u32x2;
324
+ uint32_t u32[2];
325
+ } tmp;
326
+ tmp.u32[0] =
327
+ scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
328
+ tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
329
+ scale, fp8_type);
330
+ return tmp.u32x2;
331
+ }
332
+
333
+ // fp8x8 -> half2x4
334
+ template <>
335
+ __inline__ __device__ uint4
336
+ scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
337
+ const __nv_fp8_interpretation_t fp8_type) {
338
+ union {
339
+ uint4 u64x2;
340
+ uint2 u64[2];
341
+ } tmp;
342
+ tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
343
+ tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
344
+ return tmp.u64x2;
345
+ }
346
+
347
+ // fp8 -> __nv_bfloat16
348
+ template <>
349
+ __inline__ __device__ __nv_bfloat16
350
+ scaled_vec_conversion<__nv_bfloat16, uint8_t>(
351
+ const uint8_t& a, const float scale,
352
+ const __nv_fp8_interpretation_t fp8_type) {
353
+ // Note there is no direct convert function from fp8 to bf16.
354
+ // fp8 -> half
355
+ __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
356
+ // half -> float -> bf16
357
+ float tmp = half_to_float(res.x);
358
+ return __float2bfloat16(tmp * scale);
359
+ }
360
+
361
+ // fp8x2 -> __nv_bfloat162
362
+ template <>
363
+ __inline__ __device__ __nv_bfloat162
364
+ scaled_vec_conversion<__nv_bfloat162, uint16_t>(
365
+ const uint16_t& a, const float scale,
366
+ const __nv_fp8_interpretation_t fp8_type) {
367
+ __nv_bfloat162 res;
368
+ res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
369
+ fp8_type);
370
+ res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
371
+ scale, fp8_type);
372
+ return res;
373
+ }
374
+
375
+ // fp8x4 -> bf16_4_t
376
+ template <>
377
+ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
378
+ const uint32_t& a, const float scale,
379
+ const __nv_fp8_interpretation_t fp8_type) {
380
+ bf16_4_t res;
381
+ res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
382
+ fp8_type);
383
+ res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
384
+ scale, fp8_type);
385
+ return res;
386
+ }
387
+
388
+ // fp8x8 -> bf16_8_t
389
+ template <>
390
+ __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
391
+ const uint2& a, const float scale,
392
+ const __nv_fp8_interpretation_t fp8_type) {
393
+ bf16_4_t tmp1, tmp2;
394
+ tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
395
+ tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
396
+ bf16_8_t res;
397
+ res.x = tmp1.x;
398
+ res.y = tmp1.y;
399
+ res.z = tmp2.x;
400
+ res.w = tmp2.y;
401
+ return res;
402
+ }
403
+
404
+ // fp8 -> float
405
+ template <>
406
+ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
407
+ const uint8_t& a, const float scale,
408
+ const __nv_fp8_interpretation_t fp8_type) {
409
+ // fp8 -> half
410
+ __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
411
+ uint16_t tmp = res.x;
412
+
413
+ // half -> float
414
+ return half_to_float(tmp) * scale;
415
+ }
416
+
417
+ // fp8x2 -> float2
418
+ template <>
419
+ __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
420
+ const uint16_t& a, const float scale,
421
+ const __nv_fp8_interpretation_t fp8_type) {
422
+ // fp8x2 -> half2
423
+ uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
424
+ // half2 -> float2
425
+ return half2_to_float2(tmp);
426
+ }
427
+
428
+ // fp8x4 -> float4
429
+ template <>
430
+ __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
431
+ const uint32_t& a, const float scale,
432
+ const __nv_fp8_interpretation_t fp8_type) {
433
+ Float4_ res;
434
+ res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
435
+ res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
436
+ fp8_type);
437
+ return res;
438
+ }
439
+
440
+ // fp8x8 -> float8
441
+ template <>
442
+ __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
443
+ const uint2& a, const float scale,
444
+ const __nv_fp8_interpretation_t fp8_type) {
445
+ Float4_ tmp1, tmp2;
446
+ tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
447
+ tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
448
+ Float8_ res;
449
+ res.x = tmp1.x;
450
+ res.y = tmp1.y;
451
+ res.z = tmp2.x;
452
+ res.w = tmp2.y;
453
+ return res;
454
+ }
455
+
456
+ // half -> fp8
457
+ template <>
458
+ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
459
+ const uint16_t& a, const float scale,
460
+ const __nv_fp8_interpretation_t fp8_type) {
461
+ __nv_fp8_storage_t res =
462
+ __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
463
+ return (uint8_t)res;
464
+ }
465
+
466
+ // bf16 -> fp8
467
+ template <>
468
+ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
469
+ const __nv_bfloat16& a, const float scale,
470
+ const __nv_fp8_interpretation_t fp8_type) {
471
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
472
+ assert(false);
473
+ #else
474
+ __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
475
+ __NV_SATFINITE, fp8_type);
476
+ return (uint8_t)res;
477
+ #endif
478
+ __builtin_unreachable(); // Suppress missing return statement warning
479
+ }
480
+
481
+ // float -> fp8
482
+ template <>
483
+ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
484
+ const float& a, const float scale,
485
+ const __nv_fp8_interpretation_t fp8_type) {
486
+ __nv_fp8_storage_t res =
487
+ __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
488
+ return (uint8_t)res;
489
+ }
490
+
491
+ // fp8x4 -> float4
492
+ template <>
493
+ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
494
+ const uint32_t& a, const float scale,
495
+ const __nv_fp8_interpretation_t fp8_type) {
496
+ Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
497
+ float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
498
+ return res;
499
+ }
500
+ #endif // ENABLE_FP8
501
+
502
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
503
+ __inline__ __device__ Tout convert(const Tin& x) {
504
+ #if 0 // Disable the following code to reduce the binary size.
505
+ if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
506
+ return vec_conversion<Tout, Tin>(x, __NV_E4M3);
507
+ } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
508
+ return vec_conversion<Tout, Tin>(x, __NV_E5M2);
509
+ }
510
+ #endif
511
+ assert(false);
512
+ __builtin_unreachable(); // Suppress missing return statement warning
513
+ }
514
+
515
+ template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
516
+ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
517
+ #ifdef ENABLE_FP8
518
+ if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
519
+ return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
520
+ } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
521
+ return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
522
+ }
523
+ #endif
524
+ assert(false);
525
+ __builtin_unreachable(); // Suppress missing return statement warning
526
+ }
527
+
528
+ // The following macro is used to dispatch the conversion function based on
529
+ // the data type of the key and value cache. The FN is a macro that calls a
530
+ // function with template<typename scalar_t, typename cache_t,
531
+ // Fp8KVCacheDataType kv_dt>.
532
+ #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
533
+ if (KV_DTYPE == "auto") { \
534
+ if (SRC_DTYPE == at::ScalarType::Float) { \
535
+ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
536
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
537
+ FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
538
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
539
+ FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
540
+ } else { \
541
+ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
542
+ } \
543
+ } else { \
544
+ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
545
+ if (SRC_DTYPE == at::ScalarType::Float) { \
546
+ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
547
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
548
+ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
549
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
550
+ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
551
+ } else { \
552
+ TORCH_CHECK(false, \
553
+ "Unsupported input type of kv cache: ", SRC_DTYPE); \
554
+ } \
555
+ } else if (KV_DTYPE == "fp8_e5m2") { \
556
+ if (SRC_DTYPE == at::ScalarType::Float) { \
557
+ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
558
+ } else if (SRC_DTYPE == at::ScalarType::Half) { \
559
+ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
560
+ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
561
+ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
562
+ } else { \
563
+ TORCH_CHECK(false, \
564
+ "Unsupported input type of kv cache: ", SRC_DTYPE); \
565
+ } \
566
+ } else { \
567
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
568
+ } \
569
+ }
570
+
571
+ } // namespace fp8
572
+ #endif // not USE_ROCM
573
+ } // namespace vllm
gptq_marlin/marlin.cuh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <cuda.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_runtime.h>
10
+ #include <iostream>
11
+
12
+ namespace marlin {
13
+
14
+ // Marlin params
15
+
16
+ // 8 warps are a good choice since every SM has 4 schedulers and having more
17
+ // than 1 warp per schedule allows some more latency hiding. At the same time,
18
+ // we want relatively few warps to have many registers per warp and small tiles.
19
+ static constexpr int default_threads = 256;
20
+
21
+ static constexpr int pipe_stages =
22
+ 4; // 4 pipeline stages fit into shared memory
23
+
24
+ static constexpr int min_thread_n = 64;
25
+ static constexpr int min_thread_k = 64;
26
+
27
+ static constexpr int tile_size = 16;
28
+ static constexpr int max_par = 16;
29
+
30
+ // Repack params
31
+ static constexpr int repack_stages = 8;
32
+
33
+ static constexpr int repack_threads = 256;
34
+
35
+ static constexpr int tile_k_size = tile_size;
36
+ static constexpr int tile_n_size = tile_k_size * 4;
37
+
38
+ // Helpers
39
+ template <typename T, int n>
40
+ struct Vec {
41
+ T elems[n];
42
+ __device__ T& operator[](int i) { return elems[i]; }
43
+ };
44
+
45
+ using I4 = Vec<int, 4>;
46
+
47
+ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
48
+
49
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
50
+ // No support for async
51
+ #else
52
+
53
+ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
54
+ bool pred = true) {
55
+ const int BYTES = 16;
56
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
57
+ asm volatile(
58
+ "{\n"
59
+ " .reg .pred p;\n"
60
+ " setp.ne.b32 p, %0, 0;\n"
61
+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
62
+ "}\n" ::"r"((int)pred),
63
+ "r"(smem), "l"(glob_ptr), "n"(BYTES));
64
+ }
65
+
66
+ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
67
+ const int BYTES = 16;
68
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
69
+ asm volatile(
70
+ "{\n"
71
+ " cp.async.cg.shared.global [%0], [%1], %2;\n"
72
+ "}\n" ::"r"(smem),
73
+ "l"(glob_ptr), "n"(BYTES));
74
+ }
75
+
76
+ __device__ inline void cp_async_fence() {
77
+ asm volatile("cp.async.commit_group;\n" ::);
78
+ }
79
+
80
+ template <int n>
81
+ __device__ inline void cp_async_wait() {
82
+ asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
83
+ }
84
+
85
+ #endif
86
+
87
+ } // namespace marlin
gptq_marlin/marlin_dtypes.cuh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #ifndef _data_types_cuh
3
+ #define _data_types_cuh
4
+ #include "marlin.cuh"
5
+ #include <cuda_fp16.h>
6
+ #include <cuda_bf16.h>
7
+
8
+ namespace marlin {
9
+
10
+ template <typename scalar_t>
11
+ class ScalarType {};
12
+
13
+ template <>
14
+ class ScalarType<half> {
15
+ public:
16
+ using scalar_t = half;
17
+ using scalar_t2 = half2;
18
+
19
+ // Matrix fragments for tensor core instructions; their precise layout is
20
+ // documented here:
21
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
22
+ using FragA = Vec<half2, 4>;
23
+ using FragB = Vec<half2, 2>;
24
+ using FragC = Vec<float, 4>;
25
+ using FragS = Vec<half2, 1>;
26
+ using FragZP = Vec<half2, 4>;
27
+
28
+ static __device__ float inline num2float(const half x) {
29
+ return __half2float(x);
30
+ }
31
+
32
+ static __device__ half2 inline num2num2(const half x) {
33
+ return __half2half2(x);
34
+ }
35
+
36
+ static __device__ half2 inline nums2num2(const half x1, const half x2) {
37
+ return __halves2half2(x1, x2);
38
+ }
39
+
40
+ static __host__ __device__ half inline float2num(const float x) {
41
+ return __float2half(x);
42
+ }
43
+ };
44
+
45
+ template <>
46
+ class ScalarType<nv_bfloat16> {
47
+ public:
48
+ using scalar_t = nv_bfloat16;
49
+ using scalar_t2 = nv_bfloat162;
50
+
51
+ using FragA = Vec<nv_bfloat162, 4>;
52
+ using FragB = Vec<nv_bfloat162, 2>;
53
+ using FragC = Vec<float, 4>;
54
+ using FragS = Vec<nv_bfloat162, 1>;
55
+ using FragZP = Vec<nv_bfloat162, 4>;
56
+
57
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
58
+ static __device__ float inline num2float(const nv_bfloat16 x) {
59
+ return __bfloat162float(x);
60
+ }
61
+
62
+ static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
63
+ return __bfloat162bfloat162(x);
64
+ }
65
+
66
+ static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
67
+ const nv_bfloat16 x2) {
68
+ return __halves2bfloat162(x1, x2);
69
+ }
70
+
71
+ static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
72
+ return __float2bfloat16(x);
73
+ }
74
+ #endif
75
+ };
76
+
77
+ } // namespace marlin
78
+
79
+ #endif