Add `scaled_(int|fp8)_quant` and `fp8_marlin_gemm`
Browse files- build.toml +31 -1
- compressed_tensors/int8_quant_kernels.cu +286 -0
- {ext-torch → core}/registration.h +0 -0
- dispatch_utils.h +35 -0
- ext-torch/__init__.py +107 -1
- ext-torch/torch_binding.cpp +41 -2
- ext-torch/torch_binding.h +43 -13
- fp8/amd/hip_float8.h +137 -0
- fp8/amd/hip_float8_impl.h +316 -0
- fp8/amd/quant_utils.cuh +577 -0
- fp8/common.cu +149 -0
- fp8/common.cuh +172 -0
- fp8/fp8_marlin.cu +1306 -0
- fp8/nvidia/quant_utils.cuh +573 -0
- gptq_marlin/marlin.cuh +87 -0
- gptq_marlin/marlin_dtypes.cuh +79 -0
build.toml
CHANGED
@@ -4,10 +4,11 @@ version = "0.0.1"
|
|
4 |
[torch]
|
5 |
name = "quantization"
|
6 |
src = [
|
7 |
-
"
|
8 |
"ext-torch/torch_binding.cpp",
|
9 |
"ext-torch/torch_binding.h"
|
10 |
]
|
|
|
11 |
pysrc = [
|
12 |
"ext-torch/__init__.py"
|
13 |
]
|
@@ -39,3 +40,32 @@ src = [
|
|
39 |
]
|
40 |
include = [ "." ]
|
41 |
depends = [ "cutlass", "torch" ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
[torch]
|
5 |
name = "quantization"
|
6 |
src = [
|
7 |
+
"core/registration.h",
|
8 |
"ext-torch/torch_binding.cpp",
|
9 |
"ext-torch/torch_binding.h"
|
10 |
]
|
11 |
+
include = [ "." ]
|
12 |
pysrc = [
|
13 |
"ext-torch/__init__.py"
|
14 |
]
|
|
|
40 |
]
|
41 |
include = [ "." ]
|
42 |
depends = [ "cutlass", "torch" ]
|
43 |
+
|
44 |
+
[kernel.fp8_common]
|
45 |
+
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
46 |
+
src = [
|
47 |
+
"fp8/common.cu",
|
48 |
+
"fp8/common.cuh",
|
49 |
+
"dispatch_utils.h"
|
50 |
+
]
|
51 |
+
include = [ "." ]
|
52 |
+
depends = [ "torch" ]
|
53 |
+
|
54 |
+
[kernel.fp8_marlin]
|
55 |
+
capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
56 |
+
src = [
|
57 |
+
"fp8/fp8_marlin.cu",
|
58 |
+
"gptq_marlin/marlin.cuh",
|
59 |
+
"gptq_marlin/marlin_dtypes.cuh",
|
60 |
+
]
|
61 |
+
#include = [ "." ]
|
62 |
+
depends = [ "torch" ]
|
63 |
+
|
64 |
+
[kernel.int8_common]
|
65 |
+
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
66 |
+
src = [
|
67 |
+
"compressed_tensors/int8_quant_kernels.cu",
|
68 |
+
"dispatch_utils.h"
|
69 |
+
]
|
70 |
+
include = [ "." ]
|
71 |
+
depends = [ "torch" ]
|
compressed_tensors/int8_quant_kernels.cu
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <torch/all.h>
|
3 |
+
#include <cmath>
|
4 |
+
|
5 |
+
#include "dispatch_utils.h"
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#include <cub/util_type.cuh>
|
9 |
+
#include <cub/cub.cuh>
|
10 |
+
#else
|
11 |
+
#include <hipcub/util_type.hpp>
|
12 |
+
#include <hipcub/hipcub.hpp>
|
13 |
+
#endif
|
14 |
+
|
15 |
+
static inline __device__ int8_t float_to_int8_rn(float x) {
|
16 |
+
#ifdef USE_ROCM
|
17 |
+
static constexpr auto i8_min =
|
18 |
+
static_cast<float>(std::numeric_limits<int8_t>::min());
|
19 |
+
static constexpr auto i8_max =
|
20 |
+
static_cast<float>(std::numeric_limits<int8_t>::max());
|
21 |
+
|
22 |
+
// To match the rounding mode of CUDA, we use nearbyint.
|
23 |
+
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
24 |
+
// If that changes in the future, we may need to set the rounding mode
|
25 |
+
// explicitly, either at runtime or compile time.
|
26 |
+
float dst = std::nearbyint(x);
|
27 |
+
|
28 |
+
// saturate
|
29 |
+
dst = std::clamp(dst, i8_min, i8_max);
|
30 |
+
return static_cast<int8_t>(dst);
|
31 |
+
#else
|
32 |
+
// CUDA path
|
33 |
+
uint32_t dst;
|
34 |
+
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
35 |
+
return reinterpret_cast<const int8_t&>(dst);
|
36 |
+
#endif
|
37 |
+
}
|
38 |
+
|
39 |
+
static inline __device__ int32_t float_to_int32_rn(float x) {
|
40 |
+
#ifdef USE_ROCM
|
41 |
+
// int32_max is not exactly representable as float.
|
42 |
+
// Therefore, we need to be careful and manually return int32_max on overflow.
|
43 |
+
// For symmetry, we also do the same for int32_min, even though it is exactly
|
44 |
+
// representable as float and the conversion should be exact.
|
45 |
+
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
|
46 |
+
static constexpr auto i32_min_f = static_cast<float>(i32_min);
|
47 |
+
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
|
48 |
+
static constexpr auto i32_max_f = static_cast<float>(i32_max);
|
49 |
+
|
50 |
+
// To match the rounding mode of CUDA, we use nearbyint.
|
51 |
+
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
52 |
+
// If that changes in the future, we may need to set the rounding mode
|
53 |
+
// explicitly, either at runtime or compile time.
|
54 |
+
float dst = std::nearbyint(x);
|
55 |
+
|
56 |
+
// saturate on the higher end.
|
57 |
+
if (dst >= i32_max_f) {
|
58 |
+
return i32_max;
|
59 |
+
}
|
60 |
+
// saturate on the lower end.
|
61 |
+
if (dst <= i32_min_f) {
|
62 |
+
return i32_min;
|
63 |
+
}
|
64 |
+
|
65 |
+
return static_cast<int32_t>(dst);
|
66 |
+
#else
|
67 |
+
// CUDA path
|
68 |
+
uint32_t dst;
|
69 |
+
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
70 |
+
return reinterpret_cast<const int32_t&>(dst);
|
71 |
+
#endif
|
72 |
+
}
|
73 |
+
|
74 |
+
static inline __device__ int8_t int32_to_int8(int32_t x) {
|
75 |
+
#ifdef USE_ROCM
|
76 |
+
static constexpr auto i8_min =
|
77 |
+
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
|
78 |
+
static constexpr auto i8_max =
|
79 |
+
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
80 |
+
|
81 |
+
// saturate
|
82 |
+
int32_t dst = std::clamp(x, i8_min, i8_max);
|
83 |
+
return static_cast<int8_t>(dst);
|
84 |
+
#else
|
85 |
+
// CUDA path
|
86 |
+
uint32_t dst;
|
87 |
+
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
|
88 |
+
return reinterpret_cast<const int8_t&>(dst);
|
89 |
+
#endif
|
90 |
+
}
|
91 |
+
|
92 |
+
namespace vllm {
|
93 |
+
|
94 |
+
template <typename scalar_t, typename scale_type>
|
95 |
+
__global__ void static_scaled_int8_quant_kernel(
|
96 |
+
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
97 |
+
scale_type const* scale_ptr, const int hidden_size) {
|
98 |
+
int const tid = threadIdx.x;
|
99 |
+
int64_t const token_idx = blockIdx.x;
|
100 |
+
scale_type const scale = *scale_ptr;
|
101 |
+
|
102 |
+
// Must be performed using 64-bit math to avoid integer overflow.
|
103 |
+
out += token_idx * hidden_size;
|
104 |
+
input += token_idx * hidden_size;
|
105 |
+
|
106 |
+
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
107 |
+
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
|
108 |
+
}
|
109 |
+
}
|
110 |
+
|
111 |
+
template <typename scalar_t, typename scale_type, typename azp_type>
|
112 |
+
__global__ void static_scaled_int8_azp_quant_kernel(
|
113 |
+
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
114 |
+
scale_type const* scale_ptr, azp_type const* azp_ptr,
|
115 |
+
const int hidden_size) {
|
116 |
+
int const tid = threadIdx.x;
|
117 |
+
int64_t const token_idx = blockIdx.x;
|
118 |
+
scale_type const scale = *scale_ptr;
|
119 |
+
azp_type const azp = *azp_ptr;
|
120 |
+
|
121 |
+
// Must be performed using 64-bit math to avoid integer overflow.
|
122 |
+
out += token_idx * hidden_size;
|
123 |
+
input += token_idx * hidden_size;
|
124 |
+
|
125 |
+
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
126 |
+
auto const val = static_cast<float>(input[i]);
|
127 |
+
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
|
128 |
+
out[i] = quant_val;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
template <typename scalar_t, typename scale_type>
|
133 |
+
__global__ void dynamic_scaled_int8_quant_kernel(
|
134 |
+
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
135 |
+
scale_type* scale, const int hidden_size) {
|
136 |
+
int const tid = threadIdx.x;
|
137 |
+
int64_t const token_idx = blockIdx.x;
|
138 |
+
float absmax_val = 0.0f;
|
139 |
+
float const zero = 0.0f;
|
140 |
+
|
141 |
+
// Must be performed using 64-bit math to avoid integer overflow.
|
142 |
+
out += token_idx * hidden_size;
|
143 |
+
input += token_idx * hidden_size;
|
144 |
+
|
145 |
+
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
146 |
+
float val = static_cast<float>(input[i]);
|
147 |
+
val = val > zero ? val : -val;
|
148 |
+
absmax_val = val > absmax_val ? val : absmax_val;
|
149 |
+
}
|
150 |
+
|
151 |
+
using BlockReduce = cub::BlockReduce<float, 1024>;
|
152 |
+
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
153 |
+
float const block_absmax_val_maybe =
|
154 |
+
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
155 |
+
__shared__ float block_absmax_val;
|
156 |
+
if (tid == 0) {
|
157 |
+
block_absmax_val = block_absmax_val_maybe;
|
158 |
+
scale[token_idx] = block_absmax_val / 127.0f;
|
159 |
+
}
|
160 |
+
__syncthreads();
|
161 |
+
|
162 |
+
float const tmp_scale = 127.0f / block_absmax_val;
|
163 |
+
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
164 |
+
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
|
165 |
+
}
|
166 |
+
}
|
167 |
+
|
168 |
+
template <typename scalar_t, typename scale_type, typename azp_type>
|
169 |
+
__global__ void dynamic_scaled_int8_azp_quant_kernel(
|
170 |
+
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
171 |
+
scale_type* scale, azp_type* azp, const int hidden_size) {
|
172 |
+
int64_t const token_idx = blockIdx.x;
|
173 |
+
|
174 |
+
// Must be performed using 64-bit math to avoid integer overflow.
|
175 |
+
out += token_idx * hidden_size;
|
176 |
+
input += token_idx * hidden_size;
|
177 |
+
|
178 |
+
// Scan for the min and max value for this token
|
179 |
+
float max_val = std::numeric_limits<float>::min();
|
180 |
+
float min_val = std::numeric_limits<float>::max();
|
181 |
+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
182 |
+
auto val = static_cast<float>(input[i]);
|
183 |
+
max_val = std::max(max_val, val);
|
184 |
+
min_val = std::min(min_val, val);
|
185 |
+
}
|
186 |
+
|
187 |
+
// Reduce the max and min values across the block
|
188 |
+
using BlockReduce = cub::BlockReduce<float, 1024>;
|
189 |
+
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
190 |
+
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
|
191 |
+
__syncthreads(); // Make sure min doesn't mess with max shared memory
|
192 |
+
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
|
193 |
+
|
194 |
+
__shared__ scale_type scale_sh;
|
195 |
+
__shared__ azp_type azp_sh;
|
196 |
+
|
197 |
+
// Compute the scale and zero point and store them, only on the first thread
|
198 |
+
if (threadIdx.x == 0) {
|
199 |
+
float const scale_val = (max_val - min_val) / 255.0f;
|
200 |
+
// Use rounding to even (same as torch.round)
|
201 |
+
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
|
202 |
+
auto const azp_val = static_cast<azp_type>(azp_float);
|
203 |
+
|
204 |
+
// Store the scale and azp into shared and global
|
205 |
+
scale[token_idx] = scale_sh = scale_val;
|
206 |
+
azp[token_idx] = azp_sh = azp_val;
|
207 |
+
}
|
208 |
+
|
209 |
+
// Wait for the scale and azp to be computed
|
210 |
+
__syncthreads();
|
211 |
+
|
212 |
+
float const scale_val = scale_sh;
|
213 |
+
azp_type const azp_val = azp_sh;
|
214 |
+
|
215 |
+
// Quantize the values
|
216 |
+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
217 |
+
auto const val = static_cast<float>(input[i]);
|
218 |
+
auto const quant_val =
|
219 |
+
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
|
220 |
+
out[i] = quant_val;
|
221 |
+
}
|
222 |
+
}
|
223 |
+
|
224 |
+
} // namespace vllm
|
225 |
+
|
226 |
+
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
227 |
+
torch::Tensor const& input, // [..., hidden_size]
|
228 |
+
torch::Tensor const& scale,
|
229 |
+
c10::optional<torch::Tensor> const& azp) {
|
230 |
+
TORCH_CHECK(input.is_contiguous());
|
231 |
+
TORCH_CHECK(out.is_contiguous());
|
232 |
+
TORCH_CHECK(scale.numel() == 1);
|
233 |
+
TORCH_CHECK(!azp || azp->numel() == 1);
|
234 |
+
|
235 |
+
int const hidden_size = input.size(-1);
|
236 |
+
int const num_tokens = input.numel() / hidden_size;
|
237 |
+
dim3 const grid(num_tokens);
|
238 |
+
dim3 const block(std::min(hidden_size, 1024));
|
239 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
240 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
241 |
+
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
242 |
+
if (!azp) {
|
243 |
+
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
244 |
+
<<<grid, block, 0, stream>>>(
|
245 |
+
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
246 |
+
scale.data_ptr<float>(), hidden_size);
|
247 |
+
} else {
|
248 |
+
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
|
249 |
+
<<<grid, block, 0, stream>>>(
|
250 |
+
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
251 |
+
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
|
252 |
+
hidden_size);
|
253 |
+
}
|
254 |
+
});
|
255 |
+
}
|
256 |
+
|
257 |
+
void dynamic_scaled_int8_quant(
|
258 |
+
torch::Tensor& out, // [..., hidden_size]
|
259 |
+
torch::Tensor const& input, // [..., hidden_size]
|
260 |
+
torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
|
261 |
+
TORCH_CHECK(input.is_contiguous());
|
262 |
+
TORCH_CHECK(out.is_contiguous());
|
263 |
+
TORCH_CHECK(scales.is_contiguous());
|
264 |
+
TORCH_CHECK(!azp || azp->is_contiguous());
|
265 |
+
|
266 |
+
int const hidden_size = input.size(-1);
|
267 |
+
int const num_tokens = input.numel() / hidden_size;
|
268 |
+
dim3 const grid(num_tokens);
|
269 |
+
dim3 const block(std::min(hidden_size, 1024));
|
270 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
271 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
272 |
+
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
273 |
+
if (!azp) {
|
274 |
+
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
275 |
+
<<<grid, block, 0, stream>>>(
|
276 |
+
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
277 |
+
scales.data_ptr<float>(), hidden_size);
|
278 |
+
} else {
|
279 |
+
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
|
280 |
+
<<<grid, block, 0, stream>>>(
|
281 |
+
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
282 |
+
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
|
283 |
+
hidden_size);
|
284 |
+
}
|
285 |
+
});
|
286 |
+
}
|
{ext-torch → core}/registration.h
RENAMED
File without changes
|
dispatch_utils.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
4 |
+
*/
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <torch/all.h>
|
8 |
+
|
9 |
+
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
12 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
+
|
14 |
+
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
16 |
+
|
17 |
+
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
18 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
19 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
20 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
21 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
22 |
+
|
23 |
+
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
24 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
25 |
+
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
26 |
+
|
27 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
28 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
29 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
30 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
31 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
32 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
33 |
+
|
34 |
+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
35 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
ext-torch/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Optional
|
2 |
|
3 |
import torch
|
4 |
|
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
|
42 |
|
43 |
return out
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
|
3 |
import torch
|
4 |
|
|
|
42 |
|
43 |
return out
|
44 |
|
45 |
+
# fp8
|
46 |
+
def scaled_fp8_quant(
|
47 |
+
input: torch.Tensor,
|
48 |
+
scale: Optional[torch.Tensor] = None,
|
49 |
+
num_token_padding: Optional[int] = None,
|
50 |
+
scale_ub: Optional[torch.Tensor] = None,
|
51 |
+
use_per_token_if_dynamic: bool = False,
|
52 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
53 |
+
"""
|
54 |
+
Quantize input tensor to FP8 and return quantized tensor and scale.
|
55 |
+
|
56 |
+
This function supports both static and dynamic quantization: If you
|
57 |
+
provide the scale, it will use static scaling and if you omit it,
|
58 |
+
the scale will be determined dynamically. The function also allows
|
59 |
+
optional padding of the output tensors for downstream kernels that
|
60 |
+
will benefit from padding.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
input: The input tensor to be quantized to FP8
|
64 |
+
scale: Optional scaling factor for the FP8 quantization
|
65 |
+
scale_ub: Optional upper bound for scaling factor in dynamic
|
66 |
+
per token case
|
67 |
+
num_token_padding: If specified, pad the first dimension
|
68 |
+
of the output to at least this value.
|
69 |
+
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
70 |
+
in the dynamic quantization case.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
74 |
+
scaling factor.
|
75 |
+
"""
|
76 |
+
# This code assumes batch_dim and num_tokens are flattened
|
77 |
+
assert (input.ndim == 2)
|
78 |
+
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
79 |
+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
80 |
+
#out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
81 |
+
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
82 |
+
out_dtype = torch.float8_e4m3fn
|
83 |
+
if num_token_padding:
|
84 |
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
85 |
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
86 |
+
|
87 |
+
if scale is None:
|
88 |
+
if use_per_token_if_dynamic:
|
89 |
+
scale = torch.empty((shape[0], 1),
|
90 |
+
device=input.device,
|
91 |
+
dtype=torch.float32)
|
92 |
+
ops.dynamic_per_token_scaled_fp8_quant(
|
93 |
+
output, input, scale, scale_ub)
|
94 |
+
else:
|
95 |
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
96 |
+
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
97 |
+
else:
|
98 |
+
# num_token_padding not implemented for this case
|
99 |
+
assert (scale.numel() == 1 or num_token_padding is None)
|
100 |
+
ops.static_scaled_fp8_quant(output, input, scale)
|
101 |
+
|
102 |
+
return output, scale
|
103 |
+
|
104 |
+
# int8
|
105 |
+
def scaled_int8_quant(
|
106 |
+
input: torch.Tensor,
|
107 |
+
scale: Optional[torch.Tensor] = None,
|
108 |
+
azp: Optional[torch.Tensor] = None,
|
109 |
+
symmetric: bool = True
|
110 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
111 |
+
"""
|
112 |
+
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
input: The input tensor to be quantized to int8.
|
116 |
+
scale: Optional scaling factor for the int8 quantization.
|
117 |
+
When not provided, we invoke dynamic-per-token quantization.
|
118 |
+
azp: Optional zero-point for the int8 quantization.
|
119 |
+
Must be provided for asymmetric quantization if `scale` is provided.
|
120 |
+
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
124 |
+
"""
|
125 |
+
output = torch.empty_like(input, dtype=torch.int8)
|
126 |
+
if scale is not None:
|
127 |
+
# static-per-tensor quantization.
|
128 |
+
assert symmetric == (
|
129 |
+
azp is
|
130 |
+
None), "azp must only be provided for asymmetric quantization."
|
131 |
+
ops.static_scaled_int8_quant(output, input, scale, azp)
|
132 |
+
return output, scale, azp
|
133 |
+
|
134 |
+
# dynamic-per-token quantization.
|
135 |
+
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
136 |
+
device=input.device,
|
137 |
+
dtype=torch.float32)
|
138 |
+
input_azp = None if symmetric else torch.empty_like(input_scales,
|
139 |
+
dtype=torch.int32)
|
140 |
+
ops.dynamic_scaled_int8_quant(output, input, input_scales,
|
141 |
+
input_azp)
|
142 |
+
return output, input_scales, input_azp
|
143 |
+
|
144 |
+
# fp8 marlin
|
145 |
+
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
146 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
147 |
+
num_bits: int, size_m: int, size_n: int,
|
148 |
+
size_k: int) -> torch.Tensor:
|
149 |
+
return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
150 |
+
num_bits, size_m, size_n, size_k)
|
ext-torch/torch_binding.cpp
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
#include <torch/library.h>
|
2 |
|
3 |
-
#include "registration.h"
|
4 |
#include "torch_binding.h"
|
5 |
|
6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
-
|
8 |
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
9 |
// quantization, as well as bias
|
10 |
ops.def(
|
@@ -27,6 +26,46 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
27 |
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
28 |
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
}
|
31 |
|
32 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
1 |
#include <torch/library.h>
|
2 |
|
3 |
+
#include "core/registration.h"
|
4 |
#include "torch_binding.h"
|
5 |
|
6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
|
7 |
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
8 |
// quantization, as well as bias
|
9 |
ops.def(
|
|
|
26 |
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
27 |
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
28 |
|
29 |
+
// Compute FP8 quantized tensor for given scaling factor.
|
30 |
+
ops.def(
|
31 |
+
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
32 |
+
"()");
|
33 |
+
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
34 |
+
|
35 |
+
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
36 |
+
ops.def(
|
37 |
+
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
|
38 |
+
"-> "
|
39 |
+
"()");
|
40 |
+
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
41 |
+
|
42 |
+
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
43 |
+
ops.def(
|
44 |
+
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
|
45 |
+
"Tensor! scale, Tensor? scale_ub) -> "
|
46 |
+
"()");
|
47 |
+
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
48 |
+
&dynamic_per_token_scaled_fp8_quant);
|
49 |
+
|
50 |
+
// Compute int8 quantized tensor for given scaling factor.
|
51 |
+
ops.def(
|
52 |
+
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
53 |
+
"Tensor? azp) -> ()");
|
54 |
+
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
|
55 |
+
|
56 |
+
// Compute int8 quantized tensor and scaling factor
|
57 |
+
ops.def(
|
58 |
+
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
|
59 |
+
"Tensor!? azp) -> ()");
|
60 |
+
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
61 |
+
&dynamic_scaled_int8_quant);
|
62 |
+
|
63 |
+
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
64 |
+
ops.def(
|
65 |
+
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
66 |
+
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
67 |
+
"SymInt size_k) -> Tensor");
|
68 |
+
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
69 |
}
|
70 |
|
71 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
ext-torch/torch_binding.h
CHANGED
@@ -2,17 +2,47 @@
|
|
2 |
|
3 |
#include <torch/torch.h>
|
4 |
|
5 |
-
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
6 |
-
|
7 |
-
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
8 |
-
torch::Tensor const& b, torch::Tensor const& a_scales,
|
9 |
-
torch::Tensor const& b_scales,
|
10 |
-
c10::optional<torch::Tensor> const& bias);
|
11 |
-
|
12 |
-
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
13 |
-
torch::Tensor const& b,
|
14 |
-
torch::Tensor const& a_scales,
|
15 |
-
torch::Tensor const& b_scales,
|
16 |
-
torch::Tensor const& azp_adj,
|
17 |
-
c10::optional<torch::Tensor> const& azp,
|
18 |
c10::optional<torch::Tensor> const& bias);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
#include <torch/torch.h>
|
4 |
|
5 |
+
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
6 |
+
|
7 |
+
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
8 |
+
torch::Tensor const& b, torch::Tensor const& a_scales,
|
9 |
+
torch::Tensor const& b_scales,
|
10 |
+
c10::optional<torch::Tensor> const& bias);
|
11 |
+
|
12 |
+
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
13 |
+
torch::Tensor const& b,
|
14 |
+
torch::Tensor const& a_scales,
|
15 |
+
torch::Tensor const& b_scales,
|
16 |
+
torch::Tensor const& azp_adj,
|
17 |
+
c10::optional<torch::Tensor> const& azp,
|
18 |
c10::optional<torch::Tensor> const& bias);
|
19 |
+
|
20 |
+
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
21 |
+
torch::Tensor const& scale,
|
22 |
+
c10::optional<torch::Tensor> const& azp);
|
23 |
+
|
24 |
+
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
25 |
+
torch::Tensor& scales,
|
26 |
+
c10::optional<torch::Tensor> const& azp);
|
27 |
+
|
28 |
+
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
29 |
+
torch::Tensor b_gptq_qzeros,
|
30 |
+
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
31 |
+
bool use_exllama, int64_t bit);
|
32 |
+
|
33 |
+
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
34 |
+
|
35 |
+
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
36 |
+
torch::Tensor const& scale);
|
37 |
+
|
38 |
+
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
39 |
+
torch::Tensor& scale);
|
40 |
+
|
41 |
+
void dynamic_per_token_scaled_fp8_quant(
|
42 |
+
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
43 |
+
c10::optional<torch::Tensor> const& scale_ub);
|
44 |
+
|
45 |
+
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
46 |
+
torch::Tensor& b_scales, torch::Tensor& workspace,
|
47 |
+
int64_t num_bits, int64_t size_m, int64_t size_n,
|
48 |
+
int64_t size_k);
|
fp8/amd/hip_float8.h
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#ifdef __HIPCC__
|
4 |
+
#include <hip/hip_runtime.h>
|
5 |
+
#else
|
6 |
+
#include <type_traits>
|
7 |
+
#include <stdint.h>
|
8 |
+
#include <math.h>
|
9 |
+
#include <iostream>
|
10 |
+
#endif
|
11 |
+
|
12 |
+
#include "hip_float8_impl.h"
|
13 |
+
|
14 |
+
struct alignas(1) hip_fp8 {
|
15 |
+
struct from_bits_t {};
|
16 |
+
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
17 |
+
return from_bits_t();
|
18 |
+
}
|
19 |
+
uint8_t data;
|
20 |
+
|
21 |
+
hip_fp8() = default;
|
22 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
23 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
24 |
+
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
25 |
+
: data(v) {}
|
26 |
+
|
27 |
+
#ifdef __HIP__MI300__
|
28 |
+
// NOTE: ON-DEVICE... always optimal bias
|
29 |
+
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
30 |
+
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
31 |
+
|
32 |
+
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
33 |
+
: hip_fp8(static_cast<float>(v)) {}
|
34 |
+
|
35 |
+
// Host only implementation using s/w simulation
|
36 |
+
explicit HIP_FP8_HOST
|
37 |
+
#else // __HIP__MI300__
|
38 |
+
// both Host and DEVICE for non-MI300 using s/w simulation
|
39 |
+
explicit HIP_FP8_HOST_DEVICE
|
40 |
+
#endif // __HIP__MI300__
|
41 |
+
hip_fp8(float v) {
|
42 |
+
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
43 |
+
true /*clip*/>(v);
|
44 |
+
}
|
45 |
+
|
46 |
+
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
47 |
+
: hip_fp8(static_cast<float>(v)) {}
|
48 |
+
|
49 |
+
#ifdef __HIP__MI300__
|
50 |
+
// upcast using device specific intrinsic
|
51 |
+
explicit inline HIP_FP8_DEVICE operator float() const {
|
52 |
+
float fval;
|
53 |
+
uint32_t i32val = static_cast<uint32_t>(data);
|
54 |
+
|
55 |
+
// upcast
|
56 |
+
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
57 |
+
: "=v"(fval)
|
58 |
+
: "v"(i32val));
|
59 |
+
|
60 |
+
return fval;
|
61 |
+
}
|
62 |
+
|
63 |
+
explicit inline HIP_FP8_HOST operator float() const
|
64 |
+
#else // __HIP__MI300__
|
65 |
+
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
66 |
+
#endif // __HIP__MI300__
|
67 |
+
{
|
68 |
+
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
69 |
+
data);
|
70 |
+
}
|
71 |
+
};
|
72 |
+
|
73 |
+
namespace std {
|
74 |
+
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
75 |
+
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
76 |
+
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
77 |
+
} // namespace std
|
78 |
+
|
79 |
+
// Special operator overloading
|
80 |
+
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
81 |
+
return os << float(f8);
|
82 |
+
}
|
83 |
+
|
84 |
+
// all + operator overloading with mixed types
|
85 |
+
// mixed types, always converts to f32, does computation in f32, and returns
|
86 |
+
// float
|
87 |
+
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
88 |
+
return (fa + float(b));
|
89 |
+
}
|
90 |
+
|
91 |
+
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
92 |
+
return (float(a) + fb);
|
93 |
+
}
|
94 |
+
|
95 |
+
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
96 |
+
return hip_fp8(float(a) + float(b));
|
97 |
+
}
|
98 |
+
|
99 |
+
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
100 |
+
return a = hip_fp8(float(a) + float(b));
|
101 |
+
}
|
102 |
+
|
103 |
+
// overloading multiplication, always returns float,
|
104 |
+
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
105 |
+
return float(a) * float(b);
|
106 |
+
}
|
107 |
+
|
108 |
+
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
109 |
+
return (a * float(b));
|
110 |
+
}
|
111 |
+
|
112 |
+
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
113 |
+
return (float(a) * b);
|
114 |
+
}
|
115 |
+
|
116 |
+
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
117 |
+
return ((float)a * float(b));
|
118 |
+
}
|
119 |
+
|
120 |
+
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
121 |
+
return ((float)a * float(b));
|
122 |
+
}
|
123 |
+
|
124 |
+
// overloading for compare
|
125 |
+
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
126 |
+
return (a.data == b.data);
|
127 |
+
}
|
128 |
+
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
129 |
+
return (a.data != b.data);
|
130 |
+
}
|
131 |
+
|
132 |
+
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
133 |
+
return static_cast<float>(a) >= static_cast<float>(b);
|
134 |
+
}
|
135 |
+
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
136 |
+
return static_cast<float>(a) > static_cast<float>(b);
|
137 |
+
}
|
fp8/amd/hip_float8_impl.h
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#if defined(__HIPCC__) && \
|
4 |
+
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
5 |
+
#define __HIP__MI300__
|
6 |
+
#endif
|
7 |
+
|
8 |
+
#ifdef __HIPCC__
|
9 |
+
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
10 |
+
#define HIP_FP8_HOST __host__
|
11 |
+
#define HIP_FP8_DEVICE __device__
|
12 |
+
#else
|
13 |
+
#define HIP_FP8_HOST_DEVICE
|
14 |
+
#define HIP_FP8_HOST
|
15 |
+
#define HIP_FP8_DEVICE
|
16 |
+
#endif
|
17 |
+
|
18 |
+
namespace hip_fp8_impl {
|
19 |
+
|
20 |
+
#ifdef __HIP__MI300__
|
21 |
+
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
22 |
+
uint8_t i8data;
|
23 |
+
union {
|
24 |
+
float fval;
|
25 |
+
uint32_t i32val;
|
26 |
+
uint8_t i8val[4]; // NOTE: not endian independent
|
27 |
+
} val;
|
28 |
+
|
29 |
+
uint32_t ival = 0;
|
30 |
+
val.fval = v;
|
31 |
+
|
32 |
+
if ((val.i32val & 0x7F800000) !=
|
33 |
+
0x7F800000) { /// propagate NAN/INF, no clipping
|
34 |
+
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
35 |
+
}
|
36 |
+
|
37 |
+
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
38 |
+
false); // false -> WORD0
|
39 |
+
val.i32val = ival;
|
40 |
+
i8data = val.i8val[0];
|
41 |
+
|
42 |
+
return i8data;
|
43 |
+
}
|
44 |
+
#endif // __HIP__MI300__
|
45 |
+
|
46 |
+
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
47 |
+
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
48 |
+
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
49 |
+
#endif
|
50 |
+
|
51 |
+
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
52 |
+
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
53 |
+
uint32_t rng = 0) {
|
54 |
+
#ifdef __HIPCC__
|
55 |
+
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
56 |
+
#else
|
57 |
+
constexpr bool is_half = false;
|
58 |
+
#endif
|
59 |
+
constexpr bool is_float = std::is_same<T, float>::value;
|
60 |
+
static_assert(wm + we == 7, "wm+we==7");
|
61 |
+
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
62 |
+
|
63 |
+
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
64 |
+
uint32_t x;
|
65 |
+
if (sizeof(T) == 4) {
|
66 |
+
x = reinterpret_cast<uint32_t&>(_x);
|
67 |
+
} else {
|
68 |
+
x = reinterpret_cast<uint16_t&>(_x);
|
69 |
+
}
|
70 |
+
|
71 |
+
uint32_t head, mantissa;
|
72 |
+
int exponent, bias;
|
73 |
+
uint32_t sign;
|
74 |
+
|
75 |
+
if (sizeof(T) == 4) {
|
76 |
+
head = x & 0xFF800000;
|
77 |
+
mantissa = x & 0x7FFFFF;
|
78 |
+
exponent = (head >> 23) & 0xFF;
|
79 |
+
sign = head >> 31;
|
80 |
+
bias = 127;
|
81 |
+
} else {
|
82 |
+
head = x & 0xFC00;
|
83 |
+
mantissa = x & 0x3FF;
|
84 |
+
exponent = (head >> 10) & 0x1F;
|
85 |
+
sign = head >> 15;
|
86 |
+
bias = 15;
|
87 |
+
}
|
88 |
+
|
89 |
+
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
90 |
+
|
91 |
+
// Deal with inf and NaNs
|
92 |
+
if (negative_zero_nan) {
|
93 |
+
if (sizeof(T) == 4) {
|
94 |
+
if ((x & 0x7F800000) == 0x7F800000) {
|
95 |
+
return 0x80;
|
96 |
+
}
|
97 |
+
} else {
|
98 |
+
// if(__hisinf(x) || __hisnan(x))
|
99 |
+
if ((x & 0x7C00) == 0x7C00) {
|
100 |
+
return 0x80;
|
101 |
+
}
|
102 |
+
}
|
103 |
+
} else {
|
104 |
+
if (sizeof(T) == 4) {
|
105 |
+
if ((x & 0x7F800000) == 0x7F800000) {
|
106 |
+
return signed_inf + (mantissa != 0 ? 1 : 0);
|
107 |
+
}
|
108 |
+
} else {
|
109 |
+
if ((x & 0x7C00) == 0x7C00) {
|
110 |
+
return signed_inf + (mantissa != 0 ? 1 : 0);
|
111 |
+
}
|
112 |
+
}
|
113 |
+
}
|
114 |
+
if (x == 0) {
|
115 |
+
return 0;
|
116 |
+
}
|
117 |
+
|
118 |
+
// First need to check if it is normal or denorm as there is a difference of
|
119 |
+
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
120 |
+
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
121 |
+
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
122 |
+
// need to check whether there is carry and adjust exponent and mantissa again
|
123 |
+
|
124 |
+
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
125 |
+
// bits
|
126 |
+
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
127 |
+
const int f8_denormal_act_exponent =
|
128 |
+
1 - f8_bias; // actual exponent of f8 denormal
|
129 |
+
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
130 |
+
// f8_exponent is the converted f8 exponent with bias encoding
|
131 |
+
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
132 |
+
// the difference needs to be adjusted and mantissa shifted
|
133 |
+
int act_exponent, f8_exponent, exponent_diff;
|
134 |
+
|
135 |
+
if (exponent == 0) { // fp32/fp16 is in denormal.
|
136 |
+
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
137 |
+
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
138 |
+
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
139 |
+
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
140 |
+
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
141 |
+
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
142 |
+
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
143 |
+
act_exponent = exponent - bias + 1;
|
144 |
+
exponent_diff =
|
145 |
+
f8_denormal_act_exponent -
|
146 |
+
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
147 |
+
} else { // fp32/fp16 is normal with implicit 1
|
148 |
+
act_exponent = exponent - bias;
|
149 |
+
if (act_exponent <= f8_denormal_act_exponent) {
|
150 |
+
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
151 |
+
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
152 |
+
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
153 |
+
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
154 |
+
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
155 |
+
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
156 |
+
} else { // both fp32/fp16 and f8 are in normal range
|
157 |
+
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
158 |
+
// difference for this case, act_exponent could be
|
159 |
+
// larger. Just that it does not need shift mantissa
|
160 |
+
}
|
161 |
+
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
162 |
+
}
|
163 |
+
|
164 |
+
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
165 |
+
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
166 |
+
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
167 |
+
done before we shift right as shift right could rip off some residual part
|
168 |
+
and make something not midpoint look like midpoint. For example, the fp16
|
169 |
+
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
170 |
+
shift right by 4 bits, it would look like midpoint.
|
171 |
+
*/
|
172 |
+
|
173 |
+
if (exponent_diff > 0) {
|
174 |
+
mantissa >>= exponent_diff;
|
175 |
+
} else if (exponent_diff == -1) {
|
176 |
+
mantissa <<= -exponent_diff;
|
177 |
+
}
|
178 |
+
bool implicit_one = mantissa & (1 << mfmt);
|
179 |
+
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
180 |
+
// to denorm exponent
|
181 |
+
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
182 |
+
f8_bias - (implicit_one ? 0 : 1);
|
183 |
+
|
184 |
+
// Now we have the exponent and mantissa adjusted
|
185 |
+
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
186 |
+
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
187 |
+
// that is not truncated is 1
|
188 |
+
mantissa +=
|
189 |
+
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
190 |
+
drop_mask;
|
191 |
+
|
192 |
+
// Now we deal with overflow
|
193 |
+
if (f8_exponent == 0) {
|
194 |
+
if ((1 << mfmt) & mantissa) {
|
195 |
+
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
196 |
+
}
|
197 |
+
} else {
|
198 |
+
if ((1 << (mfmt + 1)) & mantissa) {
|
199 |
+
mantissa >>= 1;
|
200 |
+
f8_exponent++;
|
201 |
+
}
|
202 |
+
}
|
203 |
+
|
204 |
+
mantissa >>= (mfmt - wm);
|
205 |
+
|
206 |
+
// above range: quantize to maximum possible float of the same sign
|
207 |
+
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
208 |
+
if (f8_exponent > max_exp) {
|
209 |
+
if (clip) {
|
210 |
+
mantissa = (1 << wm) - 1;
|
211 |
+
f8_exponent = max_exp;
|
212 |
+
} else {
|
213 |
+
return signed_inf;
|
214 |
+
}
|
215 |
+
}
|
216 |
+
|
217 |
+
if (f8_exponent == 0 && mantissa == 0) {
|
218 |
+
return negative_zero_nan ? 0 : (sign << 7);
|
219 |
+
}
|
220 |
+
mantissa &= (1 << wm) - 1;
|
221 |
+
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
222 |
+
}
|
223 |
+
|
224 |
+
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
225 |
+
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
226 |
+
#ifdef __HIPCC__
|
227 |
+
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
228 |
+
#else
|
229 |
+
constexpr bool is_half = false;
|
230 |
+
#endif
|
231 |
+
constexpr bool is_float = std::is_same<T, float>::value;
|
232 |
+
static_assert(is_half || is_float, "only half and float are supported");
|
233 |
+
|
234 |
+
constexpr int weo = is_half ? 5 : 8;
|
235 |
+
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
236 |
+
|
237 |
+
T fInf, fNegInf, fNaN, fNeg0;
|
238 |
+
|
239 |
+
#ifdef __HIPCC__
|
240 |
+
if (is_half) {
|
241 |
+
const uint16_t ihInf = 0x7C00;
|
242 |
+
const uint16_t ihNegInf = 0xFC00;
|
243 |
+
const uint16_t ihNaN = 0x7C01;
|
244 |
+
const uint16_t ihNeg0 = 0x8000;
|
245 |
+
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
246 |
+
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
247 |
+
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
248 |
+
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
249 |
+
} else
|
250 |
+
#endif
|
251 |
+
if (is_float) {
|
252 |
+
const uint32_t ifInf = 0x7F800000;
|
253 |
+
const uint32_t ifNegInf = 0xFF800000;
|
254 |
+
const uint32_t ifNaN = 0x7F800001;
|
255 |
+
const uint32_t ifNeg0 = 0x80000000;
|
256 |
+
fInf = reinterpret_cast<const float&>(ifInf);
|
257 |
+
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
258 |
+
fNaN = reinterpret_cast<const float&>(ifNaN);
|
259 |
+
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
260 |
+
}
|
261 |
+
|
262 |
+
if (x == 0) {
|
263 |
+
return 0;
|
264 |
+
}
|
265 |
+
|
266 |
+
uint32_t sign = x >> 7;
|
267 |
+
uint32_t mantissa = x & ((1 << wm) - 1);
|
268 |
+
int exponent = (x & 0x7F) >> wm;
|
269 |
+
if (negative_zero_nan) {
|
270 |
+
if (x == 0x80) {
|
271 |
+
return fNaN;
|
272 |
+
}
|
273 |
+
} else {
|
274 |
+
if (x == 0x80) {
|
275 |
+
return fNeg0;
|
276 |
+
}
|
277 |
+
if (exponent == ((1 << we) - 1)) {
|
278 |
+
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
279 |
+
}
|
280 |
+
}
|
281 |
+
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
282 |
+
if (we == 5 && is_half && !negative_zero_nan) {
|
283 |
+
retval = x << 8;
|
284 |
+
return reinterpret_cast<const T&>(retval);
|
285 |
+
}
|
286 |
+
|
287 |
+
const int exp_low_cutoff =
|
288 |
+
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
289 |
+
|
290 |
+
// subnormal input
|
291 |
+
if (exponent == 0) {
|
292 |
+
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
293 |
+
int sh = 1 + clz(mantissa) - (32 - wm);
|
294 |
+
mantissa <<= sh;
|
295 |
+
exponent += 1 - sh;
|
296 |
+
mantissa &= ((1 << wm) - 1);
|
297 |
+
}
|
298 |
+
exponent += exp_low_cutoff - 1;
|
299 |
+
mantissa <<= wmo - wm;
|
300 |
+
|
301 |
+
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
302 |
+
if (exponent <= 0) {
|
303 |
+
mantissa |= 1 << wmo;
|
304 |
+
mantissa >>= 1 - exponent;
|
305 |
+
exponent = 0;
|
306 |
+
}
|
307 |
+
|
308 |
+
if (sizeof(T) == 2) {
|
309 |
+
retval = (sign << 15) | (exponent << 10) | mantissa;
|
310 |
+
} else {
|
311 |
+
retval = (sign << 31) | (exponent << 23) | mantissa;
|
312 |
+
}
|
313 |
+
return reinterpret_cast<const T&>(retval);
|
314 |
+
}
|
315 |
+
|
316 |
+
} // namespace hip_fp8_impl
|
fp8/amd/quant_utils.cuh
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include "hip_float8.h"
|
3 |
+
|
4 |
+
#include <hip/hip_fp16.h>
|
5 |
+
#include <hip/hip_bf16.h>
|
6 |
+
#include <hip/hip_bfloat16.h>
|
7 |
+
|
8 |
+
#include "../../../attention/dtype_fp8.cuh"
|
9 |
+
#include "../../../attention/dtype_float32.cuh"
|
10 |
+
#include "../../../attention/dtype_bfloat16.cuh"
|
11 |
+
|
12 |
+
namespace vllm {
|
13 |
+
#ifdef USE_ROCM
|
14 |
+
|
15 |
+
namespace fp8 {
|
16 |
+
#ifdef ENABLE_FP8
|
17 |
+
|
18 |
+
template <typename Tout, typename Tin>
|
19 |
+
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
20 |
+
return x;
|
21 |
+
}
|
22 |
+
|
23 |
+
template <typename Tout, typename Tin>
|
24 |
+
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
25 |
+
const float scale) {
|
26 |
+
return x;
|
27 |
+
}
|
28 |
+
|
29 |
+
// fp8 -> half
|
30 |
+
template <>
|
31 |
+
__inline__ __device__ uint16_t
|
32 |
+
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
33 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
34 |
+
__half_raw res;
|
35 |
+
res.data = static_cast<float>(f8);
|
36 |
+
return res.x;
|
37 |
+
}
|
38 |
+
|
39 |
+
// fp8x2 -> half2
|
40 |
+
template <>
|
41 |
+
__inline__ __device__ uint32_t
|
42 |
+
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
43 |
+
#if defined(__HIP__MI300__) && \
|
44 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
45 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
46 |
+
union {
|
47 |
+
__half2_raw h2r;
|
48 |
+
uint32_t ui32;
|
49 |
+
} tmp;
|
50 |
+
tmp.h2r.x.data = f2[0];
|
51 |
+
tmp.h2r.y.data = f2[1];
|
52 |
+
return tmp.ui32;
|
53 |
+
#else
|
54 |
+
union {
|
55 |
+
uint16_t u16[2];
|
56 |
+
uint32_t u32;
|
57 |
+
} tmp;
|
58 |
+
|
59 |
+
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
60 |
+
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
61 |
+
return tmp.u32;
|
62 |
+
#endif
|
63 |
+
}
|
64 |
+
|
65 |
+
// fp8x4 -> half2x2
|
66 |
+
template <>
|
67 |
+
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
68 |
+
union {
|
69 |
+
uint2 u32x2;
|
70 |
+
uint32_t u32[2];
|
71 |
+
} tmp;
|
72 |
+
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
73 |
+
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
74 |
+
return tmp.u32x2;
|
75 |
+
}
|
76 |
+
|
77 |
+
// fp8x8 -> half2x4
|
78 |
+
template <>
|
79 |
+
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
80 |
+
union {
|
81 |
+
uint4 u64x2;
|
82 |
+
uint2 u64[2];
|
83 |
+
} tmp;
|
84 |
+
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
85 |
+
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
86 |
+
return tmp.u64x2;
|
87 |
+
}
|
88 |
+
|
89 |
+
using __nv_bfloat16 = __hip_bfloat16;
|
90 |
+
|
91 |
+
// fp8 -> __nv_bfloat16
|
92 |
+
template <>
|
93 |
+
__inline__ __device__ __nv_bfloat16
|
94 |
+
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
95 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
96 |
+
float f{f8};
|
97 |
+
return __float2bfloat16(f);
|
98 |
+
}
|
99 |
+
|
100 |
+
using __nv_bfloat162 = __hip_bfloat162;
|
101 |
+
|
102 |
+
// fp8x2 -> __nv_bfloat162
|
103 |
+
template <>
|
104 |
+
__inline__ __device__ __nv_bfloat162
|
105 |
+
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
106 |
+
__nv_bfloat162 res;
|
107 |
+
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
108 |
+
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
109 |
+
return res;
|
110 |
+
}
|
111 |
+
|
112 |
+
// fp8x4 -> bf16_4_t
|
113 |
+
template <>
|
114 |
+
__inline__ __device__ bf16_4_t
|
115 |
+
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
116 |
+
bf16_4_t res;
|
117 |
+
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
118 |
+
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
119 |
+
return res;
|
120 |
+
}
|
121 |
+
|
122 |
+
// fp8x8 -> bf16_8_t
|
123 |
+
template <>
|
124 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
125 |
+
bf16_4_t tmp1, tmp2;
|
126 |
+
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
127 |
+
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
128 |
+
bf16_8_t res;
|
129 |
+
res.x = tmp1.x;
|
130 |
+
res.y = tmp1.y;
|
131 |
+
res.z = tmp2.x;
|
132 |
+
res.w = tmp2.y;
|
133 |
+
return res;
|
134 |
+
}
|
135 |
+
|
136 |
+
// fp8 -> float
|
137 |
+
template <>
|
138 |
+
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
139 |
+
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
140 |
+
return static_cast<float>(fp8);
|
141 |
+
}
|
142 |
+
|
143 |
+
// fp8x2 -> float2
|
144 |
+
template <>
|
145 |
+
__inline__ __device__ float2
|
146 |
+
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
147 |
+
#if defined(__HIP__MI300__) && \
|
148 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
149 |
+
float2 res;
|
150 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
151 |
+
res.x = f2[0];
|
152 |
+
res.y = f2[1];
|
153 |
+
return res;
|
154 |
+
#else
|
155 |
+
float2 res;
|
156 |
+
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
157 |
+
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
158 |
+
return res;
|
159 |
+
#endif
|
160 |
+
}
|
161 |
+
|
162 |
+
// fp8x4 -> float4
|
163 |
+
template <>
|
164 |
+
__inline__ __device__ Float4_
|
165 |
+
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
166 |
+
Float4_ res;
|
167 |
+
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
168 |
+
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
169 |
+
return res;
|
170 |
+
}
|
171 |
+
|
172 |
+
// fp8x8 -> float8
|
173 |
+
template <>
|
174 |
+
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
175 |
+
Float4_ tmp1, tmp2;
|
176 |
+
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
177 |
+
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
178 |
+
Float8_ res;
|
179 |
+
res.x = tmp1.x;
|
180 |
+
res.y = tmp1.y;
|
181 |
+
res.z = tmp2.x;
|
182 |
+
res.w = tmp2.y;
|
183 |
+
return res;
|
184 |
+
}
|
185 |
+
|
186 |
+
// half -> fp8
|
187 |
+
template <>
|
188 |
+
__inline__ __device__ uint8_t
|
189 |
+
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
190 |
+
__half_raw tmp;
|
191 |
+
tmp.x = a;
|
192 |
+
|
193 |
+
hip_fp8 f8{static_cast<float>(tmp.data)};
|
194 |
+
return f8.data;
|
195 |
+
}
|
196 |
+
|
197 |
+
// bf16 -> fp8
|
198 |
+
template <>
|
199 |
+
__inline__ __device__ uint8_t
|
200 |
+
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
201 |
+
hip_fp8 res{__bfloat162float(a)};
|
202 |
+
return res.data;
|
203 |
+
}
|
204 |
+
|
205 |
+
// float -> fp8
|
206 |
+
template <>
|
207 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
208 |
+
hip_fp8 f8(a);
|
209 |
+
return f8.data;
|
210 |
+
}
|
211 |
+
|
212 |
+
// fp8x4 -> float4
|
213 |
+
template <>
|
214 |
+
__inline__ __device__ float4
|
215 |
+
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
216 |
+
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
217 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
218 |
+
return res;
|
219 |
+
}
|
220 |
+
|
221 |
+
// float2 -> half2
|
222 |
+
template <>
|
223 |
+
__inline__ __device__ uint32_t
|
224 |
+
vec_conversion<uint32_t, float2>(const float2& a) {
|
225 |
+
union {
|
226 |
+
half2 float16;
|
227 |
+
uint32_t uint32;
|
228 |
+
};
|
229 |
+
|
230 |
+
float16 = __float22half2_rn(a);
|
231 |
+
return uint32;
|
232 |
+
}
|
233 |
+
|
234 |
+
// Float4 -> half2x2
|
235 |
+
template <>
|
236 |
+
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
237 |
+
uint2 b;
|
238 |
+
float2 val;
|
239 |
+
val.x = a.x.x;
|
240 |
+
val.y = a.x.y;
|
241 |
+
b.x = vec_conversion<uint32_t, float2>(val);
|
242 |
+
|
243 |
+
val.x = a.y.x;
|
244 |
+
val.y = a.y.y;
|
245 |
+
b.y = vec_conversion<uint32_t, float2>(val);
|
246 |
+
return b;
|
247 |
+
}
|
248 |
+
|
249 |
+
// Float4 -> float4
|
250 |
+
template <>
|
251 |
+
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
252 |
+
float4 b;
|
253 |
+
b.x = a.x.x;
|
254 |
+
b.y = a.x.y;
|
255 |
+
b.z = a.y.x;
|
256 |
+
b.w = a.y.y;
|
257 |
+
return b;
|
258 |
+
}
|
259 |
+
|
260 |
+
// Float8 -> half2x4
|
261 |
+
template <>
|
262 |
+
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
263 |
+
uint4 b;
|
264 |
+
b.x = vec_conversion<uint32_t, float2>(a.x);
|
265 |
+
b.y = vec_conversion<uint32_t, float2>(a.y);
|
266 |
+
b.z = vec_conversion<uint32_t, float2>(a.z);
|
267 |
+
b.w = vec_conversion<uint32_t, float2>(a.w);
|
268 |
+
return b;
|
269 |
+
}
|
270 |
+
|
271 |
+
// float2 -> bfloat162
|
272 |
+
template <>
|
273 |
+
__inline__ __device__ __nv_bfloat162
|
274 |
+
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
275 |
+
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
276 |
+
return b;
|
277 |
+
}
|
278 |
+
|
279 |
+
// Float4 -> bfloat162x2
|
280 |
+
template <>
|
281 |
+
__inline__ __device__ bf16_4_t
|
282 |
+
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
283 |
+
bf16_4_t b;
|
284 |
+
b.x = __float22bfloat162_rn(a.x);
|
285 |
+
b.y = __float22bfloat162_rn(a.y);
|
286 |
+
return b;
|
287 |
+
}
|
288 |
+
|
289 |
+
// Float8 -> bfloat162x4
|
290 |
+
template <>
|
291 |
+
__inline__ __device__ bf16_8_t
|
292 |
+
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
293 |
+
bf16_8_t b;
|
294 |
+
b.x = __float22bfloat162_rn(a.x);
|
295 |
+
b.y = __float22bfloat162_rn(a.y);
|
296 |
+
b.z = __float22bfloat162_rn(a.z);
|
297 |
+
b.w = __float22bfloat162_rn(a.w);
|
298 |
+
return b;
|
299 |
+
}
|
300 |
+
|
301 |
+
/* Scaled and vectorized conversions, for data exchange between high and low
|
302 |
+
precision domains
|
303 |
+
|
304 |
+
Convention of the scale in API, e.g: FP8_data = Quantization(
|
305 |
+
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
306 |
+
scale => HP
|
307 |
+
|
308 |
+
*/
|
309 |
+
|
310 |
+
// fp8 -> half
|
311 |
+
template <>
|
312 |
+
__inline__ __device__ uint16_t
|
313 |
+
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
314 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
315 |
+
__half_raw res;
|
316 |
+
res.data = static_cast<float>(f8) * scale;
|
317 |
+
return res.x;
|
318 |
+
}
|
319 |
+
|
320 |
+
// fp8x2 -> half2
|
321 |
+
template <>
|
322 |
+
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
323 |
+
const uint16_t& a, const float scale) {
|
324 |
+
#if defined(__HIP__MI300__) && \
|
325 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
326 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
327 |
+
union {
|
328 |
+
__half2_raw h2r;
|
329 |
+
uint32_t ui32;
|
330 |
+
} tmp;
|
331 |
+
tmp.h2r.x.data = f2[0] * scale;
|
332 |
+
tmp.h2r.y.data = f2[1] * scale;
|
333 |
+
return tmp.ui32;
|
334 |
+
#else
|
335 |
+
union {
|
336 |
+
uint16_t u16[2];
|
337 |
+
uint32_t u32;
|
338 |
+
} tmp;
|
339 |
+
|
340 |
+
tmp.u16[0] =
|
341 |
+
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
342 |
+
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
343 |
+
static_cast<uint8_t>(a >> 8U), scale);
|
344 |
+
return tmp.u32;
|
345 |
+
#endif
|
346 |
+
}
|
347 |
+
|
348 |
+
// fp8x4 -> half2x2
|
349 |
+
template <>
|
350 |
+
__inline__ __device__ uint2
|
351 |
+
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
352 |
+
union {
|
353 |
+
uint2 u32x2;
|
354 |
+
uint32_t u32[2];
|
355 |
+
} tmp;
|
356 |
+
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
357 |
+
tmp.u32[1] =
|
358 |
+
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
359 |
+
return tmp.u32x2;
|
360 |
+
}
|
361 |
+
|
362 |
+
// fp8x8 -> half2x4
|
363 |
+
template <>
|
364 |
+
__inline__ __device__ uint4
|
365 |
+
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
366 |
+
union {
|
367 |
+
uint4 u64x2;
|
368 |
+
uint2 u64[2];
|
369 |
+
} tmp;
|
370 |
+
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
371 |
+
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
372 |
+
return tmp.u64x2;
|
373 |
+
}
|
374 |
+
|
375 |
+
using __nv_bfloat16 = __hip_bfloat16;
|
376 |
+
|
377 |
+
// fp8 -> __nv_bfloat16
|
378 |
+
template <>
|
379 |
+
__inline__ __device__ __nv_bfloat16
|
380 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
381 |
+
const float scale) {
|
382 |
+
hip_fp8 f8{a, hip_fp8::from_bits()};
|
383 |
+
float f{f8};
|
384 |
+
return __float2bfloat16(f * scale);
|
385 |
+
}
|
386 |
+
|
387 |
+
using __nv_bfloat162 = __hip_bfloat162;
|
388 |
+
|
389 |
+
// fp8x2 -> __nv_bfloat162
|
390 |
+
template <>
|
391 |
+
__inline__ __device__ __nv_bfloat162
|
392 |
+
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
393 |
+
const float scale) {
|
394 |
+
__nv_bfloat162 res;
|
395 |
+
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
396 |
+
res.y =
|
397 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
398 |
+
return res;
|
399 |
+
}
|
400 |
+
|
401 |
+
// fp8x4 -> bf16_4_t
|
402 |
+
template <>
|
403 |
+
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
404 |
+
const uint32_t& a, const float scale) {
|
405 |
+
bf16_4_t res;
|
406 |
+
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
407 |
+
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
408 |
+
scale);
|
409 |
+
return res;
|
410 |
+
}
|
411 |
+
|
412 |
+
// fp8x8 -> bf16_8_t
|
413 |
+
template <>
|
414 |
+
__inline__ __device__ bf16_8_t
|
415 |
+
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
416 |
+
bf16_4_t tmp1, tmp2;
|
417 |
+
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
418 |
+
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
419 |
+
bf16_8_t res;
|
420 |
+
res.x = tmp1.x;
|
421 |
+
res.y = tmp1.y;
|
422 |
+
res.z = tmp2.x;
|
423 |
+
res.w = tmp2.y;
|
424 |
+
return res;
|
425 |
+
}
|
426 |
+
|
427 |
+
// fp8 -> float
|
428 |
+
template <>
|
429 |
+
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
430 |
+
const uint8_t& a, const float scale) {
|
431 |
+
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
432 |
+
return static_cast<float>(fp8) * scale;
|
433 |
+
}
|
434 |
+
|
435 |
+
// fp8x2 -> float2
|
436 |
+
template <>
|
437 |
+
__inline__ __device__ float2
|
438 |
+
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
439 |
+
#if defined(__HIP__MI300__) && \
|
440 |
+
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
441 |
+
float2 res;
|
442 |
+
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
443 |
+
res.x = f2[0] * scale;
|
444 |
+
res.y = f2[1] * scale;
|
445 |
+
return res;
|
446 |
+
#else
|
447 |
+
float2 res;
|
448 |
+
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
449 |
+
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
450 |
+
scale);
|
451 |
+
return res;
|
452 |
+
#endif
|
453 |
+
}
|
454 |
+
|
455 |
+
// fp8x4 -> float4
|
456 |
+
template <>
|
457 |
+
__inline__ __device__ Float4_
|
458 |
+
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
459 |
+
Float4_ res;
|
460 |
+
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
461 |
+
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
462 |
+
return res;
|
463 |
+
}
|
464 |
+
|
465 |
+
// fp8x8 -> float8
|
466 |
+
template <>
|
467 |
+
__inline__ __device__ Float8_
|
468 |
+
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
469 |
+
Float4_ tmp1, tmp2;
|
470 |
+
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
471 |
+
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
472 |
+
Float8_ res;
|
473 |
+
res.x = tmp1.x;
|
474 |
+
res.y = tmp1.y;
|
475 |
+
res.z = tmp2.x;
|
476 |
+
res.w = tmp2.y;
|
477 |
+
return res;
|
478 |
+
}
|
479 |
+
|
480 |
+
/* Quantize(HP / scale) => FP8 */
|
481 |
+
|
482 |
+
// TODO(Hai): vectorized to add
|
483 |
+
|
484 |
+
// half -> fp8
|
485 |
+
template <>
|
486 |
+
__inline__ __device__ uint8_t
|
487 |
+
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
488 |
+
__half_raw tmp;
|
489 |
+
tmp.x = a;
|
490 |
+
|
491 |
+
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
|
492 |
+
return f8.data;
|
493 |
+
}
|
494 |
+
|
495 |
+
// bf16 -> fp8
|
496 |
+
template <>
|
497 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
498 |
+
const __nv_bfloat16& a, const float scale) {
|
499 |
+
hip_fp8 res{__bfloat162float(a) / scale};
|
500 |
+
return res.data;
|
501 |
+
}
|
502 |
+
|
503 |
+
// float -> fp8
|
504 |
+
template <>
|
505 |
+
__inline__ __device__ uint8_t
|
506 |
+
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
507 |
+
hip_fp8 f8(a / scale);
|
508 |
+
return f8.data;
|
509 |
+
}
|
510 |
+
|
511 |
+
// fp8x4 -> float4
|
512 |
+
template <>
|
513 |
+
__inline__ __device__ float4
|
514 |
+
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
515 |
+
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
516 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
517 |
+
return res;
|
518 |
+
}
|
519 |
+
#endif // ENABLE_FP8
|
520 |
+
|
521 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
522 |
+
__inline__ __device__ Tout convert(const Tin& x) {
|
523 |
+
#ifdef ENABLE_FP8
|
524 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
525 |
+
return vec_conversion<Tout, Tin>(x);
|
526 |
+
}
|
527 |
+
#endif
|
528 |
+
assert(false);
|
529 |
+
return {}; // Squash missing return statement warning
|
530 |
+
}
|
531 |
+
|
532 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
533 |
+
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
534 |
+
#ifdef ENABLE_FP8
|
535 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
536 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
537 |
+
}
|
538 |
+
#endif
|
539 |
+
assert(false);
|
540 |
+
return {}; // Squash missing return statement warning
|
541 |
+
}
|
542 |
+
|
543 |
+
// The following macro is used to dispatch the conversion function based on
|
544 |
+
// the data type of the key and value cache. The FN is a macro that calls a
|
545 |
+
// function with template<typename scalar_t, typename cache_t,
|
546 |
+
// Fp8KVCacheDataType kv_dt>.
|
547 |
+
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
548 |
+
if (KV_DTYPE == "auto") { \
|
549 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
550 |
+
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
551 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
552 |
+
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
553 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
554 |
+
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
555 |
+
} else { \
|
556 |
+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
557 |
+
} \
|
558 |
+
} else { \
|
559 |
+
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
560 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
561 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
562 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
563 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
564 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
565 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
566 |
+
} else { \
|
567 |
+
TORCH_CHECK(false, \
|
568 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
569 |
+
} \
|
570 |
+
} else { \
|
571 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
572 |
+
} \
|
573 |
+
}
|
574 |
+
|
575 |
+
} // namespace fp8
|
576 |
+
#endif // USE_ROCM
|
577 |
+
} // namespace vllm
|
fp8/common.cu
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "common.cuh"
|
2 |
+
#include "dispatch_utils.h"
|
3 |
+
|
4 |
+
#include <c10/cuda/CUDAGuard.h>
|
5 |
+
|
6 |
+
#ifndef USE_ROCM
|
7 |
+
#include <cub/cub.cuh>
|
8 |
+
#else
|
9 |
+
#include <hipcub/hipcub.hpp>
|
10 |
+
#endif
|
11 |
+
|
12 |
+
namespace vllm {
|
13 |
+
|
14 |
+
template <typename scalar_t>
|
15 |
+
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
16 |
+
const scalar_t* __restrict__ input,
|
17 |
+
const float* __restrict__ scale,
|
18 |
+
int64_t num_elems) {
|
19 |
+
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
20 |
+
|
21 |
+
// Invert the scale so that we can use multiplications to avoid expensive
|
22 |
+
// division.
|
23 |
+
const float inverted_scale = 1.0f / (*scale);
|
24 |
+
scaled_fp8_conversion_vec<scalar_t, true>(
|
25 |
+
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
26 |
+
}
|
27 |
+
|
28 |
+
template <typename scalar_t>
|
29 |
+
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
30 |
+
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
31 |
+
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
32 |
+
const int hidden_size) {
|
33 |
+
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
34 |
+
|
35 |
+
int const tid = threadIdx.x;
|
36 |
+
int const token_idx = blockIdx.x;
|
37 |
+
|
38 |
+
// Use int64 to avoid overflowing an int32 when calculating this offset
|
39 |
+
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
40 |
+
scalar_t const* __restrict__ token_input = &input[offset];
|
41 |
+
FP8_TYPE* __restrict__ token_output = &out[offset];
|
42 |
+
|
43 |
+
// For vectorization, token_input and token_output pointers need to be
|
44 |
+
// aligned at 8-byte and 4-byte addresses respectively.
|
45 |
+
bool const can_vectorize = hidden_size % 4 == 0;
|
46 |
+
|
47 |
+
float absmax_val = 0.0f;
|
48 |
+
if (can_vectorize) {
|
49 |
+
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
|
50 |
+
} else {
|
51 |
+
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
52 |
+
float const x = static_cast<float>(token_input[i]);
|
53 |
+
absmax_val = max(absmax_val, fabs(x));
|
54 |
+
}
|
55 |
+
}
|
56 |
+
|
57 |
+
using BlockReduce = cub::BlockReduce<float, 1024>;
|
58 |
+
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
59 |
+
float const block_absmax_val_maybe =
|
60 |
+
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
61 |
+
__shared__ float token_scale;
|
62 |
+
if (tid == 0) {
|
63 |
+
if (scale_ub) {
|
64 |
+
token_scale = min(block_absmax_val_maybe, *scale_ub);
|
65 |
+
} else {
|
66 |
+
token_scale = block_absmax_val_maybe;
|
67 |
+
}
|
68 |
+
// token scale computation
|
69 |
+
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
70 |
+
scale[token_idx] = token_scale;
|
71 |
+
}
|
72 |
+
__syncthreads();
|
73 |
+
|
74 |
+
// Note that we don't use inverted scales so we can match FBGemm impl.
|
75 |
+
if (can_vectorize) {
|
76 |
+
scaled_fp8_conversion_vec<scalar_t, false>(
|
77 |
+
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
78 |
+
} else {
|
79 |
+
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
80 |
+
token_output[i] = scaled_fp8_conversion<false>(
|
81 |
+
static_cast<float>(token_input[i]), token_scale);
|
82 |
+
}
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
} // namespace vllm
|
87 |
+
|
88 |
+
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
89 |
+
torch::Tensor const& input, // [..., d]
|
90 |
+
torch::Tensor const& scale) // [1]
|
91 |
+
{
|
92 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
93 |
+
int64_t num_elems = input.numel();
|
94 |
+
dim3 grid(num_tokens);
|
95 |
+
dim3 block(1024);
|
96 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
97 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
98 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
99 |
+
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
100 |
+
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
101 |
+
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
102 |
+
scale.data_ptr<float>(), num_elems);
|
103 |
+
});
|
104 |
+
}
|
105 |
+
|
106 |
+
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
107 |
+
torch::Tensor const& input, // [..., d]
|
108 |
+
torch::Tensor& scale) // [1]
|
109 |
+
{
|
110 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
111 |
+
int64_t num_elems = input.numel();
|
112 |
+
dim3 grid(num_tokens);
|
113 |
+
dim3 block(1024);
|
114 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
115 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
116 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
117 |
+
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
118 |
+
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
119 |
+
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
120 |
+
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
121 |
+
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
122 |
+
scale.data_ptr<float>(), num_elems);
|
123 |
+
});
|
124 |
+
}
|
125 |
+
|
126 |
+
void dynamic_per_token_scaled_fp8_quant(
|
127 |
+
torch::Tensor& out, // [..., d]
|
128 |
+
torch::Tensor const& input, // [..., d]
|
129 |
+
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
130 |
+
TORCH_CHECK(input.is_contiguous());
|
131 |
+
TORCH_CHECK(out.is_contiguous());
|
132 |
+
|
133 |
+
int const hidden_size = input.size(-1);
|
134 |
+
int const num_tokens = input.numel() / hidden_size;
|
135 |
+
dim3 const grid(num_tokens);
|
136 |
+
dim3 const block(std::min(hidden_size, 1024));
|
137 |
+
|
138 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
139 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
140 |
+
VLLM_DISPATCH_FLOATING_TYPES(
|
141 |
+
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
142 |
+
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
143 |
+
<<<grid, block, 0, stream>>>(
|
144 |
+
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
145 |
+
input.data_ptr<scalar_t>(),
|
146 |
+
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
147 |
+
hidden_size);
|
148 |
+
});
|
149 |
+
}
|
fp8/common.cuh
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <cmath>
|
4 |
+
|
5 |
+
#ifndef USE_ROCM
|
6 |
+
#include <c10/util/Float8_e4m3fn.h>
|
7 |
+
using FP8_TYPE = c10::Float8_e4m3fn;
|
8 |
+
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
9 |
+
std::numeric_limits<FP8_TYPE>::max();
|
10 |
+
#else
|
11 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
12 |
+
#include "amd/hip_float8.h"
|
13 |
+
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
14 |
+
// Using the default max value from pytorch (240.0) will cause accuracy
|
15 |
+
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
16 |
+
constexpr auto FP8_E4M3_MAX = 224.0f;
|
17 |
+
#endif
|
18 |
+
|
19 |
+
namespace vllm {
|
20 |
+
|
21 |
+
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
22 |
+
float old;
|
23 |
+
old = (value >= 0)
|
24 |
+
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
25 |
+
: __uint_as_float(
|
26 |
+
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
27 |
+
|
28 |
+
return old;
|
29 |
+
}
|
30 |
+
|
31 |
+
template <bool is_scale_inverted>
|
32 |
+
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
33 |
+
float const scale) {
|
34 |
+
float x = 0.0f;
|
35 |
+
if constexpr (is_scale_inverted) {
|
36 |
+
x = val * scale;
|
37 |
+
} else {
|
38 |
+
x = val / scale;
|
39 |
+
}
|
40 |
+
|
41 |
+
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
42 |
+
#ifndef USE_ROCM
|
43 |
+
return static_cast<c10::Float8_e4m3fn>(r);
|
44 |
+
#else
|
45 |
+
// Use hardware cvt instruction for fp8 on rocm
|
46 |
+
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
47 |
+
c10::Float8_e4m3fnuz::from_bits());
|
48 |
+
#endif
|
49 |
+
}
|
50 |
+
|
51 |
+
// Compute the absolute maximum m of the input tensor and store
|
52 |
+
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
53 |
+
// reduction tree and the memory in scale is atomically updated.
|
54 |
+
// So to get the right answer, *scale needs to be initialized to
|
55 |
+
// a value <= 0.0 and we need to wait for all thread blocks to
|
56 |
+
// finish before consuming *scale.
|
57 |
+
template <typename scalar_t>
|
58 |
+
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
59 |
+
const scalar_t* __restrict__ input,
|
60 |
+
int64_t num_elems) {
|
61 |
+
__shared__ float cache[1024];
|
62 |
+
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
63 |
+
|
64 |
+
// First store maximum for all values processes by
|
65 |
+
// the current thread in cache[threadIdx.x]
|
66 |
+
scalar_t tmp = 0.0;
|
67 |
+
while (i < num_elems) {
|
68 |
+
float x = static_cast<float>(input[i]);
|
69 |
+
tmp = max(tmp, fabs(x));
|
70 |
+
i += blockDim.x * gridDim.x;
|
71 |
+
}
|
72 |
+
cache[threadIdx.x] = tmp;
|
73 |
+
|
74 |
+
__syncthreads();
|
75 |
+
|
76 |
+
// Now perform parallel reduction within the thread block
|
77 |
+
int ib = blockDim.x / 2;
|
78 |
+
while (ib != 0) {
|
79 |
+
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
80 |
+
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
81 |
+
}
|
82 |
+
__syncthreads();
|
83 |
+
ib /= 2;
|
84 |
+
}
|
85 |
+
// Finally, since cache[0] contains the maximum for this thread block,
|
86 |
+
// atomically write the max to the target location
|
87 |
+
if (threadIdx.x == 0) {
|
88 |
+
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
template <typename scalar_t>
|
93 |
+
struct __align__(8) vec4_t {
|
94 |
+
scalar_t x;
|
95 |
+
scalar_t y;
|
96 |
+
scalar_t z;
|
97 |
+
scalar_t w;
|
98 |
+
};
|
99 |
+
|
100 |
+
typedef struct __align__(4) {
|
101 |
+
FP8_TYPE x;
|
102 |
+
FP8_TYPE y;
|
103 |
+
FP8_TYPE z;
|
104 |
+
FP8_TYPE w;
|
105 |
+
}
|
106 |
+
float8x4_t;
|
107 |
+
|
108 |
+
template <typename scalar_t>
|
109 |
+
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
110 |
+
int64_t const num_elems, int const tid,
|
111 |
+
int const step) {
|
112 |
+
// Vectorized input/output to better utilize memory bandwidth.
|
113 |
+
vec4_t<scalar_t> const* vectorized_in =
|
114 |
+
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
115 |
+
|
116 |
+
int64_t const num_vec_elems = num_elems >> 2;
|
117 |
+
float absmax_val = 0.0f;
|
118 |
+
|
119 |
+
#pragma unroll 4
|
120 |
+
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
121 |
+
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
122 |
+
absmax_val = max(absmax_val, fabs(in_vec.x));
|
123 |
+
absmax_val = max(absmax_val, fabs(in_vec.y));
|
124 |
+
absmax_val = max(absmax_val, fabs(in_vec.z));
|
125 |
+
absmax_val = max(absmax_val, fabs(in_vec.w));
|
126 |
+
}
|
127 |
+
|
128 |
+
// Handle the remaining elements if num_elems is not divisible by 4
|
129 |
+
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
130 |
+
absmax_val = max(absmax_val, fabs(input[i]));
|
131 |
+
}
|
132 |
+
|
133 |
+
return absmax_val;
|
134 |
+
}
|
135 |
+
|
136 |
+
template <typename scalar_t, bool is_scale_inverted>
|
137 |
+
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
138 |
+
scalar_t const* __restrict__ input,
|
139 |
+
float const scale,
|
140 |
+
int64_t const num_elems,
|
141 |
+
int const tid, int const step) {
|
142 |
+
// Vectorized input/output to better utilize memory bandwidth.
|
143 |
+
vec4_t<scalar_t> const* vectorized_in =
|
144 |
+
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
145 |
+
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
146 |
+
|
147 |
+
int64_t const num_vec_elems = num_elems >> 2;
|
148 |
+
|
149 |
+
#pragma unroll 4
|
150 |
+
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
151 |
+
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
152 |
+
float8x4_t out_vec;
|
153 |
+
|
154 |
+
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
155 |
+
static_cast<float>(in_vec.x), scale);
|
156 |
+
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
157 |
+
static_cast<float>(in_vec.y), scale);
|
158 |
+
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
159 |
+
static_cast<float>(in_vec.z), scale);
|
160 |
+
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
161 |
+
static_cast<float>(in_vec.w), scale);
|
162 |
+
vectorized_out[i] = out_vec;
|
163 |
+
}
|
164 |
+
|
165 |
+
// Handle the remaining elements if num_elems is not divisible by 4
|
166 |
+
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
167 |
+
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
168 |
+
static_cast<float>(input[i]), scale);
|
169 |
+
}
|
170 |
+
}
|
171 |
+
|
172 |
+
} // namespace vllm
|
fp8/fp8_marlin.cu
ADDED
@@ -0,0 +1,1306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Modified by Neural Magic
|
3 |
+
* Copyright (C) Marlin.2024 Elias Frantar
|
4 |
+
*
|
5 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
* you may not use this file except in compliance with the License.
|
7 |
+
* You may obtain a copy of the License at
|
8 |
+
*
|
9 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
*
|
11 |
+
* Unless required by applicable law or agreed to in writing, software
|
12 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
* See the License for the specific language governing permissions and
|
15 |
+
* limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
/*
|
19 |
+
* Adapted from https://github.com/IST-DASLab/marlin
|
20 |
+
*/
|
21 |
+
|
22 |
+
#include "../gptq_marlin/marlin.cuh"
|
23 |
+
#include "../gptq_marlin/marlin_dtypes.cuh"
|
24 |
+
|
25 |
+
using namespace marlin;
|
26 |
+
|
27 |
+
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
28 |
+
static_assert(std::is_same<scalar_t, half>::value || \
|
29 |
+
std::is_same<scalar_t, nv_bfloat16>::value, \
|
30 |
+
"only float16 and bfloat16 is supported");
|
31 |
+
|
32 |
+
template <typename T>
|
33 |
+
inline std::string str(T x) {
|
34 |
+
return std::to_string(x);
|
35 |
+
}
|
36 |
+
|
37 |
+
namespace fp8_marlin {
|
38 |
+
|
39 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
40 |
+
|
41 |
+
template <typename scalar_t, // compute dtype, half or nv_float16
|
42 |
+
const int num_bits, // number of bits used for weights
|
43 |
+
const int threads, // number of threads in a threadblock
|
44 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
45 |
+
// dimension (batchsize) of the
|
46 |
+
// threadblock
|
47 |
+
const int thread_n_blocks, // same for n dimension (output)
|
48 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
49 |
+
const int stages, // number of stages for the async global->shared
|
50 |
+
// fetch pipeline
|
51 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
52 |
+
// with a separate quantization scale
|
53 |
+
>
|
54 |
+
__global__ void Marlin(
|
55 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
56 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
57 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
58 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
59 |
+
// (k/groupsize)xn
|
60 |
+
int num_groups, // number of scale groups per output channel
|
61 |
+
int prob_m, // batch dimension m
|
62 |
+
int prob_n, // output dimension n
|
63 |
+
int prob_k, // reduction dimension k
|
64 |
+
int* locks // extra global storage for barrier synchronization
|
65 |
+
) {}
|
66 |
+
|
67 |
+
} // namespace fp8_marlin
|
68 |
+
|
69 |
+
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
70 |
+
torch::Tensor& b_scales, torch::Tensor& workspace,
|
71 |
+
int64_t num_bits, int64_t size_m, int64_t size_n,
|
72 |
+
int64_t size_k) {
|
73 |
+
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
74 |
+
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
75 |
+
return torch::empty({1, 1});
|
76 |
+
}
|
77 |
+
|
78 |
+
#else
|
79 |
+
|
80 |
+
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
81 |
+
// output/accumulation.
|
82 |
+
template <typename scalar_t>
|
83 |
+
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
|
84 |
+
const typename ScalarType<scalar_t>::FragB& frag_b,
|
85 |
+
typename ScalarType<scalar_t>::FragC& frag_c) {
|
86 |
+
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
87 |
+
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
88 |
+
float* c = reinterpret_cast<float*>(&frag_c);
|
89 |
+
if constexpr (std::is_same<scalar_t, half>::value) {
|
90 |
+
asm volatile(
|
91 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
92 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
93 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
94 |
+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
95 |
+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
96 |
+
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
97 |
+
asm volatile(
|
98 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
99 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
100 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
101 |
+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
102 |
+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
103 |
+
} else {
|
104 |
+
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
109 |
+
// memory, directly in tensor core layout.
|
110 |
+
template <typename scalar_t>
|
111 |
+
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
|
112 |
+
const void* smem_ptr) {
|
113 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
114 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
115 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
116 |
+
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
117 |
+
: "r"(smem));
|
118 |
+
}
|
119 |
+
|
120 |
+
// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16
|
121 |
+
// bf16 Reference:
|
122 |
+
// - FP16:
|
123 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
124 |
+
// - BF16:
|
125 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
126 |
+
template <typename scalar_t>
|
127 |
+
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
|
128 |
+
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
129 |
+
}
|
130 |
+
|
131 |
+
template <>
|
132 |
+
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
133 |
+
// Constants for FP8 (E4M3) and FP16 formats
|
134 |
+
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
|
135 |
+
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
136 |
+
|
137 |
+
// Calculate MASK for extracting mantissa and exponent
|
138 |
+
constexpr int MASK1 = 0x80000000;
|
139 |
+
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
140 |
+
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
141 |
+
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
142 |
+
// Final MASK value: 0x7F007F00
|
143 |
+
|
144 |
+
// Extract and shift FP8 values to FP16 format
|
145 |
+
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
146 |
+
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
147 |
+
|
148 |
+
// Construct and apply exponent bias
|
149 |
+
constexpr int BIAS_OFFSET =
|
150 |
+
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
151 |
+
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
152 |
+
|
153 |
+
// Convert to half2 and apply bias
|
154 |
+
typename ScalarType<half>::FragB frag_b;
|
155 |
+
// Note: reverse indexing is intentional because weights are permuted
|
156 |
+
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
|
157 |
+
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
|
158 |
+
return frag_b;
|
159 |
+
}
|
160 |
+
|
161 |
+
template <>
|
162 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
163 |
+
dequant_8bit<nv_bfloat16>(int q) {
|
164 |
+
// Constants for FP8 (E4M3) and BF16 formats
|
165 |
+
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
|
166 |
+
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
167 |
+
|
168 |
+
// Calculate MASK for extracting mantissa and exponent
|
169 |
+
constexpr int MASK1 = 0x80000000;
|
170 |
+
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
171 |
+
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
172 |
+
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
173 |
+
// Final MASK value: 0x7F007F00
|
174 |
+
|
175 |
+
// Extract and shift FP8 values to BF16 format
|
176 |
+
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
177 |
+
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
178 |
+
|
179 |
+
// Construct and apply exponent bias
|
180 |
+
constexpr int BIAS_OFFSET =
|
181 |
+
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
182 |
+
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
183 |
+
// position
|
184 |
+
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
185 |
+
const nv_bfloat162 bias_reg =
|
186 |
+
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
187 |
+
|
188 |
+
// Convert to bfloat162 and apply bias
|
189 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
190 |
+
// Note: reverse indexing is intentional because weights are permuted
|
191 |
+
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
|
192 |
+
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
|
193 |
+
return frag_b;
|
194 |
+
}
|
195 |
+
|
196 |
+
// Multiply dequantized values by the corresponding quantization scale; used
|
197 |
+
// only for grouped quantization.
|
198 |
+
template <typename scalar_t>
|
199 |
+
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
200 |
+
typename ScalarType<scalar_t>::FragS& frag_s,
|
201 |
+
int i) {
|
202 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
203 |
+
scalar_t2 s =
|
204 |
+
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
|
205 |
+
frag_b[0] = __hmul2(frag_b[0], s);
|
206 |
+
frag_b[1] = __hmul2(frag_b[1], s);
|
207 |
+
}
|
208 |
+
|
209 |
+
// Given 2 floats multiply by 2 scales (halves)
|
210 |
+
template <typename scalar_t>
|
211 |
+
__device__ inline void scale_float(float* c,
|
212 |
+
typename ScalarType<scalar_t>::FragS& s) {
|
213 |
+
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
|
214 |
+
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
|
215 |
+
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
|
216 |
+
}
|
217 |
+
|
218 |
+
// Wait until barrier reaches `count`, then lock for current threadblock.
|
219 |
+
__device__ inline void barrier_acquire(int* lock, int count) {
|
220 |
+
if (threadIdx.x == 0) {
|
221 |
+
int state = -1;
|
222 |
+
do
|
223 |
+
// Guarantee that subsequent writes by this threadblock will be visible
|
224 |
+
// globally.
|
225 |
+
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
226 |
+
: "=r"(state)
|
227 |
+
: "l"(lock));
|
228 |
+
while (state != count);
|
229 |
+
}
|
230 |
+
__syncthreads();
|
231 |
+
}
|
232 |
+
|
233 |
+
// Release barrier and increment visitation count.
|
234 |
+
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
235 |
+
__syncthreads();
|
236 |
+
if (threadIdx.x == 0) {
|
237 |
+
if (reset) {
|
238 |
+
lock[0] = 0;
|
239 |
+
return;
|
240 |
+
}
|
241 |
+
int val = 1;
|
242 |
+
// Make sure that all writes since acquiring this barrier are visible
|
243 |
+
// globally, while releasing the barrier.
|
244 |
+
asm volatile("fence.acq_rel.gpu;\n");
|
245 |
+
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
246 |
+
:
|
247 |
+
: "l"(lock), "r"(val));
|
248 |
+
}
|
249 |
+
}
|
250 |
+
|
251 |
+
template <typename scalar_t, // compute dtype, half or nv_float16
|
252 |
+
const int num_bits, // number of bits used for weights
|
253 |
+
const int threads, // number of threads in a threadblock
|
254 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
255 |
+
// dimension (batchsize) of the
|
256 |
+
// threadblock
|
257 |
+
const int thread_n_blocks, // same for n dimension (output)
|
258 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
259 |
+
const int stages, // number of stages for the async global->shared
|
260 |
+
// fetch pipeline
|
261 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
262 |
+
// with a separate quantization scale
|
263 |
+
>
|
264 |
+
__global__ void Marlin(
|
265 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
266 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
267 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
268 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
269 |
+
// (k/groupsize)xn
|
270 |
+
int num_groups, // number of scale groups per output channel
|
271 |
+
int prob_m, // batch dimension m
|
272 |
+
int prob_n, // output dimension n
|
273 |
+
int prob_k, // reduction dimension k
|
274 |
+
int* locks // extra global storage for barrier synchronization
|
275 |
+
) {
|
276 |
+
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
277 |
+
// same size, which might involve multiple column "slices" (of width 16 *
|
278 |
+
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
279 |
+
// example:
|
280 |
+
// 0 1 3
|
281 |
+
// 0 2 3
|
282 |
+
// 1 2 4
|
283 |
+
// While this kind of partitioning makes things somewhat more complicated, it
|
284 |
+
// ensures good utilization of all SMs for many kinds of shape and GPU
|
285 |
+
// configurations, while requiring as few slow global cross-threadblock
|
286 |
+
// reductions as possible.
|
287 |
+
using Dtype = ScalarType<scalar_t>;
|
288 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
289 |
+
using FragA = typename ScalarType<scalar_t>::FragA;
|
290 |
+
using FragB = typename ScalarType<scalar_t>::FragB;
|
291 |
+
using FragC = typename ScalarType<scalar_t>::FragC;
|
292 |
+
using FragS = typename ScalarType<scalar_t>::FragS;
|
293 |
+
|
294 |
+
constexpr int pack_factor = 32 / num_bits;
|
295 |
+
|
296 |
+
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
297 |
+
// better partitioning with less reductions
|
298 |
+
int parallel = 1;
|
299 |
+
if (prob_m > 16 * thread_m_blocks) {
|
300 |
+
parallel = prob_m / (16 * thread_m_blocks);
|
301 |
+
prob_m = 16 * thread_m_blocks;
|
302 |
+
}
|
303 |
+
|
304 |
+
int k_tiles = prob_k / 16 / thread_k_blocks;
|
305 |
+
int n_tiles = prob_n / 16 / thread_n_blocks;
|
306 |
+
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
|
307 |
+
|
308 |
+
int slice_row = (iters * blockIdx.x) % k_tiles;
|
309 |
+
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
310 |
+
int slice_col = slice_col_par;
|
311 |
+
int slice_iters; // number of threadblock tiles in the current slice
|
312 |
+
int slice_count =
|
313 |
+
0; // total number of active threadblocks in the current slice
|
314 |
+
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
315 |
+
// top
|
316 |
+
|
317 |
+
// We can easily implement parallel problem execution by just remapping
|
318 |
+
// indices and advancing global pointers
|
319 |
+
if (slice_col_par >= n_tiles) {
|
320 |
+
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
|
321 |
+
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
322 |
+
locks += (slice_col_par / n_tiles) * n_tiles;
|
323 |
+
slice_col = slice_col_par % n_tiles;
|
324 |
+
}
|
325 |
+
|
326 |
+
// Compute all information about the current slice which is required for
|
327 |
+
// synchronization.
|
328 |
+
auto init_slice = [&]() {
|
329 |
+
slice_iters =
|
330 |
+
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
331 |
+
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
332 |
+
if (slice_iters == 0) return;
|
333 |
+
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
334 |
+
slice_count = 1;
|
335 |
+
slice_idx = 0;
|
336 |
+
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
|
337 |
+
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
338 |
+
int col_off = col_first - k_tiles * slice_col_par;
|
339 |
+
slice_count = div_ceil(k_tiles - col_off, iters);
|
340 |
+
if (col_off > 0) slice_count++;
|
341 |
+
int delta_first = iters * blockIdx.x - col_first;
|
342 |
+
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
343 |
+
slice_idx = slice_count - 1;
|
344 |
+
else {
|
345 |
+
slice_idx = slice_count - 1 - delta_first / iters;
|
346 |
+
if (col_off > 0) slice_idx--;
|
347 |
+
}
|
348 |
+
}
|
349 |
+
if (slice_col == n_tiles) {
|
350 |
+
A += 16 * thread_m_blocks * prob_k / 8;
|
351 |
+
C += 16 * thread_m_blocks * prob_n / 8;
|
352 |
+
locks += n_tiles;
|
353 |
+
slice_col = 0;
|
354 |
+
}
|
355 |
+
};
|
356 |
+
init_slice();
|
357 |
+
|
358 |
+
// A sizes/strides
|
359 |
+
|
360 |
+
// stride of the A matrix in global memory
|
361 |
+
int a_gl_stride = prob_k / 8;
|
362 |
+
// stride of an A matrix tile in shared memory
|
363 |
+
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
|
364 |
+
// delta between subsequent A tiles in global memory
|
365 |
+
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
|
366 |
+
// between subsequent accesses within a tile
|
367 |
+
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
|
368 |
+
// between shared memory writes
|
369 |
+
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
|
370 |
+
// between shared memory tile reads
|
371 |
+
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
|
372 |
+
// within a shared memory tile
|
373 |
+
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
|
374 |
+
// overall size of a tile
|
375 |
+
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
|
376 |
+
// number of shared write iterations for a tile
|
377 |
+
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
|
378 |
+
|
379 |
+
// B sizes/strides
|
380 |
+
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
381 |
+
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
382 |
+
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
|
383 |
+
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
384 |
+
|
385 |
+
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
386 |
+
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
|
387 |
+
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
|
388 |
+
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
|
389 |
+
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
390 |
+
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
391 |
+
|
392 |
+
// Scale sizes/strides without act_order
|
393 |
+
int s_gl_stride = prob_n / 8;
|
394 |
+
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
395 |
+
|
396 |
+
// Scale size/strides with act_order
|
397 |
+
constexpr int tb_k = 16 * thread_k_blocks;
|
398 |
+
constexpr int g_idx_stage = 0;
|
399 |
+
// constexpr int act_s_row_stride = 1;
|
400 |
+
// int act_s_col_stride = act_s_row_stride * num_groups;
|
401 |
+
int act_s_col_stride = 1;
|
402 |
+
int act_s_col_warp_stride = act_s_col_stride * 8;
|
403 |
+
int tb_n_warps = thread_n_blocks / 4;
|
404 |
+
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
405 |
+
|
406 |
+
// Global A read index of current thread.
|
407 |
+
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
408 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
409 |
+
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
410 |
+
// Shared write index of current thread.
|
411 |
+
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
412 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
413 |
+
// Shared read index.
|
414 |
+
int a_sh_rd =
|
415 |
+
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
416 |
+
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
417 |
+
|
418 |
+
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
|
419 |
+
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
420 |
+
b_gl_rd += b_sh_stride * slice_col;
|
421 |
+
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
422 |
+
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
423 |
+
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
424 |
+
|
425 |
+
// For act_order
|
426 |
+
int slice_k_start = tb_k * slice_row;
|
427 |
+
int slice_k_start_shared_fetch = slice_k_start;
|
428 |
+
int slice_n_offset = act_s_col_tb_stride * slice_col;
|
429 |
+
|
430 |
+
// No act_order
|
431 |
+
int s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
432 |
+
int s_sh_wr = threadIdx.x;
|
433 |
+
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
434 |
+
|
435 |
+
// We scale a `half2` tile in row-major layout for column-wise quantization.
|
436 |
+
int s_sh_rd =
|
437 |
+
8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
|
438 |
+
|
439 |
+
// Precompute which thread should not read memory in which iterations; this is
|
440 |
+
// needed if there are more threads than required for a certain tilesize or
|
441 |
+
// when the batchsize is not a multiple of 16.
|
442 |
+
bool a_sh_wr_pred[a_sh_wr_iters];
|
443 |
+
#pragma unroll
|
444 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
445 |
+
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
446 |
+
|
447 |
+
// To ensure that writing and reading A tiles to/from shared memory, the
|
448 |
+
// latter in fragment format, is fully bank conflict free, we need to use a
|
449 |
+
// rather fancy XOR-based layout. The key here is that neither reads nor
|
450 |
+
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
|
451 |
+
// same shared memory banks. Further, it seems (based on NSight-Compute) that
|
452 |
+
// each warp must also write a consecutive memory segment?
|
453 |
+
auto transform_a = [&](int i) {
|
454 |
+
int row = i / a_gl_rd_delta_o;
|
455 |
+
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
456 |
+
};
|
457 |
+
// Since the computation of this remapping is non-trivial and, due to our main
|
458 |
+
// loop unrolls, all shared memory accesses are static, we simply precompute
|
459 |
+
// both transformed reads and writes.
|
460 |
+
int a_sh_wr_trans[a_sh_wr_iters];
|
461 |
+
#pragma unroll
|
462 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
463 |
+
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
464 |
+
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
465 |
+
#pragma unroll
|
466 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
467 |
+
#pragma unroll
|
468 |
+
for (int j = 0; j < thread_m_blocks; j++)
|
469 |
+
a_sh_rd_trans[i][j] =
|
470 |
+
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
471 |
+
}
|
472 |
+
|
473 |
+
// Since B-accesses have non-constant stride they have to be computed at
|
474 |
+
// runtime; we break dependencies between subsequent accesses with a tile by
|
475 |
+
// maintining multiple pointers (we have enough registers), a tiny
|
476 |
+
// optimization.
|
477 |
+
const int4* B_ptr[b_sh_wr_iters];
|
478 |
+
#pragma unroll
|
479 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
480 |
+
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
481 |
+
|
482 |
+
extern __shared__ int4 sh[];
|
483 |
+
// Shared memory storage for global fetch pipelines.
|
484 |
+
int4* sh_a = sh;
|
485 |
+
int4* sh_b = sh_a + (stages * a_sh_stage);
|
486 |
+
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
487 |
+
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
|
488 |
+
|
489 |
+
// Register storage for double buffer of shared memory reads.
|
490 |
+
FragA frag_a[2][thread_m_blocks];
|
491 |
+
I4 frag_b_quant[2][b_thread_vecs];
|
492 |
+
FragC frag_c[thread_m_blocks][4][2];
|
493 |
+
FragS frag_s[2][4];
|
494 |
+
|
495 |
+
// Zero accumulators.
|
496 |
+
auto zero_accums = [&]() {
|
497 |
+
#pragma unroll
|
498 |
+
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
499 |
+
reinterpret_cast<float*>(frag_c)[i] = 0;
|
500 |
+
};
|
501 |
+
|
502 |
+
int sh_first_group_id = -1;
|
503 |
+
int sh_num_groups = -1;
|
504 |
+
constexpr int sh_max_num_groups = 32;
|
505 |
+
|
506 |
+
auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
|
507 |
+
int last_group_id) {
|
508 |
+
sh_first_group_id = first_group_id;
|
509 |
+
sh_num_groups = last_group_id - first_group_id + 1;
|
510 |
+
|
511 |
+
if (sh_num_groups < sh_max_num_groups) {
|
512 |
+
sh_num_groups = sh_max_num_groups;
|
513 |
+
}
|
514 |
+
|
515 |
+
if (sh_first_group_id + sh_num_groups > num_groups) {
|
516 |
+
sh_num_groups = num_groups - sh_first_group_id;
|
517 |
+
}
|
518 |
+
|
519 |
+
int row_offset = first_group_id * s_gl_stride;
|
520 |
+
|
521 |
+
if (is_async) {
|
522 |
+
for (int i = 0; i < sh_num_groups; i++) {
|
523 |
+
if (threadIdx.x < s_sh_stride) {
|
524 |
+
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
|
525 |
+
&scales_ptr[row_offset + (i * s_gl_stride) +
|
526 |
+
slice_n_offset + threadIdx.x]);
|
527 |
+
}
|
528 |
+
}
|
529 |
+
} else {
|
530 |
+
for (int i = 0; i < sh_num_groups; i++) {
|
531 |
+
if (threadIdx.x < s_sh_stride) {
|
532 |
+
sh_s[(i * s_sh_stride) + threadIdx.x] =
|
533 |
+
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
|
534 |
+
threadIdx.x];
|
535 |
+
}
|
536 |
+
}
|
537 |
+
}
|
538 |
+
};
|
539 |
+
// Asynchronously fetch the next A, B and s tile from global to the next
|
540 |
+
// shared memory pipeline location.
|
541 |
+
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
542 |
+
if (pred) {
|
543 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
544 |
+
#pragma unroll
|
545 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
546 |
+
cp_async4_pred(
|
547 |
+
&sh_a_stage[a_sh_wr_trans[i]],
|
548 |
+
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
549 |
+
a_sh_wr_pred[i]);
|
550 |
+
}
|
551 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
552 |
+
#pragma unroll
|
553 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
554 |
+
#pragma unroll
|
555 |
+
for (int j = 0; j < b_thread_vecs; j++) {
|
556 |
+
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
557 |
+
}
|
558 |
+
|
559 |
+
B_ptr[i] += b_gl_rd_delta_o;
|
560 |
+
}
|
561 |
+
}
|
562 |
+
// Insert a fence even when we are winding down the pipeline to ensure that
|
563 |
+
// waiting is also correct at this point.
|
564 |
+
cp_async_fence();
|
565 |
+
};
|
566 |
+
|
567 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
568 |
+
auto wait_for_stage = [&]() {
|
569 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
570 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
571 |
+
// shared memory load is fully complete (as it may otherwise be
|
572 |
+
// overwritten).
|
573 |
+
cp_async_wait<stages - 2>();
|
574 |
+
__syncthreads();
|
575 |
+
};
|
576 |
+
|
577 |
+
// Load the next sub-tile from the current location in the shared memory pipe
|
578 |
+
// into the current register buffer.
|
579 |
+
auto fetch_to_registers = [&](int k, int pipe) {
|
580 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
581 |
+
#pragma unroll
|
582 |
+
for (int i = 0; i < thread_m_blocks; i++)
|
583 |
+
ldsm4<scalar_t>(frag_a[k % 2][i],
|
584 |
+
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
585 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
586 |
+
|
587 |
+
#pragma unroll
|
588 |
+
for (int i = 0; i < b_thread_vecs; i++) {
|
589 |
+
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
590 |
+
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
591 |
+
}
|
592 |
+
};
|
593 |
+
|
594 |
+
bool is_same_group[stages];
|
595 |
+
int same_group_id[stages];
|
596 |
+
|
597 |
+
auto init_same_group = [&](int pipe) {
|
598 |
+
is_same_group[pipe] = false;
|
599 |
+
same_group_id[pipe] = 0;
|
600 |
+
return;
|
601 |
+
};
|
602 |
+
|
603 |
+
// Execute the actual tensor core matmul of a sub-tile.
|
604 |
+
auto matmul = [&](int k) {
|
605 |
+
// We have the m dimension as the inner loop in order to encourage overlapping
|
606 |
+
// dequantization and matmul operations.
|
607 |
+
#pragma unroll
|
608 |
+
for (int j = 0; j < 4; j++) {
|
609 |
+
FragB frag_b0;
|
610 |
+
FragB frag_b1;
|
611 |
+
|
612 |
+
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
613 |
+
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
614 |
+
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
615 |
+
|
616 |
+
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
617 |
+
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
618 |
+
|
619 |
+
#pragma unroll
|
620 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
621 |
+
mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
622 |
+
mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
623 |
+
}
|
624 |
+
}
|
625 |
+
};
|
626 |
+
|
627 |
+
// Since we slice across the k dimension of a tile in order to increase the
|
628 |
+
// number of warps while keeping the n dimension of a tile reasonable, we have
|
629 |
+
// multiple warps that accumulate their partial sums of the same output
|
630 |
+
// location; which we have to reduce over in the end. We do in shared memory.
|
631 |
+
auto thread_block_reduce = [&]() {
|
632 |
+
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
633 |
+
if (red_off >= 1) {
|
634 |
+
int red_idx = threadIdx.x / b_sh_stride_threads;
|
635 |
+
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
636 |
+
constexpr int red_sh_delta = b_sh_stride_threads;
|
637 |
+
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
638 |
+
(threadIdx.x % b_sh_stride_threads);
|
639 |
+
|
640 |
+
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
641 |
+
// unnecessary read or write iterations, e.g., for two warps we write only
|
642 |
+
// once by warp 1 and read only once by warp 0.
|
643 |
+
|
644 |
+
#pragma unroll
|
645 |
+
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
646 |
+
#pragma unroll
|
647 |
+
for (int i = red_off; i > 0; i /= 2) {
|
648 |
+
if (i <= red_idx && red_idx < 2 * i) {
|
649 |
+
#pragma unroll
|
650 |
+
for (int j = 0; j < 4 * 2; j++) {
|
651 |
+
int red_sh_wr =
|
652 |
+
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
653 |
+
if (i < red_off) {
|
654 |
+
float* c_rd =
|
655 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
656 |
+
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
657 |
+
#pragma unroll
|
658 |
+
for (int k = 0; k < 4; k++)
|
659 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
660 |
+
c_rd[k] + c_wr[k];
|
661 |
+
}
|
662 |
+
sh[red_sh_wr] =
|
663 |
+
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
664 |
+
}
|
665 |
+
}
|
666 |
+
__syncthreads();
|
667 |
+
}
|
668 |
+
if (red_idx == 0) {
|
669 |
+
#pragma unroll
|
670 |
+
for (int i = 0; i < 4 * 2; i++) {
|
671 |
+
float* c_rd =
|
672 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
673 |
+
#pragma unroll
|
674 |
+
for (int j = 0; j < 4; j++)
|
675 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
676 |
+
c_rd[j];
|
677 |
+
}
|
678 |
+
}
|
679 |
+
__syncthreads();
|
680 |
+
}
|
681 |
+
}
|
682 |
+
};
|
683 |
+
|
684 |
+
// Since multiple threadblocks may process parts of the same column slice, we
|
685 |
+
// finally have to globally reduce over the results. As the striped
|
686 |
+
// partitioning minimizes the number of such reductions and our outputs are
|
687 |
+
// usually rather small, we perform this reduction serially in L2 cache.
|
688 |
+
auto global_reduce = [&](bool first = false, bool last = false) {
|
689 |
+
// We are very careful here to reduce directly in the output buffer to
|
690 |
+
// maximize L2 cache utilization in this step. To do this, we write out
|
691 |
+
// results in FP16 (but still reduce with FP32 compute).
|
692 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
693 |
+
if (threadIdx.x < active_threads) {
|
694 |
+
int c_gl_stride = prob_n / 8;
|
695 |
+
int c_gl_wr_delta_o = 8 * c_gl_stride;
|
696 |
+
int c_gl_wr_delta_i = 4 * (active_threads / 32);
|
697 |
+
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
|
698 |
+
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
699 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
700 |
+
constexpr int c_sh_wr_delta = active_threads;
|
701 |
+
int c_sh_wr = threadIdx.x;
|
702 |
+
|
703 |
+
int row = (threadIdx.x % 32) / 4;
|
704 |
+
|
705 |
+
if (!first) {
|
706 |
+
// Interestingly, doing direct global accesses here really seems to mess up
|
707 |
+
// the compiler and lead to slowdowns, hence we also use async-copies even
|
708 |
+
// though these fetches are not actually asynchronous.
|
709 |
+
#pragma unroll
|
710 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
711 |
+
cp_async4_pred(
|
712 |
+
&sh[c_sh_wr + c_sh_wr_delta * i],
|
713 |
+
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
714 |
+
c_gl_wr_delta_i * (i % 2)],
|
715 |
+
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
716 |
+
}
|
717 |
+
cp_async_fence();
|
718 |
+
cp_async_wait<0>();
|
719 |
+
}
|
720 |
+
|
721 |
+
#pragma unroll
|
722 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
723 |
+
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
724 |
+
if (!first) {
|
725 |
+
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
726 |
+
#pragma unroll
|
727 |
+
for (int j = 0; j < 2 * 4; j++) {
|
728 |
+
reinterpret_cast<float*>(
|
729 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
730 |
+
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
|
731 |
+
}
|
732 |
+
}
|
733 |
+
if (!last) {
|
734 |
+
int4 c;
|
735 |
+
#pragma unroll
|
736 |
+
for (int j = 0; j < 2 * 4; j++) {
|
737 |
+
reinterpret_cast<scalar_t*>(&c)[j] =
|
738 |
+
Dtype::float2num(reinterpret_cast<float*>(
|
739 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
740 |
+
}
|
741 |
+
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
742 |
+
c;
|
743 |
+
}
|
744 |
+
}
|
745 |
+
}
|
746 |
+
}
|
747 |
+
};
|
748 |
+
|
749 |
+
// Write out the reduce final result in the correct layout. We only actually
|
750 |
+
// reshuffle matrix fragments in this step, the reduction above is performed
|
751 |
+
// in fragment layout.
|
752 |
+
auto write_result = [&]() {
|
753 |
+
int c_gl_stride = prob_n / 8;
|
754 |
+
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
755 |
+
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
756 |
+
constexpr int c_sh_rd_delta =
|
757 |
+
c_sh_stride * (threads / (2 * thread_n_blocks));
|
758 |
+
|
759 |
+
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
760 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
761 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
762 |
+
int c_sh_wr =
|
763 |
+
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
|
764 |
+
c_sh_wr += 32 * (threadIdx.x / 32);
|
765 |
+
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
766 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
767 |
+
|
768 |
+
int c_gl_wr_end = c_gl_stride * prob_m;
|
769 |
+
|
770 |
+
// We first reorder in shared memory to guarantee the most efficient final
|
771 |
+
// global write patterns
|
772 |
+
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
773 |
+
scalar_t2 res =
|
774 |
+
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
775 |
+
|
776 |
+
((scalar_t2*)sh)[idx] = res;
|
777 |
+
};
|
778 |
+
|
779 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
780 |
+
#pragma unroll
|
781 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
782 |
+
#pragma unroll
|
783 |
+
for (int j = 0; j < 4; j++) {
|
784 |
+
int wr = c_sh_wr + 8 * j;
|
785 |
+
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
786 |
+
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
787 |
+
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
788 |
+
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
789 |
+
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
790 |
+
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
791 |
+
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
792 |
+
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
793 |
+
}
|
794 |
+
c_sh_wr += 16 * (4 * c_sh_stride);
|
795 |
+
}
|
796 |
+
}
|
797 |
+
__syncthreads();
|
798 |
+
|
799 |
+
#pragma unroll
|
800 |
+
for (int i = 0;
|
801 |
+
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
802 |
+
i++) {
|
803 |
+
if (c_gl_wr < c_gl_wr_end) {
|
804 |
+
C[c_gl_wr] = sh[c_sh_rd];
|
805 |
+
c_gl_wr += c_gl_wr_delta;
|
806 |
+
c_sh_rd += c_sh_rd_delta;
|
807 |
+
}
|
808 |
+
}
|
809 |
+
};
|
810 |
+
|
811 |
+
// Start global fetch and register load pipelines.
|
812 |
+
auto start_pipes = [&]() {
|
813 |
+
|
814 |
+
#pragma unroll
|
815 |
+
for (int i = 0; i < stages - 1; i++) {
|
816 |
+
fetch_to_shared(i, i, i < slice_iters);
|
817 |
+
}
|
818 |
+
|
819 |
+
zero_accums();
|
820 |
+
wait_for_stage();
|
821 |
+
init_same_group(0);
|
822 |
+
fetch_to_registers(0, 0);
|
823 |
+
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
824 |
+
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
825 |
+
};
|
826 |
+
if (slice_iters) {
|
827 |
+
start_pipes();
|
828 |
+
}
|
829 |
+
|
830 |
+
// Main loop.
|
831 |
+
while (slice_iters) {
|
832 |
+
// We unroll over both the global fetch and the register load pipeline to
|
833 |
+
// ensure all shared memory accesses are static. Note that both pipelines
|
834 |
+
// have even length meaning that the next iteration will always start at
|
835 |
+
// index 0.
|
836 |
+
|
837 |
+
#pragma unroll
|
838 |
+
for (int pipe = 0; pipe < stages;) {
|
839 |
+
#pragma unroll
|
840 |
+
for (int k = 0; k < b_sh_wr_iters; k++) {
|
841 |
+
fetch_to_registers(k + 1, pipe % stages);
|
842 |
+
if (k == b_sh_wr_iters - 2) {
|
843 |
+
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
844 |
+
slice_iters >= stages);
|
845 |
+
pipe++;
|
846 |
+
wait_for_stage();
|
847 |
+
init_same_group(pipe % stages);
|
848 |
+
}
|
849 |
+
matmul(k);
|
850 |
+
}
|
851 |
+
slice_iters--;
|
852 |
+
if (slice_iters == 0) {
|
853 |
+
break;
|
854 |
+
}
|
855 |
+
}
|
856 |
+
|
857 |
+
a_gl_rd += a_gl_rd_delta_o * stages;
|
858 |
+
slice_k_start += tb_k * stages;
|
859 |
+
slice_k_start_shared_fetch += tb_k * stages;
|
860 |
+
|
861 |
+
// Process results and, if necessary, proceed to the next column slice.
|
862 |
+
// While this pattern may not be the most readable, other ways of writing
|
863 |
+
// the loop seemed to noticeably worse performance after compilation.
|
864 |
+
if (slice_iters == 0) {
|
865 |
+
cp_async_wait<0>();
|
866 |
+
bool last = slice_idx == slice_count - 1;
|
867 |
+
// For per-column scales, we only fetch them here in the final step before
|
868 |
+
// write-out
|
869 |
+
if (s_sh_wr_pred) {
|
870 |
+
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
871 |
+
}
|
872 |
+
cp_async_fence();
|
873 |
+
|
874 |
+
thread_block_reduce();
|
875 |
+
|
876 |
+
cp_async_wait<0>();
|
877 |
+
__syncthreads();
|
878 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
879 |
+
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
880 |
+
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
881 |
+
}
|
882 |
+
|
883 |
+
// For 8-bit channelwise, we apply the scale before the global reduction
|
884 |
+
// that converts the fp32 results to fp16 (so that we avoid possible
|
885 |
+
// overflow in fp16)
|
886 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
887 |
+
#pragma unroll
|
888 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
889 |
+
#pragma unroll
|
890 |
+
for (int j = 0; j < 4; j++) {
|
891 |
+
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
|
892 |
+
frag_s[j / 2][2 * (j % 2) + 0]);
|
893 |
+
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
|
894 |
+
frag_s[j / 2][2 * (j % 2) + 0]);
|
895 |
+
|
896 |
+
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
|
897 |
+
frag_s[j / 2][2 * (j % 2) + 1]);
|
898 |
+
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
|
899 |
+
frag_s[j / 2][2 * (j % 2) + 1]);
|
900 |
+
}
|
901 |
+
}
|
902 |
+
}
|
903 |
+
|
904 |
+
if (slice_count > 1) { // only globally reduce if there is more than one
|
905 |
+
// block in a slice
|
906 |
+
barrier_acquire(&locks[slice_col], slice_idx);
|
907 |
+
global_reduce(slice_idx == 0, last);
|
908 |
+
barrier_release(&locks[slice_col], last);
|
909 |
+
}
|
910 |
+
if (last) // only the last block in a slice actually writes the result
|
911 |
+
write_result();
|
912 |
+
slice_row = 0;
|
913 |
+
slice_col_par++;
|
914 |
+
slice_col++;
|
915 |
+
init_slice();
|
916 |
+
if (slice_iters) {
|
917 |
+
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
918 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
919 |
+
#pragma unroll
|
920 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
921 |
+
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
922 |
+
if (slice_col == 0) {
|
923 |
+
#pragma unroll
|
924 |
+
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
925 |
+
}
|
926 |
+
|
927 |
+
// Update slice k/n for scales loading
|
928 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
929 |
+
|
930 |
+
start_pipes();
|
931 |
+
}
|
932 |
+
}
|
933 |
+
}
|
934 |
+
}
|
935 |
+
|
936 |
+
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
937 |
+
THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \
|
938 |
+
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
939 |
+
thread_n_blocks == THREAD_N_BLOCKS && \
|
940 |
+
thread_k_blocks == THREAD_K_BLOCKS && \
|
941 |
+
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
942 |
+
cudaFuncSetAttribute( \
|
943 |
+
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
944 |
+
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \
|
945 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
946 |
+
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
947 |
+
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS> \
|
948 |
+
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
949 |
+
A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \
|
950 |
+
locks); \
|
951 |
+
}
|
952 |
+
|
953 |
+
typedef struct {
|
954 |
+
int thread_k;
|
955 |
+
int thread_n;
|
956 |
+
int num_threads;
|
957 |
+
} thread_config_t;
|
958 |
+
|
959 |
+
typedef struct {
|
960 |
+
int max_m_blocks;
|
961 |
+
thread_config_t tb_cfg;
|
962 |
+
} exec_config_t;
|
963 |
+
|
964 |
+
thread_config_t small_batch_thread_configs[] = {
|
965 |
+
// Ordered by priority
|
966 |
+
|
967 |
+
// thread_k, thread_n, num_threads
|
968 |
+
{128, 128, 256},
|
969 |
+
{64, 128, 128},
|
970 |
+
{128, 64, 128},
|
971 |
+
};
|
972 |
+
|
973 |
+
thread_config_t large_batch_thread_configs[] = {
|
974 |
+
// Ordered by priority
|
975 |
+
|
976 |
+
// thread_k, thread_n, num_threads
|
977 |
+
{64, 256, 256},
|
978 |
+
{64, 128, 128},
|
979 |
+
{128, 64, 128},
|
980 |
+
|
981 |
+
};
|
982 |
+
|
983 |
+
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
984 |
+
int prob_n, int prob_k, int num_bits,
|
985 |
+
int group_size) {
|
986 |
+
int tb_n = th_config.thread_n;
|
987 |
+
|
988 |
+
// Get max scale groups per thread-block
|
989 |
+
// Fixed for channelwise
|
990 |
+
int tb_groups = 1;
|
991 |
+
int tb_scales = tb_groups * tb_n * 2;
|
992 |
+
|
993 |
+
return tb_scales * pipe_stages;
|
994 |
+
}
|
995 |
+
|
996 |
+
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
997 |
+
int prob_m, int prob_n, int prob_k, int num_bits,
|
998 |
+
int scales_cache_size, int max_shared_mem) {
|
999 |
+
int pack_factor = 32 / num_bits;
|
1000 |
+
|
1001 |
+
// Get B size
|
1002 |
+
int tb_k = th_config.thread_k;
|
1003 |
+
int tb_n = th_config.thread_n;
|
1004 |
+
|
1005 |
+
int b_size = (tb_k * tb_n / pack_factor) * 4;
|
1006 |
+
|
1007 |
+
// Get A size
|
1008 |
+
int m_blocks = div_ceil(prob_m, 16);
|
1009 |
+
int tb_max_m = 16;
|
1010 |
+
|
1011 |
+
while (true) {
|
1012 |
+
if (m_blocks >= max_m_blocks) {
|
1013 |
+
tb_max_m *= max_m_blocks;
|
1014 |
+
break;
|
1015 |
+
}
|
1016 |
+
|
1017 |
+
max_m_blocks--;
|
1018 |
+
if (max_m_blocks == 0) {
|
1019 |
+
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
1020 |
+
}
|
1021 |
+
}
|
1022 |
+
|
1023 |
+
int a_size = (tb_max_m * tb_k) * 2;
|
1024 |
+
|
1025 |
+
float pipe_size = (a_size + b_size) * pipe_stages;
|
1026 |
+
|
1027 |
+
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
1028 |
+
|
1029 |
+
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
1030 |
+
}
|
1031 |
+
|
1032 |
+
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
1033 |
+
int prob_m, int prob_n, int prob_k, int num_bits,
|
1034 |
+
int group_size, int max_shared_mem) {
|
1035 |
+
// Sanity
|
1036 |
+
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
1037 |
+
th_config.num_threads == -1) {
|
1038 |
+
return false;
|
1039 |
+
}
|
1040 |
+
|
1041 |
+
// Verify K/N are divisible by thread K/N
|
1042 |
+
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
1043 |
+
return false;
|
1044 |
+
}
|
1045 |
+
|
1046 |
+
// Verify min for thread K/N
|
1047 |
+
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
1048 |
+
return false;
|
1049 |
+
}
|
1050 |
+
|
1051 |
+
// num_threads must be at least 128 (= 4 warps)
|
1052 |
+
if (th_config.num_threads < 128) {
|
1053 |
+
return false;
|
1054 |
+
}
|
1055 |
+
|
1056 |
+
// Determine cache for scales
|
1057 |
+
int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n,
|
1058 |
+
prob_k, num_bits, group_size);
|
1059 |
+
|
1060 |
+
// Check that pipeline fits into cache
|
1061 |
+
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
1062 |
+
num_bits, scales_cache_size, max_shared_mem)) {
|
1063 |
+
return false;
|
1064 |
+
}
|
1065 |
+
|
1066 |
+
return true;
|
1067 |
+
}
|
1068 |
+
|
1069 |
+
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
1070 |
+
int num_bits, int group_size,
|
1071 |
+
int max_shared_mem) {
|
1072 |
+
int max_m_blocks = 4;
|
1073 |
+
while (max_m_blocks > 0) {
|
1074 |
+
if (prob_m <= 16) {
|
1075 |
+
for (auto th_config : small_batch_thread_configs) {
|
1076 |
+
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
1077 |
+
num_bits, group_size, max_shared_mem)) {
|
1078 |
+
return exec_config_t{max_m_blocks, th_config};
|
1079 |
+
}
|
1080 |
+
}
|
1081 |
+
} else {
|
1082 |
+
for (auto th_config : large_batch_thread_configs) {
|
1083 |
+
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
1084 |
+
num_bits, group_size, max_shared_mem)) {
|
1085 |
+
return exec_config_t{max_m_blocks, th_config};
|
1086 |
+
}
|
1087 |
+
}
|
1088 |
+
}
|
1089 |
+
|
1090 |
+
max_m_blocks--; // Process less M blocks per invocation to reduce cache
|
1091 |
+
// usage
|
1092 |
+
}
|
1093 |
+
|
1094 |
+
return exec_config_t{0, {-1, -1, -1}};
|
1095 |
+
}
|
1096 |
+
|
1097 |
+
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
1098 |
+
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
1099 |
+
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
1100 |
+
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
1101 |
+
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)
|
1102 |
+
|
1103 |
+
template <typename scalar_t>
|
1104 |
+
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m,
|
1105 |
+
int prob_n, int prob_k, void* workspace, int num_bits,
|
1106 |
+
int num_groups, int group_size, int dev,
|
1107 |
+
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
1108 |
+
int max_par) {
|
1109 |
+
TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits);
|
1110 |
+
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
1111 |
+
", ", prob_n, ", ", prob_k, "]");
|
1112 |
+
|
1113 |
+
int tot_m = prob_m;
|
1114 |
+
int tot_m_blocks = div_ceil(tot_m, 16);
|
1115 |
+
int pad = 16 * tot_m_blocks - tot_m;
|
1116 |
+
|
1117 |
+
if (sms == -1) {
|
1118 |
+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
1119 |
+
}
|
1120 |
+
|
1121 |
+
int max_shared_mem = 0;
|
1122 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
1123 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
1124 |
+
TORCH_CHECK(max_shared_mem > 0);
|
1125 |
+
|
1126 |
+
// Set thread config
|
1127 |
+
exec_config_t exec_cfg;
|
1128 |
+
if (thread_k != -1 && thread_n != -1) {
|
1129 |
+
// User-defined config
|
1130 |
+
exec_cfg =
|
1131 |
+
exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
|
1132 |
+
} else {
|
1133 |
+
// Auto config
|
1134 |
+
exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,
|
1135 |
+
group_size, max_shared_mem);
|
1136 |
+
}
|
1137 |
+
|
1138 |
+
TORCH_CHECK(
|
1139 |
+
exec_cfg.max_m_blocks > 0 &&
|
1140 |
+
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,
|
1141 |
+
prob_n, prob_k, num_bits, group_size, max_shared_mem),
|
1142 |
+
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
|
1143 |
+
", thread_k = ", exec_cfg.tb_cfg.thread_k,
|
1144 |
+
", thread_n = ", exec_cfg.tb_cfg.thread_n,
|
1145 |
+
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m,
|
1146 |
+
", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
1147 |
+
", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem);
|
1148 |
+
|
1149 |
+
int num_threads = exec_cfg.tb_cfg.num_threads;
|
1150 |
+
thread_k = exec_cfg.tb_cfg.thread_k;
|
1151 |
+
thread_n = exec_cfg.tb_cfg.thread_n;
|
1152 |
+
|
1153 |
+
int thread_k_blocks = thread_k / 16;
|
1154 |
+
int thread_n_blocks = thread_n / 16;
|
1155 |
+
|
1156 |
+
int blocks = sms;
|
1157 |
+
|
1158 |
+
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
1159 |
+
" is not divisible by thread_n = ", thread_n);
|
1160 |
+
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
1161 |
+
" is not divisible by thread_k = ", thread_k);
|
1162 |
+
|
1163 |
+
int group_blocks = -1;
|
1164 |
+
|
1165 |
+
const int4* A_ptr = (const int4*)A;
|
1166 |
+
const int4* B_ptr = (const int4*)B;
|
1167 |
+
int4* C_ptr = (int4*)C;
|
1168 |
+
const int4* s_ptr = (const int4*)s;
|
1169 |
+
|
1170 |
+
int* locks = (int*)workspace;
|
1171 |
+
|
1172 |
+
// Main loop
|
1173 |
+
for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
|
1174 |
+
int thread_m_blocks = tot_m_blocks - i;
|
1175 |
+
prob_m = tot_m - 16 * i;
|
1176 |
+
int par = 1;
|
1177 |
+
if (thread_m_blocks > exec_cfg.max_m_blocks) {
|
1178 |
+
// Note that parallel > 1 currently only works for inputs without any
|
1179 |
+
// padding
|
1180 |
+
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
|
1181 |
+
if (par > max_par) par = max_par;
|
1182 |
+
prob_m = (16 * exec_cfg.max_m_blocks) * par;
|
1183 |
+
i += exec_cfg.max_m_blocks * (par - 1);
|
1184 |
+
thread_m_blocks = exec_cfg.max_m_blocks;
|
1185 |
+
}
|
1186 |
+
|
1187 |
+
// Define kernel configurations
|
1188 |
+
if (false) {
|
1189 |
+
}
|
1190 |
+
CALL_IF(8, 32, 2, 256)
|
1191 |
+
CALL_IF(8, 16, 4, 256)
|
1192 |
+
CALL_IF(8, 8, 8, 256)
|
1193 |
+
CALL_IF(8, 8, 4, 128)
|
1194 |
+
CALL_IF(8, 4, 8, 128)
|
1195 |
+
else {
|
1196 |
+
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
1197 |
+
str(prob_n) + ", " + str(prob_k) + "]" +
|
1198 |
+
", num_groups = " + str(num_groups) +
|
1199 |
+
", group_size = " + str(group_size) +
|
1200 |
+
", thread_m_blocks = " + str(thread_m_blocks) +
|
1201 |
+
", thread_n_blocks = " + str(thread_n_blocks) +
|
1202 |
+
", thread_k_blocks = " + str(thread_k_blocks));
|
1203 |
+
}
|
1204 |
+
|
1205 |
+
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
1206 |
+
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
|
1207 |
+
}
|
1208 |
+
}
|
1209 |
+
|
1210 |
+
} // namespace fp8_marlin
|
1211 |
+
|
1212 |
+
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
1213 |
+
torch::Tensor& b_scales, torch::Tensor& workspace,
|
1214 |
+
int64_t num_bits, int64_t size_m, int64_t size_n,
|
1215 |
+
int64_t size_k) {
|
1216 |
+
// Verify num_bits
|
1217 |
+
TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits);
|
1218 |
+
int pack_factor = 32 / num_bits;
|
1219 |
+
|
1220 |
+
// Verify A
|
1221 |
+
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
1222 |
+
", size_m = ", size_m);
|
1223 |
+
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
1224 |
+
", size_k = ", size_k);
|
1225 |
+
|
1226 |
+
// Verify B
|
1227 |
+
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
1228 |
+
" is not divisible by tile_size = ", marlin::tile_size);
|
1229 |
+
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
1230 |
+
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
1231 |
+
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
1232 |
+
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
1233 |
+
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
1234 |
+
" is not divisible by tile_size = ", marlin::tile_size);
|
1235 |
+
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
1236 |
+
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
1237 |
+
", actual_size_n = ", actual_size_n);
|
1238 |
+
|
1239 |
+
// Verify device and strides
|
1240 |
+
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
1241 |
+
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
1242 |
+
|
1243 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
1244 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
1245 |
+
|
1246 |
+
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
1247 |
+
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
1248 |
+
|
1249 |
+
// Alloc buffers
|
1250 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
1251 |
+
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
1252 |
+
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
1253 |
+
|
1254 |
+
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
1255 |
+
// auto -1)
|
1256 |
+
int thread_k = -1;
|
1257 |
+
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
1258 |
+
// auto -1)
|
1259 |
+
int thread_n = -1;
|
1260 |
+
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
|
1261 |
+
int sms = -1;
|
1262 |
+
|
1263 |
+
// Detect groupsize and act_order
|
1264 |
+
int num_groups = -1;
|
1265 |
+
int group_size = -1;
|
1266 |
+
|
1267 |
+
int b_rank = b_scales.sizes().size();
|
1268 |
+
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
|
1269 |
+
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
1270 |
+
" is not size_n = ", size_n);
|
1271 |
+
// Channelwise only for FP8
|
1272 |
+
TORCH_CHECK(b_scales.size(0) == 1)
|
1273 |
+
num_groups = b_scales.size(0);
|
1274 |
+
|
1275 |
+
// Verify workspace size
|
1276 |
+
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
1277 |
+
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
1278 |
+
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
1279 |
+
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
1280 |
+
"workspace.numel = ", workspace.numel(),
|
1281 |
+
" is below min_workspace_size = ", min_workspace_size);
|
1282 |
+
|
1283 |
+
int dev = a.get_device();
|
1284 |
+
if (a.scalar_type() == at::ScalarType::Half) {
|
1285 |
+
fp8_marlin::marlin_mm_f16i4<half>(
|
1286 |
+
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
1287 |
+
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
|
1288 |
+
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
|
1289 |
+
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
1290 |
+
marlin::max_par);
|
1291 |
+
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
1292 |
+
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
1293 |
+
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
1294 |
+
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
|
1295 |
+
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
|
1296 |
+
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
1297 |
+
marlin::max_par);
|
1298 |
+
} else {
|
1299 |
+
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
|
1300 |
+
}
|
1301 |
+
|
1302 |
+
return c;
|
1303 |
+
}
|
1304 |
+
|
1305 |
+
#endif
|
1306 |
+
|
fp8/nvidia/quant_utils.cuh
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "../../../attention/attention_dtypes.h"
|
4 |
+
#include <assert.h>
|
5 |
+
#include <float.h>
|
6 |
+
#include <stdint.h>
|
7 |
+
#include <type_traits>
|
8 |
+
|
9 |
+
namespace vllm {
|
10 |
+
#ifndef USE_ROCM
|
11 |
+
|
12 |
+
namespace fp8 {
|
13 |
+
#ifdef ENABLE_FP8
|
14 |
+
|
15 |
+
#if 0 // Disable the following code to reduce the binary size.
|
16 |
+
template <typename Tout, typename Tin>
|
17 |
+
__inline__ __device__ Tout
|
18 |
+
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
19 |
+
return x;
|
20 |
+
}
|
21 |
+
|
22 |
+
// fp8 -> half
|
23 |
+
template <>
|
24 |
+
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
25 |
+
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
26 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
27 |
+
return res.x;
|
28 |
+
}
|
29 |
+
|
30 |
+
// fp8x2 -> half2
|
31 |
+
template <>
|
32 |
+
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
|
33 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
34 |
+
union {
|
35 |
+
uint16_t u16[2];
|
36 |
+
uint32_t u32;
|
37 |
+
} tmp;
|
38 |
+
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
39 |
+
tmp.u16[0] = res.x;
|
40 |
+
tmp.u16[1] = res.y;
|
41 |
+
return tmp.u32;
|
42 |
+
}
|
43 |
+
|
44 |
+
// fp8x4 -> half2x2
|
45 |
+
template <>
|
46 |
+
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
|
47 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
48 |
+
union {
|
49 |
+
uint2 u32x2;
|
50 |
+
uint32_t u32[2];
|
51 |
+
} tmp;
|
52 |
+
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
|
53 |
+
tmp.u32[1] =
|
54 |
+
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
55 |
+
return tmp.u32x2;
|
56 |
+
}
|
57 |
+
|
58 |
+
// fp8x8 -> half2x4
|
59 |
+
template <>
|
60 |
+
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
|
61 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
62 |
+
union {
|
63 |
+
uint4 u64x2;
|
64 |
+
uint2 u64[2];
|
65 |
+
} tmp;
|
66 |
+
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
|
67 |
+
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
|
68 |
+
return tmp.u64x2;
|
69 |
+
}
|
70 |
+
|
71 |
+
// fp8 -> __nv_bfloat16
|
72 |
+
template <>
|
73 |
+
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
|
74 |
+
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
75 |
+
// Note there is no direct convert function from fp8 to bf16.
|
76 |
+
// fp8 -> half
|
77 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
78 |
+
// half -> float -> bf16
|
79 |
+
float tmp = half_to_float(res.x);
|
80 |
+
return __float2bfloat16(tmp);
|
81 |
+
}
|
82 |
+
|
83 |
+
// fp8x2 -> __nv_bfloat162
|
84 |
+
template <>
|
85 |
+
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
|
86 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
87 |
+
__nv_bfloat162 res;
|
88 |
+
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
|
89 |
+
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
|
90 |
+
return res;
|
91 |
+
}
|
92 |
+
|
93 |
+
// fp8x4 -> bf16_4_t
|
94 |
+
template <>
|
95 |
+
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
|
96 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
97 |
+
bf16_4_t res;
|
98 |
+
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
|
99 |
+
res.y =
|
100 |
+
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
101 |
+
return res;
|
102 |
+
}
|
103 |
+
|
104 |
+
// fp8x8 -> bf16_8_t
|
105 |
+
template <>
|
106 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
|
107 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
108 |
+
bf16_4_t tmp1, tmp2;
|
109 |
+
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
|
110 |
+
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
|
111 |
+
bf16_8_t res;
|
112 |
+
res.x = tmp1.x;
|
113 |
+
res.y = tmp1.y;
|
114 |
+
res.z = tmp2.x;
|
115 |
+
res.w = tmp2.y;
|
116 |
+
return res;
|
117 |
+
}
|
118 |
+
|
119 |
+
// fp8 -> float
|
120 |
+
template <>
|
121 |
+
__inline__ __device__ float
|
122 |
+
vec_conversion<float, uint8_t>(const uint8_t &a,
|
123 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
124 |
+
// fp8 -> half
|
125 |
+
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
|
126 |
+
// half -> float
|
127 |
+
return half_to_float(tmp);
|
128 |
+
}
|
129 |
+
|
130 |
+
// fp8x2 -> float2
|
131 |
+
template <>
|
132 |
+
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
|
133 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
134 |
+
// fp8x2 -> half2
|
135 |
+
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
|
136 |
+
// half2 -> float2
|
137 |
+
return half2_to_float2(tmp);
|
138 |
+
}
|
139 |
+
|
140 |
+
// fp8x4 -> float4
|
141 |
+
template <>
|
142 |
+
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
|
143 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
144 |
+
Float4_ res;
|
145 |
+
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
|
146 |
+
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
147 |
+
return res;
|
148 |
+
}
|
149 |
+
|
150 |
+
// fp8x8 -> float8
|
151 |
+
template <>
|
152 |
+
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
|
153 |
+
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
154 |
+
Float4_ tmp1, tmp2;
|
155 |
+
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
|
156 |
+
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
|
157 |
+
Float8_ res;
|
158 |
+
res.x = tmp1.x;
|
159 |
+
res.y = tmp1.y;
|
160 |
+
res.z = tmp2.x;
|
161 |
+
res.w = tmp2.y;
|
162 |
+
return res;
|
163 |
+
}
|
164 |
+
|
165 |
+
// half -> fp8
|
166 |
+
template <>
|
167 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
168 |
+
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
169 |
+
__half_raw tmp;
|
170 |
+
tmp.x = a;
|
171 |
+
__nv_fp8_storage_t res =
|
172 |
+
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
|
173 |
+
return (uint8_t)res;
|
174 |
+
}
|
175 |
+
|
176 |
+
// bf16 -> fp8
|
177 |
+
template <>
|
178 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
179 |
+
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
180 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
181 |
+
assert(false);
|
182 |
+
#else
|
183 |
+
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
184 |
+
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
185 |
+
return (uint8_t)res;
|
186 |
+
#endif
|
187 |
+
}
|
188 |
+
|
189 |
+
// float -> fp8
|
190 |
+
template <>
|
191 |
+
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
|
192 |
+
const float &a, const __nv_fp8_interpretation_t fp8_type) {
|
193 |
+
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
|
194 |
+
return (uint8_t)res;
|
195 |
+
}
|
196 |
+
|
197 |
+
// fp8x4 -> float4
|
198 |
+
template <>
|
199 |
+
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
|
200 |
+
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
201 |
+
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
|
202 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
203 |
+
return res;
|
204 |
+
}
|
205 |
+
|
206 |
+
template <>
|
207 |
+
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
|
208 |
+
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
209 |
+
union {
|
210 |
+
half2 float16;
|
211 |
+
uint32_t uint32;
|
212 |
+
};
|
213 |
+
|
214 |
+
float16 = __float22half2_rn(a);
|
215 |
+
return uint32;
|
216 |
+
}
|
217 |
+
|
218 |
+
template <>
|
219 |
+
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
|
220 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
221 |
+
uint2 b;
|
222 |
+
float2 val;
|
223 |
+
val.x = a.x.x;
|
224 |
+
val.y = a.x.y;
|
225 |
+
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
|
226 |
+
|
227 |
+
val.x = a.y.x;
|
228 |
+
val.y = a.y.y;
|
229 |
+
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
|
230 |
+
|
231 |
+
return b;
|
232 |
+
}
|
233 |
+
|
234 |
+
template <>
|
235 |
+
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
|
236 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
237 |
+
float4 b;
|
238 |
+
b.x = a.x.x;
|
239 |
+
b.y = a.x.y;
|
240 |
+
b.z = a.y.x;
|
241 |
+
b.w = a.y.y;
|
242 |
+
return b;
|
243 |
+
}
|
244 |
+
|
245 |
+
template <>
|
246 |
+
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
|
247 |
+
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
248 |
+
uint4 b;
|
249 |
+
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
|
250 |
+
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
|
251 |
+
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
|
252 |
+
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
|
253 |
+
return b;
|
254 |
+
}
|
255 |
+
|
256 |
+
template <>
|
257 |
+
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
|
258 |
+
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
259 |
+
__nv_bfloat162 b;
|
260 |
+
from_float(b, a);
|
261 |
+
return b;
|
262 |
+
}
|
263 |
+
|
264 |
+
template <>
|
265 |
+
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
|
266 |
+
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
267 |
+
bf16_4_t b;
|
268 |
+
from_float(b, a);
|
269 |
+
return b;
|
270 |
+
}
|
271 |
+
|
272 |
+
template <>
|
273 |
+
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
274 |
+
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
275 |
+
bf16_8_t b;
|
276 |
+
from_float(b, a);
|
277 |
+
return b;
|
278 |
+
}
|
279 |
+
#endif
|
280 |
+
|
281 |
+
/* Scaled and vectorized conversions, for data exchange between high and low
|
282 |
+
precision domains Convention of the scale in API, e.g: FP8_data =
|
283 |
+
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
|
284 |
+
Dequant(FP8) * scale => HP
|
285 |
+
*/
|
286 |
+
|
287 |
+
template <typename Tout, typename Tin>
|
288 |
+
__inline__ __device__ Tout scaled_vec_conversion(
|
289 |
+
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
290 |
+
return x;
|
291 |
+
}
|
292 |
+
|
293 |
+
// fp8 -> half
|
294 |
+
template <>
|
295 |
+
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
296 |
+
const uint8_t& a, const float scale,
|
297 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
298 |
+
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
299 |
+
return float_to_half(half_to_float(tmp.x) * scale);
|
300 |
+
}
|
301 |
+
|
302 |
+
// fp8x2 -> half2
|
303 |
+
template <>
|
304 |
+
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
305 |
+
const uint16_t& a, const float scale,
|
306 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
307 |
+
union {
|
308 |
+
uint16_t u16[2];
|
309 |
+
uint32_t u32;
|
310 |
+
} tmp;
|
311 |
+
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
312 |
+
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
|
313 |
+
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
|
314 |
+
return tmp.u32;
|
315 |
+
}
|
316 |
+
|
317 |
+
// fp8x4 -> half2x2
|
318 |
+
template <>
|
319 |
+
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
320 |
+
const uint32_t& a, const float scale,
|
321 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
322 |
+
union {
|
323 |
+
uint2 u32x2;
|
324 |
+
uint32_t u32[2];
|
325 |
+
} tmp;
|
326 |
+
tmp.u32[0] =
|
327 |
+
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
|
328 |
+
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
|
329 |
+
scale, fp8_type);
|
330 |
+
return tmp.u32x2;
|
331 |
+
}
|
332 |
+
|
333 |
+
// fp8x8 -> half2x4
|
334 |
+
template <>
|
335 |
+
__inline__ __device__ uint4
|
336 |
+
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
|
337 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
338 |
+
union {
|
339 |
+
uint4 u64x2;
|
340 |
+
uint2 u64[2];
|
341 |
+
} tmp;
|
342 |
+
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
|
343 |
+
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
|
344 |
+
return tmp.u64x2;
|
345 |
+
}
|
346 |
+
|
347 |
+
// fp8 -> __nv_bfloat16
|
348 |
+
template <>
|
349 |
+
__inline__ __device__ __nv_bfloat16
|
350 |
+
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
351 |
+
const uint8_t& a, const float scale,
|
352 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
353 |
+
// Note there is no direct convert function from fp8 to bf16.
|
354 |
+
// fp8 -> half
|
355 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
356 |
+
// half -> float -> bf16
|
357 |
+
float tmp = half_to_float(res.x);
|
358 |
+
return __float2bfloat16(tmp * scale);
|
359 |
+
}
|
360 |
+
|
361 |
+
// fp8x2 -> __nv_bfloat162
|
362 |
+
template <>
|
363 |
+
__inline__ __device__ __nv_bfloat162
|
364 |
+
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
365 |
+
const uint16_t& a, const float scale,
|
366 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
367 |
+
__nv_bfloat162 res;
|
368 |
+
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
369 |
+
fp8_type);
|
370 |
+
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
|
371 |
+
scale, fp8_type);
|
372 |
+
return res;
|
373 |
+
}
|
374 |
+
|
375 |
+
// fp8x4 -> bf16_4_t
|
376 |
+
template <>
|
377 |
+
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
378 |
+
const uint32_t& a, const float scale,
|
379 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
380 |
+
bf16_4_t res;
|
381 |
+
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
382 |
+
fp8_type);
|
383 |
+
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
384 |
+
scale, fp8_type);
|
385 |
+
return res;
|
386 |
+
}
|
387 |
+
|
388 |
+
// fp8x8 -> bf16_8_t
|
389 |
+
template <>
|
390 |
+
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
391 |
+
const uint2& a, const float scale,
|
392 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
393 |
+
bf16_4_t tmp1, tmp2;
|
394 |
+
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
395 |
+
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
|
396 |
+
bf16_8_t res;
|
397 |
+
res.x = tmp1.x;
|
398 |
+
res.y = tmp1.y;
|
399 |
+
res.z = tmp2.x;
|
400 |
+
res.w = tmp2.y;
|
401 |
+
return res;
|
402 |
+
}
|
403 |
+
|
404 |
+
// fp8 -> float
|
405 |
+
template <>
|
406 |
+
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
407 |
+
const uint8_t& a, const float scale,
|
408 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
409 |
+
// fp8 -> half
|
410 |
+
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
411 |
+
uint16_t tmp = res.x;
|
412 |
+
|
413 |
+
// half -> float
|
414 |
+
return half_to_float(tmp) * scale;
|
415 |
+
}
|
416 |
+
|
417 |
+
// fp8x2 -> float2
|
418 |
+
template <>
|
419 |
+
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
420 |
+
const uint16_t& a, const float scale,
|
421 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
422 |
+
// fp8x2 -> half2
|
423 |
+
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
424 |
+
// half2 -> float2
|
425 |
+
return half2_to_float2(tmp);
|
426 |
+
}
|
427 |
+
|
428 |
+
// fp8x4 -> float4
|
429 |
+
template <>
|
430 |
+
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
431 |
+
const uint32_t& a, const float scale,
|
432 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
433 |
+
Float4_ res;
|
434 |
+
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
435 |
+
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
|
436 |
+
fp8_type);
|
437 |
+
return res;
|
438 |
+
}
|
439 |
+
|
440 |
+
// fp8x8 -> float8
|
441 |
+
template <>
|
442 |
+
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
443 |
+
const uint2& a, const float scale,
|
444 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
445 |
+
Float4_ tmp1, tmp2;
|
446 |
+
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
447 |
+
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
|
448 |
+
Float8_ res;
|
449 |
+
res.x = tmp1.x;
|
450 |
+
res.y = tmp1.y;
|
451 |
+
res.z = tmp2.x;
|
452 |
+
res.w = tmp2.y;
|
453 |
+
return res;
|
454 |
+
}
|
455 |
+
|
456 |
+
// half -> fp8
|
457 |
+
template <>
|
458 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
459 |
+
const uint16_t& a, const float scale,
|
460 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
461 |
+
__nv_fp8_storage_t res =
|
462 |
+
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
463 |
+
return (uint8_t)res;
|
464 |
+
}
|
465 |
+
|
466 |
+
// bf16 -> fp8
|
467 |
+
template <>
|
468 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
469 |
+
const __nv_bfloat16& a, const float scale,
|
470 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
471 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
472 |
+
assert(false);
|
473 |
+
#else
|
474 |
+
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
475 |
+
__NV_SATFINITE, fp8_type);
|
476 |
+
return (uint8_t)res;
|
477 |
+
#endif
|
478 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
479 |
+
}
|
480 |
+
|
481 |
+
// float -> fp8
|
482 |
+
template <>
|
483 |
+
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
484 |
+
const float& a, const float scale,
|
485 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
486 |
+
__nv_fp8_storage_t res =
|
487 |
+
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
488 |
+
return (uint8_t)res;
|
489 |
+
}
|
490 |
+
|
491 |
+
// fp8x4 -> float4
|
492 |
+
template <>
|
493 |
+
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
494 |
+
const uint32_t& a, const float scale,
|
495 |
+
const __nv_fp8_interpretation_t fp8_type) {
|
496 |
+
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
497 |
+
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
498 |
+
return res;
|
499 |
+
}
|
500 |
+
#endif // ENABLE_FP8
|
501 |
+
|
502 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
503 |
+
__inline__ __device__ Tout convert(const Tin& x) {
|
504 |
+
#if 0 // Disable the following code to reduce the binary size.
|
505 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
506 |
+
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
507 |
+
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
508 |
+
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
509 |
+
}
|
510 |
+
#endif
|
511 |
+
assert(false);
|
512 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
513 |
+
}
|
514 |
+
|
515 |
+
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
516 |
+
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
517 |
+
#ifdef ENABLE_FP8
|
518 |
+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
519 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
520 |
+
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
521 |
+
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
522 |
+
}
|
523 |
+
#endif
|
524 |
+
assert(false);
|
525 |
+
__builtin_unreachable(); // Suppress missing return statement warning
|
526 |
+
}
|
527 |
+
|
528 |
+
// The following macro is used to dispatch the conversion function based on
|
529 |
+
// the data type of the key and value cache. The FN is a macro that calls a
|
530 |
+
// function with template<typename scalar_t, typename cache_t,
|
531 |
+
// Fp8KVCacheDataType kv_dt>.
|
532 |
+
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
533 |
+
if (KV_DTYPE == "auto") { \
|
534 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
535 |
+
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
536 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
537 |
+
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
538 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
539 |
+
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
540 |
+
} else { \
|
541 |
+
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
542 |
+
} \
|
543 |
+
} else { \
|
544 |
+
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
545 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
546 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
547 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
548 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
549 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
550 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
551 |
+
} else { \
|
552 |
+
TORCH_CHECK(false, \
|
553 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
554 |
+
} \
|
555 |
+
} else if (KV_DTYPE == "fp8_e5m2") { \
|
556 |
+
if (SRC_DTYPE == at::ScalarType::Float) { \
|
557 |
+
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
558 |
+
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
559 |
+
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
560 |
+
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
561 |
+
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
562 |
+
} else { \
|
563 |
+
TORCH_CHECK(false, \
|
564 |
+
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
565 |
+
} \
|
566 |
+
} else { \
|
567 |
+
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
568 |
+
} \
|
569 |
+
}
|
570 |
+
|
571 |
+
} // namespace fp8
|
572 |
+
#endif // not USE_ROCM
|
573 |
+
} // namespace vllm
|
gptq_marlin/marlin.cuh
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <c10/cuda/CUDAGuard.h>
|
7 |
+
#include <cuda.h>
|
8 |
+
#include <cuda_fp16.h>
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
#include <iostream>
|
11 |
+
|
12 |
+
namespace marlin {
|
13 |
+
|
14 |
+
// Marlin params
|
15 |
+
|
16 |
+
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
17 |
+
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
18 |
+
// we want relatively few warps to have many registers per warp and small tiles.
|
19 |
+
static constexpr int default_threads = 256;
|
20 |
+
|
21 |
+
static constexpr int pipe_stages =
|
22 |
+
4; // 4 pipeline stages fit into shared memory
|
23 |
+
|
24 |
+
static constexpr int min_thread_n = 64;
|
25 |
+
static constexpr int min_thread_k = 64;
|
26 |
+
|
27 |
+
static constexpr int tile_size = 16;
|
28 |
+
static constexpr int max_par = 16;
|
29 |
+
|
30 |
+
// Repack params
|
31 |
+
static constexpr int repack_stages = 8;
|
32 |
+
|
33 |
+
static constexpr int repack_threads = 256;
|
34 |
+
|
35 |
+
static constexpr int tile_k_size = tile_size;
|
36 |
+
static constexpr int tile_n_size = tile_k_size * 4;
|
37 |
+
|
38 |
+
// Helpers
|
39 |
+
template <typename T, int n>
|
40 |
+
struct Vec {
|
41 |
+
T elems[n];
|
42 |
+
__device__ T& operator[](int i) { return elems[i]; }
|
43 |
+
};
|
44 |
+
|
45 |
+
using I4 = Vec<int, 4>;
|
46 |
+
|
47 |
+
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
48 |
+
|
49 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
50 |
+
// No support for async
|
51 |
+
#else
|
52 |
+
|
53 |
+
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
54 |
+
bool pred = true) {
|
55 |
+
const int BYTES = 16;
|
56 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
57 |
+
asm volatile(
|
58 |
+
"{\n"
|
59 |
+
" .reg .pred p;\n"
|
60 |
+
" setp.ne.b32 p, %0, 0;\n"
|
61 |
+
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
62 |
+
"}\n" ::"r"((int)pred),
|
63 |
+
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
64 |
+
}
|
65 |
+
|
66 |
+
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
67 |
+
const int BYTES = 16;
|
68 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
69 |
+
asm volatile(
|
70 |
+
"{\n"
|
71 |
+
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
72 |
+
"}\n" ::"r"(smem),
|
73 |
+
"l"(glob_ptr), "n"(BYTES));
|
74 |
+
}
|
75 |
+
|
76 |
+
__device__ inline void cp_async_fence() {
|
77 |
+
asm volatile("cp.async.commit_group;\n" ::);
|
78 |
+
}
|
79 |
+
|
80 |
+
template <int n>
|
81 |
+
__device__ inline void cp_async_wait() {
|
82 |
+
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
83 |
+
}
|
84 |
+
|
85 |
+
#endif
|
86 |
+
|
87 |
+
} // namespace marlin
|
gptq_marlin/marlin_dtypes.cuh
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#ifndef _data_types_cuh
|
3 |
+
#define _data_types_cuh
|
4 |
+
#include "marlin.cuh"
|
5 |
+
#include <cuda_fp16.h>
|
6 |
+
#include <cuda_bf16.h>
|
7 |
+
|
8 |
+
namespace marlin {
|
9 |
+
|
10 |
+
template <typename scalar_t>
|
11 |
+
class ScalarType {};
|
12 |
+
|
13 |
+
template <>
|
14 |
+
class ScalarType<half> {
|
15 |
+
public:
|
16 |
+
using scalar_t = half;
|
17 |
+
using scalar_t2 = half2;
|
18 |
+
|
19 |
+
// Matrix fragments for tensor core instructions; their precise layout is
|
20 |
+
// documented here:
|
21 |
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
22 |
+
using FragA = Vec<half2, 4>;
|
23 |
+
using FragB = Vec<half2, 2>;
|
24 |
+
using FragC = Vec<float, 4>;
|
25 |
+
using FragS = Vec<half2, 1>;
|
26 |
+
using FragZP = Vec<half2, 4>;
|
27 |
+
|
28 |
+
static __device__ float inline num2float(const half x) {
|
29 |
+
return __half2float(x);
|
30 |
+
}
|
31 |
+
|
32 |
+
static __device__ half2 inline num2num2(const half x) {
|
33 |
+
return __half2half2(x);
|
34 |
+
}
|
35 |
+
|
36 |
+
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
37 |
+
return __halves2half2(x1, x2);
|
38 |
+
}
|
39 |
+
|
40 |
+
static __host__ __device__ half inline float2num(const float x) {
|
41 |
+
return __float2half(x);
|
42 |
+
}
|
43 |
+
};
|
44 |
+
|
45 |
+
template <>
|
46 |
+
class ScalarType<nv_bfloat16> {
|
47 |
+
public:
|
48 |
+
using scalar_t = nv_bfloat16;
|
49 |
+
using scalar_t2 = nv_bfloat162;
|
50 |
+
|
51 |
+
using FragA = Vec<nv_bfloat162, 4>;
|
52 |
+
using FragB = Vec<nv_bfloat162, 2>;
|
53 |
+
using FragC = Vec<float, 4>;
|
54 |
+
using FragS = Vec<nv_bfloat162, 1>;
|
55 |
+
using FragZP = Vec<nv_bfloat162, 4>;
|
56 |
+
|
57 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
58 |
+
static __device__ float inline num2float(const nv_bfloat16 x) {
|
59 |
+
return __bfloat162float(x);
|
60 |
+
}
|
61 |
+
|
62 |
+
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
63 |
+
return __bfloat162bfloat162(x);
|
64 |
+
}
|
65 |
+
|
66 |
+
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
67 |
+
const nv_bfloat16 x2) {
|
68 |
+
return __halves2bfloat162(x1, x2);
|
69 |
+
}
|
70 |
+
|
71 |
+
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
72 |
+
return __float2bfloat16(x);
|
73 |
+
}
|
74 |
+
#endif
|
75 |
+
};
|
76 |
+
|
77 |
+
} // namespace marlin
|
78 |
+
|
79 |
+
#endif
|