File size: 18,590 Bytes
8aa00a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 |
/*
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
The process of fast dequantization can be summarized as a combination
of bitwise operations and floating-point computations:
weight =>(bit_op / bitwise operations)=>
f16_value =>(flop / floating-point computation)=>
dequantized_weight
Since the dequantized weights typically require subtracting the zero point and
applying a scale factor, the floating-point computation step can be fused with
the zero-point subtraction and scaling operations.
The following are the parts that need to be modified for the fused operation
of zero-point subtraction and scaling.
## INT4 => FP16/BF16 or INT8 => FP16
The floating-point computation is `__hsub2`
If has zero points:
flop(bit_op(weight)) - flop(bit_op(zp))
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
= bit_op(weight) - bit_op(zp)
so we don't need additional modification.
If has float zero points:
flop(bit_op(weight)) - fzp
= sub(bit_op(weight), bias) - fzp
= bit_op(weight) - (fzp + bias)
where the `fzp + bias` can be computed at weight loading. But this
may have accuracy issue, so we should not use this in most cases.
If has not zero points:
scale(flop(bit_op(weight)))
= scale(sub(bit_op(weight), bias))
= scale(bit_op(weight)) - scale(bias)
= fma(bit_op(weight), scale_factor, scale(bias))
where the `scale(bias)` can be cached. But this may have accuracy issue,
so we should not use this in most cases.
## INT8 => BF16
INT8 => BF16 is a special case, it use byte_perm instead of flop.
We cannot fused byte_perm with scaling.
## FP4/FP8 => FP16/BF16
scale(flop(bit_op(weight)))
= scale(mul(bit_op(weight), multiplier))
= mul(bit_op(weight), scale_factor * multiplier)
where `scale_factor * multiplier` can be computed at weight loading.
*/
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void dequant(int q, scalar_t2* frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(int q,
half2* frag_b) {
const int MASK = 0x000f000f;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(int q,
half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<half2, vllm::kU4.id(), true>(int q,
half2* frag_b) {
dequant<half2, vllm::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, vllm::kU4.id(), false>(int q,
half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), true>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43084308;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), true>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kU4.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43004300;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id(), true>(int q,
half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(frag_b[0],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(frag_b[1],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<half2, vllm::kU8.id(), true>(int q,
half2* frag_b) {
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, vllm::kU8.id(), false>(int q,
half2* frag_b) {
dequant<half2, vllm::kU8.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
frag_b[0] = __hsub2(frag_b[0],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(frag_b[1],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id(), false>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id(), false>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), true>(
int q, half2* frag_b) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(
int q, nv_bfloat162* frag_b) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg =
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), true>(int q,
half2* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(
int q, nv_bfloat162* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and BF16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg =
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <typename scalar_t2>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
template <>
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
int Out1 = (q & 0xFF00FF00) >> 1;
;
q <<= 8;
int Out2 = (q & 0xFF00FF00) >> 1;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
};
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
nv_bfloat162* frag_b) {
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME
|