Build
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py +30 -150
- build/torch24-cxx11-cu118-x86_64-linux/quantization/_ops.py +6 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
- build/torch24-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py +110 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/cutlass.py +75 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/marlin.py +208 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/scalar_type.py +330 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py +0 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py +391 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py +100 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py +162 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py +473 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +125 -0
- build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py +470 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py +30 -150
- build/torch24-cxx11-cu121-x86_64-linux/quantization/_ops.py +6 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
- build/torch24-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py +110 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/cutlass.py +75 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/marlin.py +208 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/scalar_type.py +330 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py +0 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py +391 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py +100 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py +162 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py +473 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +125 -0
- build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py +470 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py +30 -150
- build/torch24-cxx11-cu124-x86_64-linux/quantization/_ops.py +6 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
- build/torch24-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py +110 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/cutlass.py +75 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/marlin.py +208 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/scalar_type.py +330 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py +0 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py +391 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py +100 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py +162 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py +473 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +125 -0
- build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py +470 -0
- build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py +30 -150
- build/torch24-cxx98-cu118-x86_64-linux/quantization/_ops.py +6 -0
- build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
- build/torch24-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py +110 -0
- build/torch24-cxx98-cu118-x86_64-linux/quantization/cutlass.py +75 -0
- build/torch24-cxx98-cu118-x86_64-linux/quantization/marlin.py +208 -0
- build/torch24-cxx98-cu118-x86_64-linux/quantization/scalar_type.py +330 -0
- build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py +0 -0
build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py
CHANGED
@@ -1,150 +1,30 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
#if current_platform.is_rocm():
|
33 |
-
# triton_scaled_mm_module = importlib.import_module(
|
34 |
-
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
35 |
-
# "triton_scaled_mm")
|
36 |
-
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
37 |
-
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
38 |
-
|
39 |
-
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
40 |
-
|
41 |
-
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
42 |
-
|
43 |
-
return out
|
44 |
-
|
45 |
-
# fp8
|
46 |
-
def scaled_fp8_quant(
|
47 |
-
input: torch.Tensor,
|
48 |
-
scale: Optional[torch.Tensor] = None,
|
49 |
-
num_token_padding: Optional[int] = None,
|
50 |
-
scale_ub: Optional[torch.Tensor] = None,
|
51 |
-
use_per_token_if_dynamic: bool = False,
|
52 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
53 |
-
"""
|
54 |
-
Quantize input tensor to FP8 and return quantized tensor and scale.
|
55 |
-
|
56 |
-
This function supports both static and dynamic quantization: If you
|
57 |
-
provide the scale, it will use static scaling and if you omit it,
|
58 |
-
the scale will be determined dynamically. The function also allows
|
59 |
-
optional padding of the output tensors for downstream kernels that
|
60 |
-
will benefit from padding.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
input: The input tensor to be quantized to FP8
|
64 |
-
scale: Optional scaling factor for the FP8 quantization
|
65 |
-
scale_ub: Optional upper bound for scaling factor in dynamic
|
66 |
-
per token case
|
67 |
-
num_token_padding: If specified, pad the first dimension
|
68 |
-
of the output to at least this value.
|
69 |
-
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
70 |
-
in the dynamic quantization case.
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
74 |
-
scaling factor.
|
75 |
-
"""
|
76 |
-
# This code assumes batch_dim and num_tokens are flattened
|
77 |
-
assert (input.ndim == 2)
|
78 |
-
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
79 |
-
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
80 |
-
#out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
81 |
-
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
82 |
-
out_dtype = torch.float8_e4m3fn
|
83 |
-
if num_token_padding:
|
84 |
-
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
85 |
-
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
86 |
-
|
87 |
-
if scale is None:
|
88 |
-
if use_per_token_if_dynamic:
|
89 |
-
scale = torch.empty((shape[0], 1),
|
90 |
-
device=input.device,
|
91 |
-
dtype=torch.float32)
|
92 |
-
ops.dynamic_per_token_scaled_fp8_quant(
|
93 |
-
output, input, scale, scale_ub)
|
94 |
-
else:
|
95 |
-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
96 |
-
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
97 |
-
else:
|
98 |
-
# num_token_padding not implemented for this case
|
99 |
-
assert (scale.numel() == 1 or num_token_padding is None)
|
100 |
-
ops.static_scaled_fp8_quant(output, input, scale)
|
101 |
-
|
102 |
-
return output, scale
|
103 |
-
|
104 |
-
# int8
|
105 |
-
def scaled_int8_quant(
|
106 |
-
input: torch.Tensor,
|
107 |
-
scale: Optional[torch.Tensor] = None,
|
108 |
-
azp: Optional[torch.Tensor] = None,
|
109 |
-
symmetric: bool = True
|
110 |
-
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
111 |
-
"""
|
112 |
-
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
113 |
-
|
114 |
-
Args:
|
115 |
-
input: The input tensor to be quantized to int8.
|
116 |
-
scale: Optional scaling factor for the int8 quantization.
|
117 |
-
When not provided, we invoke dynamic-per-token quantization.
|
118 |
-
azp: Optional zero-point for the int8 quantization.
|
119 |
-
Must be provided for asymmetric quantization if `scale` is provided.
|
120 |
-
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
121 |
-
|
122 |
-
Returns:
|
123 |
-
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
124 |
-
"""
|
125 |
-
output = torch.empty_like(input, dtype=torch.int8)
|
126 |
-
if scale is not None:
|
127 |
-
# static-per-tensor quantization.
|
128 |
-
assert symmetric == (
|
129 |
-
azp is
|
130 |
-
None), "azp must only be provided for asymmetric quantization."
|
131 |
-
ops.static_scaled_int8_quant(output, input, scale, azp)
|
132 |
-
return output, scale, azp
|
133 |
-
|
134 |
-
# dynamic-per-token quantization.
|
135 |
-
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
136 |
-
device=input.device,
|
137 |
-
dtype=torch.float32)
|
138 |
-
input_azp = None if symmetric else torch.empty_like(input_scales,
|
139 |
-
dtype=torch.int32)
|
140 |
-
ops.dynamic_scaled_int8_quant(output, input, input_scales,
|
141 |
-
input_azp)
|
142 |
-
return output, input_scales, input_azp
|
143 |
-
|
144 |
-
# fp8 marlin
|
145 |
-
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
146 |
-
b_scales: torch.Tensor, workspace: torch.Tensor,
|
147 |
-
num_bits: int, size_m: int, size_n: int,
|
148 |
-
size_k: int) -> torch.Tensor:
|
149 |
-
return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
150 |
-
num_bits, size_m, size_n, size_k)
|
|
|
1 |
+
from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
|
2 |
+
from .cutlass import (
|
3 |
+
cutlass_scaled_mm_supports_fp8,
|
4 |
+
cutlass_scaled_mm,
|
5 |
+
cutlass_scaled_mm_azp,
|
6 |
+
)
|
7 |
+
from .marlin import (
|
8 |
+
awq_marlin_repack,
|
9 |
+
fp8_marlin_gemm,
|
10 |
+
gptq_marlin_gemm,
|
11 |
+
gptq_marlin_repack,
|
12 |
+
gptq_marlin_24_gemm,
|
13 |
+
marlin_qqq_gemm,
|
14 |
+
marlin_gemm,
|
15 |
+
)
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"awq_marlin_repack",
|
19 |
+
"cutlass_scaled_mm",
|
20 |
+
"cutlass_scaled_mm_azp",
|
21 |
+
"cutlass_scaled_mm_supports_fp8",
|
22 |
+
"fp8_marlin_gemm",
|
23 |
+
"gptq_marlin_24_gemm",
|
24 |
+
"gptq_marlin_gemm",
|
25 |
+
"gptq_marlin_repack",
|
26 |
+
"marlin_gemm",
|
27 |
+
"marlin_qqq_gemm",
|
28 |
+
"scaled_fp8_quant",
|
29 |
+
"scaled_int8_quant",
|
30 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/_ops.py
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
import torch
|
2 |
from . import _quantization_0_0_1
|
3 |
ops = torch.ops._quantization_0_0_1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from . import _quantization_0_0_1
|
3 |
ops = torch.ops._quantization_0_0_1
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_quantization_0_0_1::{op_name}"
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:35967133ffbd0cac32aafc9e70f441264b1f41710f4f86d68723d2eb9a59cfe8
|
3 |
+
size 85717704
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
# fp8
|
18 |
+
def scaled_fp8_quant(
|
19 |
+
input: torch.Tensor,
|
20 |
+
scale: Optional[torch.Tensor] = None,
|
21 |
+
num_token_padding: Optional[int] = None,
|
22 |
+
scale_ub: Optional[torch.Tensor] = None,
|
23 |
+
use_per_token_if_dynamic: bool = False,
|
24 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
"""
|
26 |
+
Quantize input tensor to FP8 and return quantized tensor and scale.
|
27 |
+
|
28 |
+
This function supports both static and dynamic quantization: If you
|
29 |
+
provide the scale, it will use static scaling and if you omit it,
|
30 |
+
the scale will be determined dynamically. The function also allows
|
31 |
+
optional padding of the output tensors for downstream kernels that
|
32 |
+
will benefit from padding.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
input: The input tensor to be quantized to FP8
|
36 |
+
scale: Optional scaling factor for the FP8 quantization
|
37 |
+
scale_ub: Optional upper bound for scaling factor in dynamic
|
38 |
+
per token case
|
39 |
+
num_token_padding: If specified, pad the first dimension
|
40 |
+
of the output to at least this value.
|
41 |
+
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
42 |
+
in the dynamic quantization case.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
46 |
+
scaling factor.
|
47 |
+
"""
|
48 |
+
# This code assumes batch_dim and num_tokens are flattened
|
49 |
+
assert input.ndim == 2
|
50 |
+
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
51 |
+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
52 |
+
# out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
53 |
+
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
54 |
+
out_dtype = torch.float8_e4m3fn
|
55 |
+
if num_token_padding:
|
56 |
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
57 |
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
58 |
+
|
59 |
+
if scale is None:
|
60 |
+
if use_per_token_if_dynamic:
|
61 |
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
62 |
+
ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
63 |
+
else:
|
64 |
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
65 |
+
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
66 |
+
else:
|
67 |
+
# num_token_padding not implemented for this case
|
68 |
+
assert scale.numel() == 1 or num_token_padding is None
|
69 |
+
ops.static_scaled_fp8_quant(output, input, scale)
|
70 |
+
|
71 |
+
return output, scale
|
72 |
+
|
73 |
+
|
74 |
+
# int8
|
75 |
+
def scaled_int8_quant(
|
76 |
+
input: torch.Tensor,
|
77 |
+
scale: Optional[torch.Tensor] = None,
|
78 |
+
azp: Optional[torch.Tensor] = None,
|
79 |
+
symmetric: bool = True,
|
80 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
81 |
+
"""
|
82 |
+
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
input: The input tensor to be quantized to int8.
|
86 |
+
scale: Optional scaling factor for the int8 quantization.
|
87 |
+
When not provided, we invoke dynamic-per-token quantization.
|
88 |
+
azp: Optional zero-point for the int8 quantization.
|
89 |
+
Must be provided for asymmetric quantization if `scale` is provided.
|
90 |
+
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
94 |
+
"""
|
95 |
+
output = torch.empty_like(input, dtype=torch.int8)
|
96 |
+
if scale is not None:
|
97 |
+
# static-per-tensor quantization.
|
98 |
+
assert symmetric == (
|
99 |
+
azp is None
|
100 |
+
), "azp must only be provided for asymmetric quantization."
|
101 |
+
ops.static_scaled_int8_quant(output, input, scale, azp)
|
102 |
+
return output, scale, azp
|
103 |
+
|
104 |
+
# dynamic-per-token quantization.
|
105 |
+
input_scales = torch.empty(
|
106 |
+
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
|
107 |
+
)
|
108 |
+
input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
|
109 |
+
ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
|
110 |
+
return output, input_scales, input_azp
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/cutlass.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
18 |
+
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
19 |
+
|
20 |
+
|
21 |
+
def cutlass_scaled_mm(
|
22 |
+
a: torch.Tensor,
|
23 |
+
b: torch.Tensor,
|
24 |
+
scale_a: torch.Tensor,
|
25 |
+
scale_b: torch.Tensor,
|
26 |
+
out_dtype: torch.dtype,
|
27 |
+
bias: Optional[torch.Tensor] = None,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
30 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
31 |
+
assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
|
32 |
+
|
33 |
+
m = a.shape[0]
|
34 |
+
n = b.shape[1]
|
35 |
+
|
36 |
+
# if current_platform.is_rocm():
|
37 |
+
# triton_scaled_mm_module = importlib.import_module(
|
38 |
+
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
39 |
+
# "triton_scaled_mm")
|
40 |
+
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
41 |
+
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
42 |
+
|
43 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
44 |
+
|
45 |
+
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
def cutlass_scaled_mm_azp(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b: torch.Tensor,
|
53 |
+
scale_a: torch.Tensor,
|
54 |
+
scale_b: torch.Tensor,
|
55 |
+
out_dtype: torch.dtype,
|
56 |
+
azp_adj: torch.Tensor,
|
57 |
+
azp: Optional[torch.Tensor] = None,
|
58 |
+
bias: Optional[torch.Tensor] = None,
|
59 |
+
) -> torch.Tensor:
|
60 |
+
"""
|
61 |
+
:param azp_adj: In the per-tensor case, this should include the azp.
|
62 |
+
Always per-channel.
|
63 |
+
:param azp: Only set in the per-token case. Per-token if set.
|
64 |
+
"""
|
65 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
66 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
67 |
+
assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
|
68 |
+
assert azp is None or azp.numel() == a.shape[0]
|
69 |
+
|
70 |
+
m = a.shape[0]
|
71 |
+
n = b.shape[1]
|
72 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
73 |
+
|
74 |
+
ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
|
75 |
+
return out
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/marlin.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# neuron has torch version that doesn't even have impl_abstract
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
def register_fake(fn):
|
8 |
+
return lambda name: fn
|
9 |
+
else:
|
10 |
+
try:
|
11 |
+
from torch.library import register_fake
|
12 |
+
except ImportError:
|
13 |
+
from torch.library import impl_abstract as register_fake
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ._ops import ops, add_op_namespace_prefix
|
17 |
+
except ImportError as e:
|
18 |
+
# Fallback for local development.
|
19 |
+
try:
|
20 |
+
import _quantization
|
21 |
+
|
22 |
+
ops = torch.ops._quantization
|
23 |
+
|
24 |
+
def add_op_namespace_prefix(op_name: str):
|
25 |
+
return f"_quantization::{op_name}"
|
26 |
+
except ImportError:
|
27 |
+
raise e
|
28 |
+
|
29 |
+
|
30 |
+
from .scalar_type import ScalarType
|
31 |
+
|
32 |
+
|
33 |
+
# fp8 marlin
|
34 |
+
def fp8_marlin_gemm(
|
35 |
+
a: torch.Tensor,
|
36 |
+
b_q_weight: torch.Tensor,
|
37 |
+
b_scales: torch.Tensor,
|
38 |
+
workspace: torch.Tensor,
|
39 |
+
num_bits: int,
|
40 |
+
size_m: int,
|
41 |
+
size_n: int,
|
42 |
+
size_k: int,
|
43 |
+
) -> torch.Tensor:
|
44 |
+
return ops.fp8_marlin_gemm(
|
45 |
+
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
# gptq_marlin
|
50 |
+
def gptq_marlin_gemm(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b_q_weight: torch.Tensor,
|
53 |
+
b_scales: torch.Tensor,
|
54 |
+
b_zeros: torch.Tensor,
|
55 |
+
g_idx: torch.Tensor,
|
56 |
+
perm: torch.Tensor,
|
57 |
+
workspace: torch.Tensor,
|
58 |
+
b_q_type: ScalarType,
|
59 |
+
size_m: int,
|
60 |
+
size_n: int,
|
61 |
+
size_k: int,
|
62 |
+
is_k_full: bool,
|
63 |
+
has_zp: bool = False,
|
64 |
+
use_fp32_reduce: bool = False,
|
65 |
+
is_zp_float: bool = False,
|
66 |
+
) -> torch.Tensor:
|
67 |
+
return ops.gptq_marlin_gemm(
|
68 |
+
a,
|
69 |
+
b_q_weight,
|
70 |
+
b_scales,
|
71 |
+
b_zeros,
|
72 |
+
g_idx,
|
73 |
+
perm,
|
74 |
+
workspace,
|
75 |
+
b_q_type.id,
|
76 |
+
size_m,
|
77 |
+
size_n,
|
78 |
+
size_k,
|
79 |
+
is_k_full,
|
80 |
+
has_zp,
|
81 |
+
use_fp32_reduce,
|
82 |
+
is_zp_float,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
# gptq_marlin
|
87 |
+
def gptq_marlin_repack(
|
88 |
+
b_q_weight: torch.Tensor,
|
89 |
+
perm: torch.Tensor,
|
90 |
+
size_k: int,
|
91 |
+
size_n: int,
|
92 |
+
num_bits: int,
|
93 |
+
) -> torch.Tensor:
|
94 |
+
return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
95 |
+
|
96 |
+
|
97 |
+
# gptq_marlin
|
98 |
+
def awq_marlin_repack(
|
99 |
+
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
100 |
+
) -> torch.Tensor:
|
101 |
+
return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
102 |
+
|
103 |
+
|
104 |
+
# marlin
|
105 |
+
def marlin_gemm(
|
106 |
+
a: torch.Tensor,
|
107 |
+
b_q_weight: torch.Tensor,
|
108 |
+
b_scales: torch.Tensor,
|
109 |
+
workspace: torch.Tensor,
|
110 |
+
size_m: int,
|
111 |
+
size_n: int,
|
112 |
+
size_k: int,
|
113 |
+
) -> torch.Tensor:
|
114 |
+
return ops.marlin_gemm(
|
115 |
+
a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# marlin_24
|
120 |
+
def gptq_marlin_24_gemm(
|
121 |
+
a: torch.Tensor,
|
122 |
+
b_q_weight: torch.Tensor,
|
123 |
+
b_meta: torch.Tensor,
|
124 |
+
b_scales: torch.Tensor,
|
125 |
+
workspace: torch.Tensor,
|
126 |
+
b_q_type: ScalarType,
|
127 |
+
size_m: int,
|
128 |
+
size_n: int,
|
129 |
+
size_k: int,
|
130 |
+
) -> torch.Tensor:
|
131 |
+
return ops.gptq_marlin_24_gemm(
|
132 |
+
a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
# qqq ops
|
137 |
+
def marlin_qqq_gemm(
|
138 |
+
a: torch.Tensor,
|
139 |
+
b_q_weight: torch.Tensor,
|
140 |
+
s_tok: torch.Tensor,
|
141 |
+
s_ch: torch.Tensor,
|
142 |
+
s_group: torch.Tensor,
|
143 |
+
workspace: torch.Tensor,
|
144 |
+
size_m: int,
|
145 |
+
size_n: int,
|
146 |
+
size_k: int,
|
147 |
+
) -> torch.Tensor:
|
148 |
+
return ops.marlin_qqq_gemm(
|
149 |
+
a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
# Fake ops
|
154 |
+
|
155 |
+
if hasattr(ops, "gptq_marlin_24_gemm"):
|
156 |
+
@register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
|
157 |
+
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
158 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
159 |
+
num_bits: int, size_m: torch.SymInt,
|
160 |
+
size_n: torch.SymInt,
|
161 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
162 |
+
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
163 |
+
|
164 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
|
165 |
+
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
166 |
+
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
167 |
+
workspace: torch.Tensor,
|
168 |
+
b_q_type: ScalarType, size_m: torch.SymInt,
|
169 |
+
size_n: torch.SymInt,
|
170 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
171 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
172 |
+
|
173 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
|
174 |
+
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
175 |
+
b_q_weight: torch.Tensor,
|
176 |
+
b_scales: torch.Tensor,
|
177 |
+
b_zeros: torch.Tensor,
|
178 |
+
g_idx: torch.Tensor,
|
179 |
+
perm: torch.Tensor,
|
180 |
+
workspace: torch.Tensor,
|
181 |
+
b_q_type: ScalarType,
|
182 |
+
size_m: torch.SymInt,
|
183 |
+
size_n: torch.SymInt,
|
184 |
+
size_k: torch.SymInt,
|
185 |
+
is_k_full: bool,
|
186 |
+
has_zp: bool = False,
|
187 |
+
use_fp32_reduce: bool = False,
|
188 |
+
is_zp_float: bool = False) -> torch.Tensor:
|
189 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
190 |
+
|
191 |
+
@register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
|
192 |
+
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
193 |
+
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
194 |
+
s_group: torch.Tensor, workspace: torch.Tensor,
|
195 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
196 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
197 |
+
return torch.empty((size_m, size_n),
|
198 |
+
dtype=torch.float16,
|
199 |
+
device=a.device)
|
200 |
+
|
201 |
+
@register_fake(add_op_namespace_prefix("marlin_gemm"))
|
202 |
+
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
203 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
204 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
205 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
206 |
+
return torch.empty((size_m, size_n),
|
207 |
+
dtype=torch.float16,
|
208 |
+
device=a.device)
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/scalar_type.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import struct
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
# Mirrors enum in `core/scalar_type.hpp`
|
9 |
+
class NanRepr(Enum):
|
10 |
+
NONE = 0 # nans are not supported
|
11 |
+
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
12 |
+
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
13 |
+
|
14 |
+
|
15 |
+
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
16 |
+
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
17 |
+
# in sync until the inductor fully supports custom C++ classes.
|
18 |
+
@dataclass(frozen=True)
|
19 |
+
class ScalarType:
|
20 |
+
"""
|
21 |
+
ScalarType can represent a wide range of floating point and integer
|
22 |
+
types, in particular it can be used to represent sub-byte data types
|
23 |
+
(something that torch.dtype currently does not support). It is also
|
24 |
+
capable of representing types with a bias, i.e.:
|
25 |
+
`stored_value = value + bias`,
|
26 |
+
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
27 |
+
of 8). The implementation for this class can be found in
|
28 |
+
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
29 |
+
with that file.
|
30 |
+
"""
|
31 |
+
|
32 |
+
exponent: int
|
33 |
+
"""
|
34 |
+
Number of bits in the exponent if this is a floating point type
|
35 |
+
(zero if this an integer type)
|
36 |
+
"""
|
37 |
+
|
38 |
+
mantissa: int
|
39 |
+
"""
|
40 |
+
Number of bits in the mantissa if this is a floating point type,
|
41 |
+
or the number bits representing an integer excluding the sign bit if
|
42 |
+
this an integer type.
|
43 |
+
"""
|
44 |
+
|
45 |
+
signed: bool
|
46 |
+
"If the type is signed (i.e. has a sign bit)"
|
47 |
+
|
48 |
+
bias: int
|
49 |
+
"""
|
50 |
+
bias used to encode the values in this scalar type
|
51 |
+
(value = stored_value - bias, default 0) for example if we store the
|
52 |
+
type as an unsigned integer with a bias of 128 then the value 0 will be
|
53 |
+
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
54 |
+
"""
|
55 |
+
|
56 |
+
_finite_values_only: bool = False
|
57 |
+
"""
|
58 |
+
Private: if infs are supported, used `has_infs()` instead.
|
59 |
+
"""
|
60 |
+
|
61 |
+
nan_repr: NanRepr = NanRepr.IEEE_754
|
62 |
+
"""
|
63 |
+
How NaNs are represent in this scalar type, returns NanRepr value.
|
64 |
+
(not applicable for integer types)
|
65 |
+
"""
|
66 |
+
|
67 |
+
def _floating_point_max_int(self) -> int:
|
68 |
+
assert (
|
69 |
+
self.mantissa <= 52 and self.exponent <= 11
|
70 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
71 |
+
|
72 |
+
max_mantissa = (1 << self.mantissa) - 1
|
73 |
+
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
74 |
+
max_mantissa = max_mantissa - 1
|
75 |
+
|
76 |
+
max_exponent = (1 << self.exponent) - 2
|
77 |
+
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
78 |
+
or self.nan_repr == NanRepr.NONE):
|
79 |
+
assert (
|
80 |
+
self.exponent < 11
|
81 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
82 |
+
max_exponent = max_exponent + 1
|
83 |
+
|
84 |
+
# adjust the exponent to match that of a double
|
85 |
+
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
86 |
+
# e is the exponent bits), there is some precedent for non-standard
|
87 |
+
# biases, example `float8_e4m3b11fnuz` here:
|
88 |
+
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
89 |
+
# complication we are just assuming the standard exponent bias until
|
90 |
+
# there is a need to support non-standard biases
|
91 |
+
exponent_bias = (1 << (self.exponent - 1)) - 1
|
92 |
+
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
93 |
+
|
94 |
+
max_exponent_double = (max_exponent - exponent_bias +
|
95 |
+
exponent_bias_double)
|
96 |
+
|
97 |
+
# shift the mantissa and exponent into the proper positions for an
|
98 |
+
# IEEE double and bitwise-or them together.
|
99 |
+
return (max_mantissa <<
|
100 |
+
(52 - self.mantissa)) | (max_exponent_double << 52)
|
101 |
+
|
102 |
+
def _floating_point_max(self) -> float:
|
103 |
+
double_raw = self._floating_point_max_int()
|
104 |
+
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
105 |
+
|
106 |
+
def _raw_max(self) -> Union[int, float]:
|
107 |
+
if self.is_floating_point():
|
108 |
+
return self._floating_point_max()
|
109 |
+
else:
|
110 |
+
assert (self.size_bits < 64 or self.size_bits == 64
|
111 |
+
and self.is_signed()), "Cannot represent max as an int"
|
112 |
+
return (1 << self.mantissa) - 1
|
113 |
+
|
114 |
+
def _raw_min(self) -> Union[int, float]:
|
115 |
+
if self.is_floating_point():
|
116 |
+
assert self.is_signed(
|
117 |
+
), "We currently assume all floating point types are signed"
|
118 |
+
sign_bit_double = 1 << 63
|
119 |
+
|
120 |
+
max_raw = self._floating_point_max_int()
|
121 |
+
min_raw = max_raw | sign_bit_double
|
122 |
+
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
123 |
+
else:
|
124 |
+
assert (not self.is_signed() or
|
125 |
+
self.size_bits <= 64), "Cannot represent min as a int64_t"
|
126 |
+
|
127 |
+
if self.is_signed():
|
128 |
+
return -(1 << (self.size_bits - 1))
|
129 |
+
else:
|
130 |
+
return 0
|
131 |
+
|
132 |
+
@functools.cached_property
|
133 |
+
def id(self) -> int:
|
134 |
+
"""
|
135 |
+
Convert the ScalarType to an int which can be passed to pytorch custom
|
136 |
+
ops. This layout of the int must be kept in sync with the C++
|
137 |
+
ScalarType's from_id method.
|
138 |
+
"""
|
139 |
+
val = 0
|
140 |
+
offset = 0
|
141 |
+
|
142 |
+
def or_and_advance(member, bit_width):
|
143 |
+
nonlocal val
|
144 |
+
nonlocal offset
|
145 |
+
bit_mask = (1 << bit_width) - 1
|
146 |
+
val = val | (int(member) & bit_mask) << offset
|
147 |
+
offset = offset + bit_width
|
148 |
+
|
149 |
+
or_and_advance(self.exponent, 8)
|
150 |
+
or_and_advance(self.mantissa, 8)
|
151 |
+
or_and_advance(self.signed, 1)
|
152 |
+
or_and_advance(self.bias, 32)
|
153 |
+
or_and_advance(self._finite_values_only, 1)
|
154 |
+
or_and_advance(self.nan_repr.value, 8)
|
155 |
+
|
156 |
+
assert offset <= 64, \
|
157 |
+
f"ScalarType fields too big {offset} to fit into an int64"
|
158 |
+
|
159 |
+
return val
|
160 |
+
|
161 |
+
@property
|
162 |
+
def size_bits(self) -> int:
|
163 |
+
return self.exponent + self.mantissa + int(self.signed)
|
164 |
+
|
165 |
+
def min(self) -> Union[int, float]:
|
166 |
+
"""
|
167 |
+
Min representable value for this scalar type.
|
168 |
+
(accounting for bias if there is one)
|
169 |
+
"""
|
170 |
+
return self._raw_min() - self.bias
|
171 |
+
|
172 |
+
def max(self) -> Union[int, float]:
|
173 |
+
"""
|
174 |
+
Max representable value for this scalar type.
|
175 |
+
(accounting for bias if there is one)
|
176 |
+
"""
|
177 |
+
return self._raw_max() - self.bias
|
178 |
+
|
179 |
+
def is_signed(self) -> bool:
|
180 |
+
"""
|
181 |
+
If the type is signed (i.e. has a sign bit), same as `signed`
|
182 |
+
added for consistency with:
|
183 |
+
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
184 |
+
"""
|
185 |
+
return self.signed
|
186 |
+
|
187 |
+
def is_floating_point(self) -> bool:
|
188 |
+
"If the type is a floating point type"
|
189 |
+
return self.exponent != 0
|
190 |
+
|
191 |
+
def is_integer(self) -> bool:
|
192 |
+
"If the type is an integer type"
|
193 |
+
return self.exponent == 0
|
194 |
+
|
195 |
+
def has_bias(self) -> bool:
|
196 |
+
"If the type has a non-zero bias"
|
197 |
+
return self.bias != 0
|
198 |
+
|
199 |
+
def has_infs(self) -> bool:
|
200 |
+
"If the type is floating point and supports infinity"
|
201 |
+
return not self._finite_values_only
|
202 |
+
|
203 |
+
def has_nans(self) -> bool:
|
204 |
+
return self.nan_repr != NanRepr.NONE.value
|
205 |
+
|
206 |
+
def is_ieee_754(self) -> bool:
|
207 |
+
"""
|
208 |
+
If the type is a floating point type that follows IEEE 754
|
209 |
+
conventions
|
210 |
+
"""
|
211 |
+
return self.nan_repr == NanRepr.IEEE_754.value and \
|
212 |
+
not self._finite_values_only
|
213 |
+
|
214 |
+
def __str__(self) -> str:
|
215 |
+
"""
|
216 |
+
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
217 |
+
for floating point types (leading f) the scheme is:
|
218 |
+
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
219 |
+
flags:
|
220 |
+
- no-flags: means it follows IEEE 754 conventions
|
221 |
+
- f: means finite values only (no infinities)
|
222 |
+
- n: means nans are supported (non-standard encoding)
|
223 |
+
for integer types the scheme is:
|
224 |
+
`[u]int<size_bits>[b<bias>]`
|
225 |
+
- if bias is not present it means its zero
|
226 |
+
"""
|
227 |
+
if self.is_floating_point():
|
228 |
+
ret = "float" + str(self.size_bits) + "_e" + str(
|
229 |
+
self.exponent) + "m" + str(self.mantissa)
|
230 |
+
|
231 |
+
if not self.is_ieee_754():
|
232 |
+
if self._finite_values_only:
|
233 |
+
ret = ret + "f"
|
234 |
+
if self.nan_repr != NanRepr.NONE:
|
235 |
+
ret = ret + "n"
|
236 |
+
|
237 |
+
return ret
|
238 |
+
else:
|
239 |
+
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
240 |
+
if self.has_bias():
|
241 |
+
ret = ret + "b" + str(self.bias)
|
242 |
+
return ret
|
243 |
+
|
244 |
+
def __repr__(self) -> str:
|
245 |
+
return "ScalarType." + self.__str__()
|
246 |
+
|
247 |
+
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
248 |
+
# opcheck to work.
|
249 |
+
def __len__(self) -> int:
|
250 |
+
raise TypeError
|
251 |
+
|
252 |
+
#
|
253 |
+
# Convenience Constructors
|
254 |
+
#
|
255 |
+
|
256 |
+
@classmethod
|
257 |
+
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
258 |
+
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
259 |
+
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
260 |
+
ret.id # noqa B018: make sure the id is cached
|
261 |
+
return ret
|
262 |
+
|
263 |
+
@classmethod
|
264 |
+
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
265 |
+
"""Create a unsigned integer scalar type."""
|
266 |
+
ret = cls(0, size_bits, False, bias if bias else 0)
|
267 |
+
ret.id # noqa B018: make sure the id is cached
|
268 |
+
return ret
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
272 |
+
"""
|
273 |
+
Create a standard floating point type
|
274 |
+
(i.e. follows IEEE 754 conventions).
|
275 |
+
"""
|
276 |
+
assert (mantissa > 0 and exponent > 0)
|
277 |
+
ret = cls(exponent, mantissa, True, 0)
|
278 |
+
ret.id # noqa B018: make sure the id is cached
|
279 |
+
return ret
|
280 |
+
|
281 |
+
@classmethod
|
282 |
+
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
283 |
+
nan_repr: NanRepr) -> 'ScalarType':
|
284 |
+
"""
|
285 |
+
Create a non-standard floating point type
|
286 |
+
(i.e. does not follow IEEE 754 conventions).
|
287 |
+
"""
|
288 |
+
assert (mantissa > 0 and exponent > 0)
|
289 |
+
assert (nan_repr != NanRepr.IEEE_754), (
|
290 |
+
"use `float_IEEE754` constructor for floating point types that "
|
291 |
+
"follow IEEE 754 conventions")
|
292 |
+
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
293 |
+
ret.id # noqa B018: make sure the id is cached
|
294 |
+
return ret
|
295 |
+
|
296 |
+
|
297 |
+
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
298 |
+
# for floating point types (leading f) the scheme is:
|
299 |
+
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
300 |
+
# flags:
|
301 |
+
# - no-flags: means it follows IEEE 754 conventions
|
302 |
+
# - f: means finite values only (no infinities)
|
303 |
+
# - n: means nans are supported (non-standard encoding)
|
304 |
+
# for integer types the scheme is:
|
305 |
+
# `[u]int<size_bits>[b<bias>]`
|
306 |
+
# - if bias is not present it means its zero
|
307 |
+
|
308 |
+
|
309 |
+
class scalar_types:
|
310 |
+
int4 = ScalarType.int_(4, None)
|
311 |
+
uint4 = ScalarType.uint(4, None)
|
312 |
+
int8 = ScalarType.int_(8, None)
|
313 |
+
uint8 = ScalarType.uint(8, None)
|
314 |
+
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
315 |
+
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
316 |
+
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
317 |
+
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
318 |
+
|
319 |
+
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
320 |
+
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
321 |
+
|
322 |
+
# "gptq" types
|
323 |
+
uint2b2 = ScalarType.uint(2, 2)
|
324 |
+
uint3b4 = ScalarType.uint(3, 4)
|
325 |
+
uint4b8 = ScalarType.uint(4, 8)
|
326 |
+
uint8b128 = ScalarType.uint(8, 128)
|
327 |
+
|
328 |
+
# colloquial names
|
329 |
+
bfloat16 = float16_e8m7
|
330 |
+
float16 = float16_e5m10
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py
ADDED
File without changes
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import quantization as ops
|
7 |
+
from quantization.scalar_type import ScalarType, scalar_types
|
8 |
+
|
9 |
+
from .quant_utils import pack_cols, unpack_cols
|
10 |
+
|
11 |
+
GPTQ_MARLIN_TILE = 16
|
12 |
+
GPTQ_MARLIN_MIN_THREAD_N = 64
|
13 |
+
GPTQ_MARLIN_MIN_THREAD_K = 128
|
14 |
+
GPTQ_MARLIN_MAX_PARALLEL = 16
|
15 |
+
|
16 |
+
GPTQ_MARLIN_24_TILE = 16
|
17 |
+
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
18 |
+
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
19 |
+
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
20 |
+
|
21 |
+
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
22 |
+
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
23 |
+
|
24 |
+
MARLIN_QQQ_TILE = 16
|
25 |
+
MARLIN_QQQ_MIN_THREAD_N = 64
|
26 |
+
MARLIN_QQQ_MIN_THREAD_K = 128
|
27 |
+
MARLIN_QQQ_MAX_PARALLEL = 16
|
28 |
+
|
29 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
30 |
+
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
31 |
+
MARLIN_QQQ_SUPPORTED_SYM = [True]
|
32 |
+
|
33 |
+
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
34 |
+
|
35 |
+
# In case there is a performance issue with Marlin, the variable below can be
|
36 |
+
# changed to False, which allows Marlin to perform global reductions in fp16
|
37 |
+
# precision (instead of fp32), and therefore, save on some memory movements.
|
38 |
+
USE_FP32_REDUCE_DEFAULT = True
|
39 |
+
|
40 |
+
|
41 |
+
# For binary size and compile time, we don't support the same types for with and
|
42 |
+
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
43 |
+
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
44 |
+
def query_marlin_supported_quant_types(
|
45 |
+
has_zp: bool, device_capability: Optional[int] = None
|
46 |
+
):
|
47 |
+
if device_capability is None:
|
48 |
+
capability_tuple = torch.cuda.get_device_capability()
|
49 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
50 |
+
|
51 |
+
if device_capability < 80:
|
52 |
+
return []
|
53 |
+
|
54 |
+
if has_zp:
|
55 |
+
# AWQ style, unsigned + runtime zero-point
|
56 |
+
return [scalar_types.uint4, scalar_types.uint8]
|
57 |
+
else:
|
58 |
+
# GPTQ style, unsigned + symmetric bias
|
59 |
+
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
60 |
+
# to add `scalar_types.float8_e4m3fn` here
|
61 |
+
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
62 |
+
|
63 |
+
|
64 |
+
def _check_marlin_supported(
|
65 |
+
quant_type: ScalarType,
|
66 |
+
group_size: Optional[int],
|
67 |
+
has_zp: bool,
|
68 |
+
device_capability: Optional[int] = None,
|
69 |
+
) -> Tuple[bool, Optional[str]]:
|
70 |
+
|
71 |
+
if device_capability is None:
|
72 |
+
capability_tuple = torch.cuda.get_device_capability()
|
73 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
74 |
+
|
75 |
+
supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
|
76 |
+
|
77 |
+
if quant_type not in supported_types:
|
78 |
+
return (
|
79 |
+
False,
|
80 |
+
f"Marlin does not support weight_bits = {quant_type}. "
|
81 |
+
f"Only types = {supported_types} "
|
82 |
+
f"are supported (for group_size = {group_size}, "
|
83 |
+
f"device_capability = {device_capability}, zp = {has_zp}).",
|
84 |
+
)
|
85 |
+
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
86 |
+
return (
|
87 |
+
False,
|
88 |
+
f"Marlin does not support group_size = {group_size}. "
|
89 |
+
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
90 |
+
"are supported.",
|
91 |
+
)
|
92 |
+
|
93 |
+
return True, None
|
94 |
+
|
95 |
+
|
96 |
+
def check_marlin_supported(
|
97 |
+
quant_type: ScalarType,
|
98 |
+
group_size: int,
|
99 |
+
has_zp: bool = False,
|
100 |
+
device_capability: Optional[int] = None,
|
101 |
+
) -> bool:
|
102 |
+
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
|
103 |
+
return cond
|
104 |
+
|
105 |
+
|
106 |
+
def verify_marlin_supported(
|
107 |
+
quant_type: ScalarType, group_size: int, has_zp: bool = False
|
108 |
+
) -> None:
|
109 |
+
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
|
110 |
+
if not cond:
|
111 |
+
assert err_msg is not None
|
112 |
+
raise ValueError(err_msg)
|
113 |
+
|
114 |
+
|
115 |
+
def verify_marlin_supports_shape(
|
116 |
+
output_size_per_partition: int,
|
117 |
+
input_size_per_partition: int,
|
118 |
+
input_size: int,
|
119 |
+
group_size: int,
|
120 |
+
) -> None:
|
121 |
+
|
122 |
+
# Validate output_size_per_partition
|
123 |
+
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
|
124 |
+
raise ValueError(
|
125 |
+
f"Weight output_size_per_partition = "
|
126 |
+
f"{output_size_per_partition} is not divisible by "
|
127 |
+
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
|
128 |
+
"Consider reducing tensor_parallel_size or running "
|
129 |
+
"with --quantization gptq."
|
130 |
+
)
|
131 |
+
|
132 |
+
# Validate input_size_per_partition
|
133 |
+
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
|
134 |
+
raise ValueError(
|
135 |
+
f"Weight input_size_per_partition = "
|
136 |
+
f"{input_size_per_partition} is not divisible "
|
137 |
+
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
|
138 |
+
"Consider reducing tensor_parallel_size or running "
|
139 |
+
"with --quantization gptq."
|
140 |
+
)
|
141 |
+
|
142 |
+
if group_size < input_size and input_size_per_partition % group_size != 0:
|
143 |
+
raise ValueError(
|
144 |
+
f"Weight input_size_per_partition = {input_size_per_partition}"
|
145 |
+
f" is not divisible by group_size = {group_size}."
|
146 |
+
"Consider reducing tensor_parallel_size or running "
|
147 |
+
"with --quantization gptq."
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
def check_marlin_supports_shape(
|
152 |
+
output_size_per_partition: int,
|
153 |
+
input_size_per_partition: int,
|
154 |
+
input_size: int,
|
155 |
+
group_size: int,
|
156 |
+
) -> Tuple[bool, Optional[str]]:
|
157 |
+
try:
|
158 |
+
verify_marlin_supports_shape(
|
159 |
+
output_size_per_partition, input_size_per_partition, input_size, group_size
|
160 |
+
)
|
161 |
+
except ValueError as e:
|
162 |
+
return False, e.__str__()
|
163 |
+
return True, None
|
164 |
+
|
165 |
+
|
166 |
+
def marlin_make_workspace(
|
167 |
+
output_size_per_partition: int, device: torch.device
|
168 |
+
) -> torch.Tensor:
|
169 |
+
max_workspace_size = (
|
170 |
+
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
|
171 |
+
) * GPTQ_MARLIN_MAX_PARALLEL
|
172 |
+
|
173 |
+
return torch.zeros(
|
174 |
+
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
179 |
+
return (not act_order) or (act_order and not is_row_parallel)
|
180 |
+
|
181 |
+
|
182 |
+
def marlin_repeat_scales_on_all_ranks(
|
183 |
+
act_order: bool, group_size: int, is_row_parallel: bool
|
184 |
+
) -> bool:
|
185 |
+
# Need to repeat scales on every rank if act_ordering or
|
186 |
+
# channelwise and RowParallelLinear
|
187 |
+
is_channelwise = group_size == -1
|
188 |
+
return act_order or (is_channelwise and is_row_parallel)
|
189 |
+
|
190 |
+
|
191 |
+
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
192 |
+
return torch.nn.Parameter(
|
193 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
198 |
+
return torch.nn.Parameter(
|
199 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
204 |
+
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
205 |
+
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
206 |
+
|
207 |
+
|
208 |
+
def get_scale_perms():
|
209 |
+
scale_perm: List[int] = []
|
210 |
+
for i in range(8):
|
211 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
212 |
+
scale_perm_single: List[int] = []
|
213 |
+
for i in range(4):
|
214 |
+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
215 |
+
return scale_perm, scale_perm_single
|
216 |
+
|
217 |
+
|
218 |
+
def marlin_permute_scales(
|
219 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
220 |
+
) -> torch.Tensor:
|
221 |
+
|
222 |
+
scale_perm, scale_perm_single = get_scale_perms()
|
223 |
+
if group_size < size_k and group_size != -1:
|
224 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
225 |
+
else:
|
226 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
227 |
+
s = s.reshape((-1, size_n)).contiguous()
|
228 |
+
|
229 |
+
return s
|
230 |
+
|
231 |
+
|
232 |
+
def marlin_moe_permute_scales(
|
233 |
+
s: torch.Tensor,
|
234 |
+
size_k: int,
|
235 |
+
size_n: int,
|
236 |
+
group_size: int,
|
237 |
+
):
|
238 |
+
num_experts = s.shape[0]
|
239 |
+
output = torch.empty(
|
240 |
+
(num_experts, s.shape[1], s.shape[2]),
|
241 |
+
device=s.device,
|
242 |
+
dtype=s.dtype,
|
243 |
+
)
|
244 |
+
|
245 |
+
for e in range(num_experts):
|
246 |
+
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
247 |
+
return output
|
248 |
+
|
249 |
+
|
250 |
+
def marlin_zero_points(
|
251 |
+
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
252 |
+
) -> torch.Tensor:
|
253 |
+
# Permute zero-points in a similar way to scales, but do not use the
|
254 |
+
# "single" permutation, since zero-points are applied on every MMA
|
255 |
+
scale_perm, _ = get_scale_perms()
|
256 |
+
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
257 |
+
|
258 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
259 |
+
if num_bits == 4:
|
260 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
261 |
+
elif num_bits == 8:
|
262 |
+
interleave = numpy.array([0, 2, 1, 3])
|
263 |
+
else:
|
264 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
265 |
+
|
266 |
+
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
267 |
+
zp = zp.reshape((-1, size_n)).contiguous()
|
268 |
+
zp = pack_cols(zp, num_bits, size_k, size_n)
|
269 |
+
|
270 |
+
return zp
|
271 |
+
|
272 |
+
|
273 |
+
def awq_to_marlin_zero_points(
|
274 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
275 |
+
) -> torch.Tensor:
|
276 |
+
# AWQ zero-points are quantized and packed on the column dim.
|
277 |
+
# In addition, the values are permuted based on dequantizer.
|
278 |
+
# Here we undo both of these, and then apply marlin permutation
|
279 |
+
# and pack it back.
|
280 |
+
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
281 |
+
|
282 |
+
# Undo interleaving (use argsort(..) to get inverse perm)
|
283 |
+
if num_bits == 4:
|
284 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
285 |
+
elif num_bits == 8:
|
286 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
287 |
+
else:
|
288 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
289 |
+
|
290 |
+
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
291 |
+
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
292 |
+
|
293 |
+
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
294 |
+
return marlin_zp
|
295 |
+
|
296 |
+
|
297 |
+
def moe_awq_to_marlin_zero_points(
|
298 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
299 |
+
):
|
300 |
+
num_experts = q_zp_packed.shape[0]
|
301 |
+
output = torch.empty(
|
302 |
+
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
303 |
+
device=q_zp_packed.device,
|
304 |
+
dtype=q_zp_packed.dtype,
|
305 |
+
)
|
306 |
+
for e in range(num_experts):
|
307 |
+
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
308 |
+
return output
|
309 |
+
|
310 |
+
|
311 |
+
def apply_gptq_marlin_linear(
|
312 |
+
input: torch.Tensor,
|
313 |
+
weight: torch.Tensor,
|
314 |
+
weight_scale: torch.Tensor,
|
315 |
+
weight_zp: torch.Tensor,
|
316 |
+
g_idx: torch.Tensor,
|
317 |
+
g_idx_sort_indices: torch.Tensor,
|
318 |
+
workspace: torch.Tensor,
|
319 |
+
wtype: ScalarType,
|
320 |
+
output_size_per_partition: int,
|
321 |
+
input_size_per_partition: int,
|
322 |
+
is_k_full: bool,
|
323 |
+
bias: Optional[torch.Tensor] = None,
|
324 |
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
325 |
+
) -> torch.Tensor:
|
326 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
327 |
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
328 |
+
|
329 |
+
output = ops.gptq_marlin_gemm(
|
330 |
+
reshaped_x,
|
331 |
+
weight,
|
332 |
+
weight_scale,
|
333 |
+
weight_zp,
|
334 |
+
g_idx,
|
335 |
+
g_idx_sort_indices,
|
336 |
+
workspace,
|
337 |
+
wtype,
|
338 |
+
size_m=reshaped_x.shape[0],
|
339 |
+
size_n=output_size_per_partition,
|
340 |
+
size_k=input_size_per_partition,
|
341 |
+
is_k_full=is_k_full,
|
342 |
+
has_zp=False,
|
343 |
+
use_fp32_reduce=use_fp32_reduce,
|
344 |
+
is_zp_float=False,
|
345 |
+
)
|
346 |
+
|
347 |
+
if bias is not None:
|
348 |
+
output.add_(bias) # In-place add
|
349 |
+
|
350 |
+
return output.reshape(out_shape)
|
351 |
+
|
352 |
+
|
353 |
+
def apply_awq_marlin_linear(
|
354 |
+
input: torch.Tensor,
|
355 |
+
weight: torch.Tensor,
|
356 |
+
weight_scale: torch.Tensor,
|
357 |
+
weight_zp: torch.Tensor,
|
358 |
+
g_idx: torch.Tensor,
|
359 |
+
g_idx_sort_indices: torch.Tensor,
|
360 |
+
workspace: torch.Tensor,
|
361 |
+
quant_type: ScalarType,
|
362 |
+
output_size_per_partition: int,
|
363 |
+
input_size_per_partition: int,
|
364 |
+
bias: Optional[torch.Tensor] = None,
|
365 |
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
366 |
+
) -> torch.Tensor:
|
367 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
368 |
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
369 |
+
|
370 |
+
output = ops.gptq_marlin_gemm(
|
371 |
+
reshaped_x,
|
372 |
+
weight,
|
373 |
+
weight_scale,
|
374 |
+
weight_zp,
|
375 |
+
g_idx,
|
376 |
+
g_idx_sort_indices,
|
377 |
+
workspace,
|
378 |
+
quant_type,
|
379 |
+
size_m=reshaped_x.shape[0],
|
380 |
+
size_n=output_size_per_partition,
|
381 |
+
size_k=input_size_per_partition,
|
382 |
+
is_k_full=True,
|
383 |
+
has_zp=True,
|
384 |
+
use_fp32_reduce=use_fp32_reduce,
|
385 |
+
is_zp_float=False,
|
386 |
+
)
|
387 |
+
|
388 |
+
if bias is not None:
|
389 |
+
output.add_(bias) # In-place add
|
390 |
+
|
391 |
+
return output.reshape(out_shape)
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import quantization as ops
|
6 |
+
|
7 |
+
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
8 |
+
|
9 |
+
|
10 |
+
def is_fp8_marlin_supported():
|
11 |
+
capability = torch.cuda.get_device_capability()
|
12 |
+
capability = capability[0] * 10 + capability[1]
|
13 |
+
return capability >= 80
|
14 |
+
|
15 |
+
|
16 |
+
def apply_fp8_marlin_linear(
|
17 |
+
input: torch.Tensor,
|
18 |
+
weight: torch.Tensor,
|
19 |
+
weight_scale: torch.Tensor,
|
20 |
+
workspace: torch.Tensor,
|
21 |
+
size_n: int,
|
22 |
+
size_k: int,
|
23 |
+
bias: Optional[torch.Tensor],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
# For GPUs that lack FP8 hardware support, we can leverage the
|
26 |
+
# Marlin kernel for fast weight-only FP8 quantization
|
27 |
+
|
28 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
29 |
+
out_shape = input.shape[:-1] + (size_n,)
|
30 |
+
|
31 |
+
output = ops.fp8_marlin_gemm(
|
32 |
+
a=reshaped_x,
|
33 |
+
b_q_weight=weight,
|
34 |
+
b_scales=weight_scale,
|
35 |
+
workspace=workspace,
|
36 |
+
num_bits=8,
|
37 |
+
size_m=reshaped_x.shape[0],
|
38 |
+
size_n=size_n,
|
39 |
+
size_k=size_k,
|
40 |
+
)
|
41 |
+
|
42 |
+
if bias is not None:
|
43 |
+
output.add_(bias) # In-place add
|
44 |
+
|
45 |
+
return output.reshape(out_shape)
|
46 |
+
|
47 |
+
|
48 |
+
def prepare_fp8_layer_for_marlin(
|
49 |
+
layer: torch.nn.Module, strategy: str = "tensor"
|
50 |
+
) -> None:
|
51 |
+
part_size_n = layer.output_size_per_partition
|
52 |
+
part_size_k = layer.input_size_per_partition
|
53 |
+
|
54 |
+
device = layer.weight.device
|
55 |
+
|
56 |
+
# WORKSPACE
|
57 |
+
layer.workspace = marlin_make_workspace(part_size_n, device)
|
58 |
+
|
59 |
+
# WEIGHT
|
60 |
+
# Repack weights to marlin format
|
61 |
+
marlin_qweight = ops.gptq_marlin_repack(
|
62 |
+
b_q_weight=pack_fp8_to_int32(layer.weight),
|
63 |
+
perm=torch.empty(0, dtype=torch.int, device=device),
|
64 |
+
size_k=part_size_k,
|
65 |
+
size_n=part_size_n,
|
66 |
+
num_bits=8,
|
67 |
+
)
|
68 |
+
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
69 |
+
|
70 |
+
# WEIGHT SCALES
|
71 |
+
scales = layer.weight_scale.to(layer.orig_dtype)
|
72 |
+
# Permute scales
|
73 |
+
marlin_scales = marlin_permute_scales(
|
74 |
+
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
|
75 |
+
)
|
76 |
+
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
77 |
+
|
78 |
+
|
79 |
+
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
Repack FP8 weights to gptq format (packed int32 elements)
|
82 |
+
"""
|
83 |
+
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
84 |
+
assert fp8_tensor.shape[0] % 4 == 0
|
85 |
+
|
86 |
+
# Reshape to prepare for packing
|
87 |
+
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
88 |
+
|
89 |
+
# Convert fp8 to uint8 (byte) representation
|
90 |
+
byte_tensor = reshaped.view(torch.uint8)
|
91 |
+
|
92 |
+
# Pack 4 uint8 values into one int32
|
93 |
+
packed = (
|
94 |
+
byte_tensor[:, 0].to(torch.int32)
|
95 |
+
| (byte_tensor[:, 1].to(torch.int32) << 8)
|
96 |
+
| (byte_tensor[:, 2].to(torch.int32) << 16)
|
97 |
+
| (byte_tensor[:, 3].to(torch.int32) << 24)
|
98 |
+
)
|
99 |
+
|
100 |
+
return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from quantization.scalar_type import ScalarType
|
9 |
+
|
10 |
+
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
11 |
+
from .quant_utils import (
|
12 |
+
get_pack_factor,
|
13 |
+
gptq_quantize_weights,
|
14 |
+
quantize_weights,
|
15 |
+
sort_weights,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class MarlinWorkspace:
|
20 |
+
|
21 |
+
def __init__(self, out_features, min_thread_n, max_parallel):
|
22 |
+
assert (
|
23 |
+
out_features % min_thread_n == 0
|
24 |
+
), "out_features = {} is undivisible by min_thread_n = {}".format(
|
25 |
+
out_features, min_thread_n
|
26 |
+
)
|
27 |
+
|
28 |
+
max_workspace_size = (out_features // min_thread_n) * max_parallel
|
29 |
+
|
30 |
+
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
|
31 |
+
|
32 |
+
|
33 |
+
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
34 |
+
assert q_w.shape == (size_k, size_n)
|
35 |
+
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
36 |
+
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
37 |
+
|
38 |
+
# Permute weights to 16x64 marlin tiles
|
39 |
+
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
40 |
+
q_w = q_w.permute((0, 2, 1, 3))
|
41 |
+
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
42 |
+
|
43 |
+
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
44 |
+
|
45 |
+
return q_w
|
46 |
+
|
47 |
+
|
48 |
+
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
49 |
+
# Permute
|
50 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
51 |
+
|
52 |
+
# Pack
|
53 |
+
pack_factor = get_pack_factor(num_bits)
|
54 |
+
orig_device = q_w.device
|
55 |
+
|
56 |
+
q_w = q_w.cpu().numpy().astype(np.uint32)
|
57 |
+
|
58 |
+
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
59 |
+
for i in range(pack_factor):
|
60 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
61 |
+
|
62 |
+
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
63 |
+
|
64 |
+
return q_packed
|
65 |
+
|
66 |
+
|
67 |
+
def get_weight_perm(num_bits: int):
|
68 |
+
perm_list: List[int] = []
|
69 |
+
for i in range(32):
|
70 |
+
perm1: List[int] = []
|
71 |
+
col = i // 4
|
72 |
+
for block in [0, 1]:
|
73 |
+
for row in [
|
74 |
+
2 * (i % 4),
|
75 |
+
2 * (i % 4) + 1,
|
76 |
+
2 * (i % 4 + 4),
|
77 |
+
2 * (i % 4 + 4) + 1,
|
78 |
+
]:
|
79 |
+
perm1.append(16 * row + col + 8 * block)
|
80 |
+
for j in range(4):
|
81 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
82 |
+
|
83 |
+
perm = np.array(perm_list)
|
84 |
+
|
85 |
+
if num_bits == 4:
|
86 |
+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
87 |
+
elif num_bits == 8:
|
88 |
+
interleave = np.array([0, 2, 1, 3])
|
89 |
+
else:
|
90 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
91 |
+
|
92 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
93 |
+
perm = torch.from_numpy(perm)
|
94 |
+
return perm
|
95 |
+
|
96 |
+
|
97 |
+
def marlin_quantize(
|
98 |
+
w: torch.Tensor,
|
99 |
+
quant_type: ScalarType,
|
100 |
+
group_size: int,
|
101 |
+
act_order: bool,
|
102 |
+
test_perm: Optional[torch.Tensor] = None,
|
103 |
+
):
|
104 |
+
size_k, size_n = w.shape
|
105 |
+
num_bits = quant_type.size_bits
|
106 |
+
|
107 |
+
# Normalize group_size
|
108 |
+
if group_size == -1:
|
109 |
+
group_size = size_k
|
110 |
+
assert group_size <= size_k
|
111 |
+
|
112 |
+
# Quantize (and apply act_order if provided)
|
113 |
+
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
114 |
+
w, quant_type, group_size, act_order, test_perm
|
115 |
+
)
|
116 |
+
|
117 |
+
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
118 |
+
# increasing
|
119 |
+
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
120 |
+
if act_order:
|
121 |
+
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
122 |
+
|
123 |
+
# Reformat to marlin
|
124 |
+
weight_perm = get_weight_perm(num_bits)
|
125 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
126 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
127 |
+
|
128 |
+
# Create result
|
129 |
+
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
130 |
+
for i in range(len(res_list)):
|
131 |
+
res_list[i] = res_list[i].to(w.device)
|
132 |
+
|
133 |
+
return res_list
|
134 |
+
|
135 |
+
|
136 |
+
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
|
137 |
+
size_k, size_n = w.shape
|
138 |
+
|
139 |
+
# Normalize group_size
|
140 |
+
if group_size == -1:
|
141 |
+
group_size = size_k
|
142 |
+
assert group_size <= size_k
|
143 |
+
|
144 |
+
# Detect num groups
|
145 |
+
assert size_k % group_size == 0
|
146 |
+
num_groups = size_k // group_size
|
147 |
+
|
148 |
+
# Quantize with zp
|
149 |
+
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
|
150 |
+
|
151 |
+
# Reformat to marlin
|
152 |
+
weight_perm = get_weight_perm(quant_type.size_bits)
|
153 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
|
154 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
155 |
+
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
|
156 |
+
|
157 |
+
# Create result
|
158 |
+
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
159 |
+
for i in range(len(res_list)):
|
160 |
+
res_list[i] = res_list[i].to(w.device)
|
161 |
+
|
162 |
+
return res_list
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
import random
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import numpy
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from quantization.scalar_type import ScalarType
|
10 |
+
|
11 |
+
from .marlin_utils_test import marlin_weights
|
12 |
+
from .quant_utils import gptq_quantize_weights
|
13 |
+
|
14 |
+
|
15 |
+
# This is PyTorch implementation of main part of reorder_meta()
|
16 |
+
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
17 |
+
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
18 |
+
# GEMM decides upon layout of this matrix, and at the moment for the
|
19 |
+
# sparse GEMM executed on tensor cores, this is layout described by
|
20 |
+
# ColumnMajorInterleaved<2> data structure, in
|
21 |
+
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
22 |
+
# reordering of meta matrix into meta_reordered matrix calculated
|
23 |
+
# according to these segments of CUTLASS code is re-implemented here.
|
24 |
+
# Note that this calculation produces offsets for scattering metadata
|
25 |
+
# matrix elements into reordered metadata matrix elements (or,
|
26 |
+
# equivalently, for gathering reordered metadata matrix element back
|
27 |
+
# into metadata matrix elements).
|
28 |
+
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
29 |
+
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
30 |
+
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
31 |
+
|
32 |
+
# Reorder the rows, then swizzle the 2x2 blocks.
|
33 |
+
group_x = 64
|
34 |
+
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
35 |
+
|
36 |
+
dst_rows = (
|
37 |
+
dst_rows // group_x * group_x
|
38 |
+
+ (dst_rows % 2) * 2
|
39 |
+
+ (dst_rows % 8) // 4
|
40 |
+
+ ((dst_rows % group_y) % 4) // 2 * 32
|
41 |
+
+ ((dst_rows % group_x) // 8) * 4
|
42 |
+
)
|
43 |
+
|
44 |
+
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
45 |
+
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
46 |
+
dst_rows += topright - bottomleft
|
47 |
+
dst_cols -= topright - bottomleft
|
48 |
+
|
49 |
+
# Assumed that meta tensor is to be stored in CUTLASS
|
50 |
+
# InterleavedColumnMajor layout, and reverse engineered
|
51 |
+
# corresponding code to store values into this tensor.
|
52 |
+
interleave = 2
|
53 |
+
cols_maj = dst_cols // interleave
|
54 |
+
cols_min = dst_cols % interleave
|
55 |
+
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
56 |
+
|
57 |
+
|
58 |
+
# This function converts dense matrix into sparse semi-structured
|
59 |
+
# representation, producing "compressed" matrix, in the layout used by
|
60 |
+
# CUTLASS backend, and corresponding metadata matrix.
|
61 |
+
def sparse_semi_structured_from_dense_cutlass(dense):
|
62 |
+
if dense.dim() != 2:
|
63 |
+
raise RuntimeError(
|
64 |
+
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
65 |
+
)
|
66 |
+
|
67 |
+
m, k = dense.shape
|
68 |
+
device = dense.device
|
69 |
+
|
70 |
+
meta_dtype = torch.int8
|
71 |
+
if dense.dtype == torch.int8:
|
72 |
+
meta_dtype = torch.int32
|
73 |
+
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
74 |
+
meta_dtype = torch.int16
|
75 |
+
else:
|
76 |
+
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
77 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
78 |
+
if quadbits_per_meta_elem not in (4, 8):
|
79 |
+
raise RuntimeError("Invalid number of elements per meta element calculated")
|
80 |
+
|
81 |
+
if meta_dtype == torch.int32:
|
82 |
+
if m % 16 != 0:
|
83 |
+
raise RuntimeError(
|
84 |
+
f"Number of rows of dense matrix {m} must be divisible by 16"
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
if m % 32 != 0:
|
88 |
+
raise RuntimeError(
|
89 |
+
f"Number of rows of dense matrix {m} must be divisible by 32"
|
90 |
+
)
|
91 |
+
if k % (4 * quadbits_per_meta_elem) != 0:
|
92 |
+
raise RuntimeError(
|
93 |
+
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
94 |
+
)
|
95 |
+
|
96 |
+
if dense.dtype != torch.float:
|
97 |
+
ksparse = 4
|
98 |
+
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
99 |
+
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
100 |
+
else:
|
101 |
+
ksparse = 2
|
102 |
+
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
103 |
+
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
104 |
+
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
105 |
+
|
106 |
+
# Encoding quadruples of True/False values as follows:
|
107 |
+
# [True, True, False, False] -> 0b0100
|
108 |
+
# [True, False, True, False] -> 0b1000
|
109 |
+
# [False, True, True, False] -> 0b1001
|
110 |
+
# [True, False, False, True ] -> 0b1100
|
111 |
+
# [False, True, False, True ] -> 0b1101
|
112 |
+
# [False, False, True, True ] -> 0b1110
|
113 |
+
# Thus, lower two bits in the encoding are index of the True value
|
114 |
+
# at the lowest index in the quadruple, and the higher two bits in
|
115 |
+
# the encoding are index of the other True value in the quadruple.
|
116 |
+
# In case there are less than two True values, than False value or
|
117 |
+
# values at some index or indices are considered True for the
|
118 |
+
# encoding. In case there are more than two True values, then the
|
119 |
+
# excess True value(s) at some indices are considered False for
|
120 |
+
# the encoding. The exact encodings used for these cases are as
|
121 |
+
# follows:
|
122 |
+
# [False, False, False, False] -> 0b1110
|
123 |
+
# [False, False, False, True ] -> 0b1110
|
124 |
+
# [False, False, True, False] -> 0b1110
|
125 |
+
# [False, True, False, False] -> 0b1001
|
126 |
+
# [False, True, True, True ] -> 0b1101
|
127 |
+
# [True, False, False, False] -> 0b1000
|
128 |
+
# [True, False, True, True ] -> 0b1100
|
129 |
+
# [True, True, False, True ] -> 0b0100
|
130 |
+
# [True, True, True, False] -> 0b0100
|
131 |
+
# [True, True, True, True ] -> 0b0100
|
132 |
+
# These particular encodings are chosen, with the help of Espresso
|
133 |
+
# logic minimizer software, for the purpose of minimization of
|
134 |
+
# corresponding Boolean functions, that translate non-zero flags
|
135 |
+
# into encoding bits. Note also possible choices for the first
|
136 |
+
# and last of these encodings were limited only to (0b0100,
|
137 |
+
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
138 |
+
# case.
|
139 |
+
|
140 |
+
expr0 = m0 & m1
|
141 |
+
expr1 = ~m0 & m1
|
142 |
+
expr2 = ~m0 & ~m1
|
143 |
+
bit0 = expr1
|
144 |
+
bit1 = expr2
|
145 |
+
bit2 = expr0 | expr2 | m3
|
146 |
+
bit3 = expr1 | ~m1
|
147 |
+
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
148 |
+
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
149 |
+
|
150 |
+
if dense.dtype != torch.float:
|
151 |
+
sparse0 = dense_4.gather(
|
152 |
+
-1, idxs0.unsqueeze(-1)
|
153 |
+
) # type: ignore[possibly-undefined]
|
154 |
+
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
155 |
+
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
156 |
+
else:
|
157 |
+
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
|
158 |
+
m, k // 2
|
159 |
+
) # type: ignore[possibly-undefined]
|
160 |
+
|
161 |
+
meta_4 = idxs0 | (idxs1 << 2)
|
162 |
+
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
163 |
+
|
164 |
+
if quadbits_per_meta_elem == 4:
|
165 |
+
meta = (
|
166 |
+
meta_n[:, :, 0]
|
167 |
+
| (meta_n[:, :, 1] << 4)
|
168 |
+
| (meta_n[:, :, 2] << 8)
|
169 |
+
| (meta_n[:, :, 3] << 12)
|
170 |
+
)
|
171 |
+
elif quadbits_per_meta_elem == 8:
|
172 |
+
meta = (
|
173 |
+
meta_n[:, :, 0]
|
174 |
+
| (meta_n[:, :, 1] << 4)
|
175 |
+
| (meta_n[:, :, 2] << 8)
|
176 |
+
| (meta_n[:, :, 3] << 12)
|
177 |
+
| (meta_n[:, :, 4] << 16)
|
178 |
+
| (meta_n[:, :, 5] << 20)
|
179 |
+
| (meta_n[:, :, 6] << 24)
|
180 |
+
| (meta_n[:, :, 7] << 28)
|
181 |
+
)
|
182 |
+
|
183 |
+
# Reorder meta tensor elements.
|
184 |
+
meta_reordered = meta.new_empty(
|
185 |
+
(m * meta_ncols,)
|
186 |
+
) # type: ignore[possibly-undefined]
|
187 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
188 |
+
m, meta_ncols, meta_dtype, device
|
189 |
+
)
|
190 |
+
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
191 |
+
|
192 |
+
return (sparse, meta_reordered.view(m, meta_ncols))
|
193 |
+
|
194 |
+
|
195 |
+
# This function performs reverse of the function above - it
|
196 |
+
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
197 |
+
# in the layout used by CUTLASS backend, and accompanying metadata
|
198 |
+
# matrix.
|
199 |
+
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
200 |
+
if sparse.dim() != 2:
|
201 |
+
raise RuntimeError(
|
202 |
+
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
203 |
+
)
|
204 |
+
|
205 |
+
m, k = sparse.shape
|
206 |
+
device = sparse.device
|
207 |
+
|
208 |
+
if meta_reordered.dim() != 2:
|
209 |
+
raise RuntimeError(
|
210 |
+
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
211 |
+
)
|
212 |
+
if meta_reordered.device != device:
|
213 |
+
raise RuntimeError(
|
214 |
+
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
215 |
+
)
|
216 |
+
|
217 |
+
meta_dtype = meta_reordered.dtype
|
218 |
+
if meta_dtype not in (torch.int16, torch.int32):
|
219 |
+
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
220 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
221 |
+
|
222 |
+
ksparse = 4 if sparse.dtype != torch.float else 2
|
223 |
+
|
224 |
+
meta_nrows, meta_ncols = meta_reordered.shape
|
225 |
+
if meta_nrows != m:
|
226 |
+
raise RuntimeError(
|
227 |
+
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
228 |
+
)
|
229 |
+
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
230 |
+
raise RuntimeError(
|
231 |
+
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
232 |
+
"expected according to the number of columns of meta matrix"
|
233 |
+
)
|
234 |
+
|
235 |
+
# Undo meta tensor elements reordering.
|
236 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
237 |
+
m, meta_ncols, meta_dtype, device
|
238 |
+
)
|
239 |
+
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
240 |
+
|
241 |
+
# Unpack sparse tensor back to original dense tensor, using
|
242 |
+
# information provided by meta tensor. Note that torch.float
|
243 |
+
# datatype is handled pretty much the same as
|
244 |
+
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
245 |
+
# value is encoded as if underlying 8 bytes contain four
|
246 |
+
# torch.half/torch.bfloat16 values, where either first two or last
|
247 |
+
# two are zeros.
|
248 |
+
meta_2 = torch.empty(
|
249 |
+
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
250 |
+
dtype=meta_dtype,
|
251 |
+
device=device,
|
252 |
+
)
|
253 |
+
if quadbits_per_meta_elem == 4:
|
254 |
+
meta_2[:, :, 0] = meta & 0b11
|
255 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
256 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
257 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
258 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
259 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
260 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
261 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
262 |
+
elif quadbits_per_meta_elem == 8:
|
263 |
+
meta_2[:, :, 0] = meta & 0b11
|
264 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
265 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
266 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
267 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
268 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
269 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
270 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
271 |
+
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
272 |
+
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
273 |
+
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
274 |
+
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
275 |
+
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
276 |
+
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
277 |
+
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
278 |
+
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
279 |
+
|
280 |
+
dense_offsets = meta_2.view(-1) + (
|
281 |
+
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
282 |
+
).view(-1, 1).repeat(1, 2).view(-1)
|
283 |
+
|
284 |
+
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
285 |
+
if sparse.dtype != torch.float:
|
286 |
+
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
287 |
+
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
288 |
+
else:
|
289 |
+
dense.view(torch.half).scatter_(
|
290 |
+
0, dense_offsets, sparse.view(torch.half).view(-1)
|
291 |
+
)
|
292 |
+
|
293 |
+
return dense.view(m, 2 * k)
|
294 |
+
|
295 |
+
|
296 |
+
def mask_creator(tensor):
|
297 |
+
"""
|
298 |
+
Class for creating N:M sparsity masks.
|
299 |
+
Masks will be created using the N:M ratio, where for every block of
|
300 |
+
M weights, N will be pruned based on ranked weight value. Each mask
|
301 |
+
will correspond to the given tensor.
|
302 |
+
|
303 |
+
:param N: The number of weights in a group to keep
|
304 |
+
:param M: The size of a weight group
|
305 |
+
"""
|
306 |
+
N = 2
|
307 |
+
M = 4
|
308 |
+
|
309 |
+
mask = None
|
310 |
+
# for i, tensor in enumerate(tensors):
|
311 |
+
if tensor.numel() % M != 0:
|
312 |
+
raise ValueError(
|
313 |
+
f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
|
314 |
+
)
|
315 |
+
|
316 |
+
num_groups = tensor.numel() // M
|
317 |
+
|
318 |
+
# N:M sparsity for linear layers
|
319 |
+
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
320 |
+
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
321 |
+
|
322 |
+
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
323 |
+
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
324 |
+
|
325 |
+
return mask
|
326 |
+
|
327 |
+
|
328 |
+
def inject_24(w, size_k, size_n):
|
329 |
+
assert w.shape == (size_k, size_n)
|
330 |
+
|
331 |
+
mask = mask_creator(w.t()).t().cuda().bool()
|
332 |
+
|
333 |
+
return (mask * w).contiguous(), mask.contiguous()
|
334 |
+
|
335 |
+
|
336 |
+
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
337 |
+
BLOCK_SIZE = 4
|
338 |
+
MAX_NON_ZEROS = 2
|
339 |
+
|
340 |
+
w = w.t().contiguous()
|
341 |
+
|
342 |
+
print("check_24: w.shape = {}".format(w.shape))
|
343 |
+
|
344 |
+
num_rows, num_cols = w.shape
|
345 |
+
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
346 |
+
if _verbose:
|
347 |
+
print(f"Sampled row idxs = {sampled_row_idxs}")
|
348 |
+
|
349 |
+
total_segments = 0
|
350 |
+
non_24_segments = 0
|
351 |
+
for i in sampled_row_idxs:
|
352 |
+
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
353 |
+
total_segments += 1
|
354 |
+
block = w[i, j : j + BLOCK_SIZE]
|
355 |
+
num_nonzero = torch.count_nonzero(block)
|
356 |
+
if num_nonzero > MAX_NON_ZEROS:
|
357 |
+
print("i = {} j = {} block = {}".format(i, j, block))
|
358 |
+
non_24_segments += 1
|
359 |
+
|
360 |
+
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
361 |
+
|
362 |
+
|
363 |
+
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
364 |
+
assert q_24.shape == (size_k, size_n)
|
365 |
+
|
366 |
+
# Remove bias to normalize over 0
|
367 |
+
q_24_no_zp = q_24 - wtype.bias
|
368 |
+
|
369 |
+
# Compress
|
370 |
+
q_24_no_zp = q_24_no_zp.t().contiguous()
|
371 |
+
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
|
372 |
+
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
373 |
+
|
374 |
+
# Restore bias
|
375 |
+
q_24_comp = q_24_no_zp_comp + wtype.bias
|
376 |
+
|
377 |
+
# Resize meta to its actual shape (without moving any data)
|
378 |
+
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
379 |
+
|
380 |
+
return q_24_comp, meta
|
381 |
+
|
382 |
+
|
383 |
+
def get_scale_perms_24():
|
384 |
+
scale_perm: List[int] = []
|
385 |
+
for i in range(8):
|
386 |
+
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
387 |
+
scale_perm_single: List[int] = []
|
388 |
+
for i in range(8):
|
389 |
+
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
390 |
+
return scale_perm, scale_perm_single
|
391 |
+
|
392 |
+
|
393 |
+
def get_weight_perm_24(num_bits: int):
|
394 |
+
perm_list: List[int] = []
|
395 |
+
for i in range(32):
|
396 |
+
perm1: List[int] = []
|
397 |
+
col = i // 4
|
398 |
+
col_o = col // 2
|
399 |
+
for block in [0, 1]:
|
400 |
+
for row in [
|
401 |
+
2 * (i % 4),
|
402 |
+
2 * (i % 4) + 1,
|
403 |
+
2 * (i % 4 + 4),
|
404 |
+
2 * (i % 4 + 4) + 1,
|
405 |
+
]:
|
406 |
+
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
|
407 |
+
for j in range(4):
|
408 |
+
perm_list.extend([p + 1 * j for p in perm1])
|
409 |
+
perm = numpy.array(perm_list)
|
410 |
+
|
411 |
+
if num_bits == 4:
|
412 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
413 |
+
elif num_bits == 8:
|
414 |
+
interleave = numpy.array([0, 2, 1, 3])
|
415 |
+
else:
|
416 |
+
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
417 |
+
|
418 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
419 |
+
perm = torch.from_numpy(perm)
|
420 |
+
return perm
|
421 |
+
|
422 |
+
|
423 |
+
def marlin_permute_scales_24(
|
424 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
425 |
+
) -> torch.Tensor:
|
426 |
+
|
427 |
+
scale_perm, scale_perm_single = get_scale_perms_24()
|
428 |
+
if group_size < size_k and group_size != -1:
|
429 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
430 |
+
else:
|
431 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
432 |
+
s = s.reshape((-1, size_n)).contiguous()
|
433 |
+
|
434 |
+
return s
|
435 |
+
|
436 |
+
|
437 |
+
def marlin_24_quantize(
|
438 |
+
w: torch.Tensor,
|
439 |
+
quant_type: ScalarType,
|
440 |
+
group_size: int,
|
441 |
+
):
|
442 |
+
size_k, size_n = w.shape
|
443 |
+
|
444 |
+
# Normalize group_size
|
445 |
+
if group_size == -1:
|
446 |
+
group_size = size_k
|
447 |
+
assert group_size <= size_k
|
448 |
+
|
449 |
+
# Inject 2:4 sparsity
|
450 |
+
w_24, mask_24 = inject_24(w, size_k, size_n)
|
451 |
+
|
452 |
+
# Quantize
|
453 |
+
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
|
454 |
+
w_24, quant_type, group_size, act_order=False
|
455 |
+
)
|
456 |
+
|
457 |
+
# Compress quantized weight
|
458 |
+
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
|
459 |
+
size_k_comp = size_k // 2
|
460 |
+
|
461 |
+
# Reformat to marlin
|
462 |
+
weight_perm = get_weight_perm_24(quant_type.size_bits)
|
463 |
+
marlin_24_q_w_comp = marlin_weights(
|
464 |
+
q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
|
465 |
+
)
|
466 |
+
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
|
467 |
+
|
468 |
+
# Create result
|
469 |
+
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
470 |
+
for i in range(len(res_list)):
|
471 |
+
res_list[i] = res_list[i].to(w.device)
|
472 |
+
|
473 |
+
return res_list
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .marlin_utils_test import marlin_permute_weights
|
7 |
+
from .quant_utils import get_pack_factor, qqq_quantize_weights
|
8 |
+
|
9 |
+
|
10 |
+
def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
|
11 |
+
# Permute
|
12 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
13 |
+
|
14 |
+
# Pack
|
15 |
+
pack_factor = get_pack_factor(num_bits)
|
16 |
+
orig_device = q_w.device
|
17 |
+
|
18 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
19 |
+
|
20 |
+
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
21 |
+
dtype=numpy.uint32)
|
22 |
+
if group_size == size_k:
|
23 |
+
for i in range(pack_factor):
|
24 |
+
q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
|
25 |
+
else:
|
26 |
+
for i in range(pack_factor):
|
27 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
28 |
+
|
29 |
+
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
30 |
+
|
31 |
+
return q_packed
|
32 |
+
|
33 |
+
|
34 |
+
def get_qqq_scale_perms():
|
35 |
+
scale_perm: List[int] = []
|
36 |
+
for i in range(8):
|
37 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
38 |
+
scale_perm_single: List[int] = []
|
39 |
+
for i in range(4):
|
40 |
+
scale_perm_single.extend(
|
41 |
+
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
42 |
+
return scale_perm, scale_perm_single
|
43 |
+
|
44 |
+
|
45 |
+
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
|
46 |
+
def get_qqq_weight_perm(num_bits: int, quant_type: str):
|
47 |
+
perm_list: List[int] = []
|
48 |
+
for i in range(32):
|
49 |
+
perm1: List[int] = []
|
50 |
+
col = i // 4
|
51 |
+
for block in [0, 1]:
|
52 |
+
for row in [
|
53 |
+
4 * (i % 4),
|
54 |
+
4 * (i % 4) + 1,
|
55 |
+
4 * (i % 4) + 2,
|
56 |
+
4 * (i % 4) + 3,
|
57 |
+
]:
|
58 |
+
perm1.append(16 * row + col + 8 * block)
|
59 |
+
for j in range(4):
|
60 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
61 |
+
|
62 |
+
perm = numpy.array(perm_list)
|
63 |
+
|
64 |
+
assert quant_type in ["per-channel",
|
65 |
+
"per-group"], "not supported quantization type"
|
66 |
+
if num_bits == 4:
|
67 |
+
if quant_type == "per-channel":
|
68 |
+
interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
|
69 |
+
else:
|
70 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
71 |
+
else:
|
72 |
+
raise Exception("num_bits must be 4, got {}".format(num_bits))
|
73 |
+
|
74 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
75 |
+
perm = torch.from_numpy(perm)
|
76 |
+
return perm
|
77 |
+
|
78 |
+
|
79 |
+
def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
|
80 |
+
scale_perm, scale_perm_single = get_qqq_scale_perms()
|
81 |
+
if group_size < size_k and group_size != -1:
|
82 |
+
s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
|
83 |
+
s_channel = s_channel.reshape(
|
84 |
+
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
85 |
+
s_group = s_group.reshape((-1, size_n)).contiguous()
|
86 |
+
else:
|
87 |
+
s_channel = s_channel.reshape(
|
88 |
+
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
89 |
+
s_channel = s_channel.reshape((-1, size_n)).contiguous()
|
90 |
+
|
91 |
+
return s_group, s_channel
|
92 |
+
|
93 |
+
|
94 |
+
def marlin_qqq_quantize(
|
95 |
+
w: torch.Tensor,
|
96 |
+
num_bits: int,
|
97 |
+
group_size: int,
|
98 |
+
):
|
99 |
+
size_k, size_n = w.shape
|
100 |
+
|
101 |
+
# Normalize group_size
|
102 |
+
if group_size == -1:
|
103 |
+
group_size = size_k
|
104 |
+
assert group_size <= size_k
|
105 |
+
quant_type = "per-channel" if group_size == size_k else "per-group"
|
106 |
+
|
107 |
+
# Quantize
|
108 |
+
w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
|
109 |
+
w, num_bits, group_size)
|
110 |
+
|
111 |
+
# Reformat to marlin_qqq
|
112 |
+
weight_perm = get_qqq_weight_perm(num_bits, quant_type)
|
113 |
+
marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
|
114 |
+
weight_perm, group_size)
|
115 |
+
marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
|
116 |
+
s_group, s_channel, size_k, size_n, group_size)
|
117 |
+
|
118 |
+
# Create result
|
119 |
+
res_list = [
|
120 |
+
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
|
121 |
+
]
|
122 |
+
for i in range(len(res_list)):
|
123 |
+
res_list[i] = res_list[i].to(w.device)
|
124 |
+
|
125 |
+
return res_list
|
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file is used for /tests and /benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from quantization.scalar_type import ScalarType, scalar_types
|
9 |
+
|
10 |
+
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
11 |
+
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
12 |
+
|
13 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
14 |
+
|
15 |
+
# Note: this is a hack. We should update each model to register the
|
16 |
+
# stacked params and get it from there instead in a future PR.
|
17 |
+
# fused_name: List[shard_name]
|
18 |
+
FUSED_LAYER_NAME_MAPPING = {
|
19 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
20 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def pack_quantized_values_into_int32(
|
25 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
26 |
+
):
|
27 |
+
# move dim to pack to the end
|
28 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
29 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
30 |
+
w_q_perm = w_q.permute(perm)
|
31 |
+
|
32 |
+
pack_factor = 32 // wtype.size_bits
|
33 |
+
mask = (1 << wtype.size_bits) - 1
|
34 |
+
|
35 |
+
new_shape_perm = list(w_q_perm.shape)
|
36 |
+
assert w_q_perm.shape[-1] % pack_factor == 0
|
37 |
+
new_shape_perm[-1] //= pack_factor
|
38 |
+
|
39 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
40 |
+
for i in range(pack_factor):
|
41 |
+
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
|
42 |
+
|
43 |
+
return res.permute(inv_perm)
|
44 |
+
|
45 |
+
|
46 |
+
def unpack_quantized_values_into_int32(
|
47 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
48 |
+
):
|
49 |
+
# move dim to pack to the end
|
50 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
51 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
52 |
+
w_q_perm = w_q.permute(perm)
|
53 |
+
|
54 |
+
pack_factor = 32 // wtype.size_bits
|
55 |
+
mask = (1 << wtype.size_bits) - 1
|
56 |
+
|
57 |
+
new_shape_perm = list(w_q_perm.shape)
|
58 |
+
new_shape_perm[-1] *= pack_factor
|
59 |
+
|
60 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
61 |
+
for i in range(pack_factor):
|
62 |
+
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
|
63 |
+
|
64 |
+
return res.permute(inv_perm)
|
65 |
+
|
66 |
+
|
67 |
+
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
68 |
+
# prefix: model.layers.0.self_attn.q_proj
|
69 |
+
# proj_name: q_proj
|
70 |
+
proj_name = prefix.split(".")[-1]
|
71 |
+
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
72 |
+
shard_prefixes = [
|
73 |
+
prefix.replace(proj_name, shard_proj_name)
|
74 |
+
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
75 |
+
]
|
76 |
+
|
77 |
+
is_skipped = None
|
78 |
+
for shard_prefix in shard_prefixes:
|
79 |
+
is_shard_skipped = shard_prefix in ignored_layers
|
80 |
+
|
81 |
+
if is_skipped is None:
|
82 |
+
is_skipped = is_shard_skipped
|
83 |
+
elif is_shard_skipped != is_skipped:
|
84 |
+
raise ValueError(
|
85 |
+
f"Detected some but not all shards of {prefix} "
|
86 |
+
"are quantized. All shards of fused layers "
|
87 |
+
"to have the same precision."
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
is_skipped = prefix in ignored_layers
|
91 |
+
|
92 |
+
assert is_skipped is not None
|
93 |
+
return is_skipped
|
94 |
+
|
95 |
+
|
96 |
+
def get_pack_factor(num_bits):
|
97 |
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
98 |
+
return 32 // num_bits
|
99 |
+
|
100 |
+
|
101 |
+
def permute_rows(
|
102 |
+
q_w: torch.Tensor,
|
103 |
+
w_ref: torch.Tensor,
|
104 |
+
group_size: int,
|
105 |
+
test_perm: Optional[torch.Tensor] = None,
|
106 |
+
):
|
107 |
+
assert q_w.shape == w_ref.shape
|
108 |
+
|
109 |
+
orig_device = q_w.device
|
110 |
+
k_size, _ = q_w.shape
|
111 |
+
|
112 |
+
g_idx = torch.zeros((k_size,), dtype=torch.int32)
|
113 |
+
for i in range(k_size):
|
114 |
+
g_idx[i] = i // group_size
|
115 |
+
|
116 |
+
# Simulate act_order by doing a random permutation on K
|
117 |
+
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
118 |
+
|
119 |
+
g_idx = g_idx[rand_perm].contiguous()
|
120 |
+
q_w = q_w[rand_perm, :].contiguous()
|
121 |
+
w_ref = w_ref[rand_perm, :].contiguous()
|
122 |
+
|
123 |
+
return (
|
124 |
+
w_ref.to(device=orig_device),
|
125 |
+
q_w.to(device=orig_device),
|
126 |
+
g_idx.to(device=orig_device),
|
127 |
+
rand_perm.to(device=orig_device),
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def quantize_weights(
|
132 |
+
w: torch.Tensor,
|
133 |
+
quant_type: ScalarType,
|
134 |
+
group_size: Optional[int],
|
135 |
+
zero_points: bool = False,
|
136 |
+
ref_zero_points_after_scales: bool = False,
|
137 |
+
):
|
138 |
+
assert (
|
139 |
+
quant_type.is_integer()
|
140 |
+
), "Floating point quantization may work but has not been tested"
|
141 |
+
assert not zero_points or group_size is not None, (
|
142 |
+
"to have group zero points, group_size must be provided "
|
143 |
+
"(-1 group_size is channelwise)"
|
144 |
+
)
|
145 |
+
|
146 |
+
orig_device = w.device
|
147 |
+
orig_type = w.dtype
|
148 |
+
size_k, size_n = w.shape
|
149 |
+
|
150 |
+
assert w.is_floating_point(), "w must be float"
|
151 |
+
|
152 |
+
if group_size == -1:
|
153 |
+
group_size = size_k
|
154 |
+
|
155 |
+
# Reshape to [groupsize, -1]
|
156 |
+
if group_size is not None and group_size < size_k:
|
157 |
+
w = w.reshape((-1, group_size, size_n))
|
158 |
+
w = w.permute(1, 0, 2)
|
159 |
+
w = w.reshape((group_size, -1))
|
160 |
+
|
161 |
+
# Compute scale for each group
|
162 |
+
max_val = torch.max(w, 0, keepdim=True).values
|
163 |
+
min_val = torch.min(w, 0, keepdim=True).values
|
164 |
+
|
165 |
+
max_q_val = quant_type.max()
|
166 |
+
min_q_val = quant_type.min()
|
167 |
+
|
168 |
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
169 |
+
maybe_w_zp = None
|
170 |
+
if group_size is not None:
|
171 |
+
if zero_points:
|
172 |
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
173 |
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
174 |
+
maybe_w_zp = (
|
175 |
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
# If the bias is such that there are no possible negative/positive
|
179 |
+
# values, set the max value to inf to avoid divide by 0
|
180 |
+
w_s = torch.max(
|
181 |
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
182 |
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
183 |
+
)
|
184 |
+
|
185 |
+
# Quantize
|
186 |
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
187 |
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
188 |
+
|
189 |
+
# Compute ref (dequantized)
|
190 |
+
# For some kernels (namely Machete) the zero-points are applied after the
|
191 |
+
# scales are applied, for this case computing the reference in similar way
|
192 |
+
# allows us to use tighter error tolerances in our unit tests.
|
193 |
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
194 |
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
195 |
+
else:
|
196 |
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
197 |
+
|
198 |
+
if quant_type.has_bias():
|
199 |
+
w_q += quant_type.bias
|
200 |
+
|
201 |
+
# Restore original shapes
|
202 |
+
if group_size is not None and group_size < size_k:
|
203 |
+
|
204 |
+
def reshape_w(w):
|
205 |
+
w = w.reshape((group_size, -1, size_n))
|
206 |
+
w = w.permute(1, 0, 2)
|
207 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
208 |
+
return w
|
209 |
+
|
210 |
+
w_q = reshape_w(w_q)
|
211 |
+
w_ref = reshape_w(w_ref)
|
212 |
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
213 |
+
|
214 |
+
if maybe_w_zp is not None:
|
215 |
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
216 |
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
217 |
+
|
218 |
+
return (
|
219 |
+
w_ref.to(device=orig_device),
|
220 |
+
w_q.to(device=orig_device),
|
221 |
+
w_s if group_size is not None else None,
|
222 |
+
maybe_w_zp,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def gptq_quantize_weights(
|
227 |
+
w: torch.Tensor,
|
228 |
+
quant_type: ScalarType,
|
229 |
+
group_size: int,
|
230 |
+
act_order: bool,
|
231 |
+
test_perm: Optional[torch.Tensor] = None,
|
232 |
+
):
|
233 |
+
size_k, _ = w.shape
|
234 |
+
|
235 |
+
assert w.is_floating_point(), "w must be float"
|
236 |
+
assert (
|
237 |
+
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
|
238 |
+
), f"Unsupported gptq type = {quant_type}"
|
239 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
240 |
+
size_k
|
241 |
+
], f"Unsupported groupsize = {group_size}"
|
242 |
+
|
243 |
+
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
244 |
+
|
245 |
+
# Apply act_order
|
246 |
+
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
247 |
+
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
248 |
+
if act_order:
|
249 |
+
assert (
|
250 |
+
group_size < size_k
|
251 |
+
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
252 |
+
group_size, size_k
|
253 |
+
)
|
254 |
+
|
255 |
+
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
|
256 |
+
|
257 |
+
return w_ref, w_q, w_s, g_idx, rand_perm
|
258 |
+
|
259 |
+
|
260 |
+
# QQQ employs different quant schemes for per-group and
|
261 |
+
# per-channel quantization.
|
262 |
+
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
263 |
+
orig_device = w.device
|
264 |
+
size_k, size_n = w.shape
|
265 |
+
|
266 |
+
assert w.is_floating_point(), "w must be float"
|
267 |
+
assert (
|
268 |
+
num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
|
269 |
+
), f"Unsupported num_bits = {num_bits}"
|
270 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
271 |
+
size_k
|
272 |
+
], f"Unsupported groupsize = {group_size}"
|
273 |
+
|
274 |
+
if group_size == -1:
|
275 |
+
group_size = size_k
|
276 |
+
assert group_size <= size_k
|
277 |
+
|
278 |
+
if group_size < size_k:
|
279 |
+
# Reshape to [groupsize, -1]
|
280 |
+
w = w.reshape((-1, group_size, size_n))
|
281 |
+
w = w.permute(1, 0, 2)
|
282 |
+
w = w.reshape((group_size, -1))
|
283 |
+
|
284 |
+
max_q_val = 2**num_bits - 1
|
285 |
+
half_q_val = (max_q_val + 1) // 2
|
286 |
+
|
287 |
+
# Compute scale for each group
|
288 |
+
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
289 |
+
s_group *= 2 / max_q_val # 2 => symmetric
|
290 |
+
|
291 |
+
# Quantize
|
292 |
+
q_w = torch.round(w / s_group).int()
|
293 |
+
q_w += half_q_val
|
294 |
+
q_w = torch.clamp(q_w, 0, max_q_val)
|
295 |
+
# Compute ref (dequantized)
|
296 |
+
w_ref = (q_w - half_q_val).half() * s_group
|
297 |
+
|
298 |
+
# Restore original shapes
|
299 |
+
def reshape_w(w):
|
300 |
+
w = w.reshape((group_size, -1, size_n))
|
301 |
+
w = w.permute(1, 0, 2)
|
302 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
303 |
+
return w
|
304 |
+
|
305 |
+
q_w = reshape_w(q_w)
|
306 |
+
w_ref = reshape_w(w_ref)
|
307 |
+
|
308 |
+
# Compute int8 quantization scale for each channel
|
309 |
+
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
|
310 |
+
s_channel /= 127.0
|
311 |
+
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
|
312 |
+
w_ref = t_int8.half() * s_channel
|
313 |
+
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
|
314 |
+
|
315 |
+
# Fuse scales
|
316 |
+
s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
|
317 |
+
dtype=torch.half
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
max_q_val = 2 ** (num_bits - 1) - 1
|
321 |
+
|
322 |
+
# Compute scale for each channel
|
323 |
+
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
324 |
+
s_channel /= max_q_val
|
325 |
+
|
326 |
+
# Quantize
|
327 |
+
q_w = torch.round(w / s_channel).int()
|
328 |
+
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
|
329 |
+
# Compute ref (dequantized)
|
330 |
+
w_ref = q_w.half() * s_channel
|
331 |
+
|
332 |
+
s_group = torch.tensor([], dtype=torch.half)
|
333 |
+
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
|
334 |
+
s_channel /= 2 ** (8 - num_bits)
|
335 |
+
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
|
336 |
+
|
337 |
+
return (
|
338 |
+
w_ref.to(device=orig_device),
|
339 |
+
q_w.to(device=orig_device),
|
340 |
+
s_group.to(device=orig_device),
|
341 |
+
s_channel.to(device=orig_device),
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
346 |
+
orig_device = q_w.device
|
347 |
+
|
348 |
+
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
|
349 |
+
|
350 |
+
g_idx = g_idx[sort_indices].contiguous()
|
351 |
+
q_w = q_w[sort_indices, :].contiguous()
|
352 |
+
|
353 |
+
return (
|
354 |
+
q_w.to(device=orig_device),
|
355 |
+
g_idx.to(device=orig_device),
|
356 |
+
sort_indices.to(device=orig_device),
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
def pack_rows(
|
361 |
+
q_w: torch.Tensor,
|
362 |
+
num_bits: int,
|
363 |
+
size_k: int,
|
364 |
+
size_n: int,
|
365 |
+
):
|
366 |
+
assert q_w.shape == (size_k, size_n)
|
367 |
+
|
368 |
+
pack_factor = get_pack_factor(num_bits)
|
369 |
+
assert size_k % pack_factor == 0
|
370 |
+
|
371 |
+
orig_device = q_w.device
|
372 |
+
|
373 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
374 |
+
|
375 |
+
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
376 |
+
|
377 |
+
for i in range(pack_factor):
|
378 |
+
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
379 |
+
|
380 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
381 |
+
return q_res
|
382 |
+
|
383 |
+
|
384 |
+
def pack_cols(
|
385 |
+
q_w: torch.Tensor,
|
386 |
+
num_bits: int,
|
387 |
+
size_k: int,
|
388 |
+
size_n: int,
|
389 |
+
):
|
390 |
+
assert q_w.shape == (size_k, size_n)
|
391 |
+
|
392 |
+
pack_factor = get_pack_factor(num_bits)
|
393 |
+
assert size_n % pack_factor == 0
|
394 |
+
|
395 |
+
orig_device = q_w.device
|
396 |
+
|
397 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
398 |
+
|
399 |
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
400 |
+
|
401 |
+
for i in range(pack_factor):
|
402 |
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
403 |
+
|
404 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
405 |
+
q_res = q_res.contiguous()
|
406 |
+
|
407 |
+
return q_res
|
408 |
+
|
409 |
+
|
410 |
+
def unpack_cols(
|
411 |
+
packed_q_w: torch.Tensor,
|
412 |
+
num_bits: int,
|
413 |
+
size_k: int,
|
414 |
+
size_n: int,
|
415 |
+
):
|
416 |
+
pack_factor = get_pack_factor(num_bits)
|
417 |
+
assert size_n % pack_factor == 0
|
418 |
+
assert packed_q_w.shape == (
|
419 |
+
size_k,
|
420 |
+
size_n // pack_factor,
|
421 |
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
422 |
+
packed_q_w.shape, size_k, size_n, pack_factor
|
423 |
+
)
|
424 |
+
|
425 |
+
orig_device = packed_q_w.device
|
426 |
+
|
427 |
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
428 |
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
429 |
+
|
430 |
+
mask = (1 << num_bits) - 1
|
431 |
+
for i in range(pack_factor):
|
432 |
+
vals = packed_q_w_cpu & mask
|
433 |
+
packed_q_w_cpu >>= num_bits
|
434 |
+
q_res[:, i::pack_factor] = vals
|
435 |
+
|
436 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
437 |
+
q_res = q_res.contiguous()
|
438 |
+
|
439 |
+
return q_res
|
440 |
+
|
441 |
+
|
442 |
+
def gptq_pack(
|
443 |
+
q_w: torch.Tensor,
|
444 |
+
num_bits: int,
|
445 |
+
size_k: int,
|
446 |
+
size_n: int,
|
447 |
+
):
|
448 |
+
return pack_rows(q_w, num_bits, size_k, size_n)
|
449 |
+
|
450 |
+
|
451 |
+
def awq_pack(
|
452 |
+
q_w: torch.Tensor,
|
453 |
+
num_bits: int,
|
454 |
+
size_k: int,
|
455 |
+
size_n: int,
|
456 |
+
):
|
457 |
+
assert q_w.shape == (size_k, size_n)
|
458 |
+
|
459 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
460 |
+
if num_bits == 4:
|
461 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
462 |
+
elif num_bits == 8:
|
463 |
+
interleave = numpy.array([0, 2, 1, 3])
|
464 |
+
else:
|
465 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
466 |
+
|
467 |
+
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
468 |
+
q_w = q_w.reshape((-1, size_n)).contiguous()
|
469 |
+
|
470 |
+
return pack_cols(q_w, num_bits, size_k, size_n)
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py
CHANGED
@@ -1,150 +1,30 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
#if current_platform.is_rocm():
|
33 |
-
# triton_scaled_mm_module = importlib.import_module(
|
34 |
-
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
35 |
-
# "triton_scaled_mm")
|
36 |
-
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
37 |
-
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
38 |
-
|
39 |
-
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
40 |
-
|
41 |
-
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
42 |
-
|
43 |
-
return out
|
44 |
-
|
45 |
-
# fp8
|
46 |
-
def scaled_fp8_quant(
|
47 |
-
input: torch.Tensor,
|
48 |
-
scale: Optional[torch.Tensor] = None,
|
49 |
-
num_token_padding: Optional[int] = None,
|
50 |
-
scale_ub: Optional[torch.Tensor] = None,
|
51 |
-
use_per_token_if_dynamic: bool = False,
|
52 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
53 |
-
"""
|
54 |
-
Quantize input tensor to FP8 and return quantized tensor and scale.
|
55 |
-
|
56 |
-
This function supports both static and dynamic quantization: If you
|
57 |
-
provide the scale, it will use static scaling and if you omit it,
|
58 |
-
the scale will be determined dynamically. The function also allows
|
59 |
-
optional padding of the output tensors for downstream kernels that
|
60 |
-
will benefit from padding.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
input: The input tensor to be quantized to FP8
|
64 |
-
scale: Optional scaling factor for the FP8 quantization
|
65 |
-
scale_ub: Optional upper bound for scaling factor in dynamic
|
66 |
-
per token case
|
67 |
-
num_token_padding: If specified, pad the first dimension
|
68 |
-
of the output to at least this value.
|
69 |
-
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
70 |
-
in the dynamic quantization case.
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
74 |
-
scaling factor.
|
75 |
-
"""
|
76 |
-
# This code assumes batch_dim and num_tokens are flattened
|
77 |
-
assert (input.ndim == 2)
|
78 |
-
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
79 |
-
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
80 |
-
#out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
81 |
-
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
82 |
-
out_dtype = torch.float8_e4m3fn
|
83 |
-
if num_token_padding:
|
84 |
-
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
85 |
-
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
86 |
-
|
87 |
-
if scale is None:
|
88 |
-
if use_per_token_if_dynamic:
|
89 |
-
scale = torch.empty((shape[0], 1),
|
90 |
-
device=input.device,
|
91 |
-
dtype=torch.float32)
|
92 |
-
ops.dynamic_per_token_scaled_fp8_quant(
|
93 |
-
output, input, scale, scale_ub)
|
94 |
-
else:
|
95 |
-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
96 |
-
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
97 |
-
else:
|
98 |
-
# num_token_padding not implemented for this case
|
99 |
-
assert (scale.numel() == 1 or num_token_padding is None)
|
100 |
-
ops.static_scaled_fp8_quant(output, input, scale)
|
101 |
-
|
102 |
-
return output, scale
|
103 |
-
|
104 |
-
# int8
|
105 |
-
def scaled_int8_quant(
|
106 |
-
input: torch.Tensor,
|
107 |
-
scale: Optional[torch.Tensor] = None,
|
108 |
-
azp: Optional[torch.Tensor] = None,
|
109 |
-
symmetric: bool = True
|
110 |
-
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
111 |
-
"""
|
112 |
-
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
113 |
-
|
114 |
-
Args:
|
115 |
-
input: The input tensor to be quantized to int8.
|
116 |
-
scale: Optional scaling factor for the int8 quantization.
|
117 |
-
When not provided, we invoke dynamic-per-token quantization.
|
118 |
-
azp: Optional zero-point for the int8 quantization.
|
119 |
-
Must be provided for asymmetric quantization if `scale` is provided.
|
120 |
-
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
121 |
-
|
122 |
-
Returns:
|
123 |
-
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
124 |
-
"""
|
125 |
-
output = torch.empty_like(input, dtype=torch.int8)
|
126 |
-
if scale is not None:
|
127 |
-
# static-per-tensor quantization.
|
128 |
-
assert symmetric == (
|
129 |
-
azp is
|
130 |
-
None), "azp must only be provided for asymmetric quantization."
|
131 |
-
ops.static_scaled_int8_quant(output, input, scale, azp)
|
132 |
-
return output, scale, azp
|
133 |
-
|
134 |
-
# dynamic-per-token quantization.
|
135 |
-
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
136 |
-
device=input.device,
|
137 |
-
dtype=torch.float32)
|
138 |
-
input_azp = None if symmetric else torch.empty_like(input_scales,
|
139 |
-
dtype=torch.int32)
|
140 |
-
ops.dynamic_scaled_int8_quant(output, input, input_scales,
|
141 |
-
input_azp)
|
142 |
-
return output, input_scales, input_azp
|
143 |
-
|
144 |
-
# fp8 marlin
|
145 |
-
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
146 |
-
b_scales: torch.Tensor, workspace: torch.Tensor,
|
147 |
-
num_bits: int, size_m: int, size_n: int,
|
148 |
-
size_k: int) -> torch.Tensor:
|
149 |
-
return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
150 |
-
num_bits, size_m, size_n, size_k)
|
|
|
1 |
+
from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
|
2 |
+
from .cutlass import (
|
3 |
+
cutlass_scaled_mm_supports_fp8,
|
4 |
+
cutlass_scaled_mm,
|
5 |
+
cutlass_scaled_mm_azp,
|
6 |
+
)
|
7 |
+
from .marlin import (
|
8 |
+
awq_marlin_repack,
|
9 |
+
fp8_marlin_gemm,
|
10 |
+
gptq_marlin_gemm,
|
11 |
+
gptq_marlin_repack,
|
12 |
+
gptq_marlin_24_gemm,
|
13 |
+
marlin_qqq_gemm,
|
14 |
+
marlin_gemm,
|
15 |
+
)
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"awq_marlin_repack",
|
19 |
+
"cutlass_scaled_mm",
|
20 |
+
"cutlass_scaled_mm_azp",
|
21 |
+
"cutlass_scaled_mm_supports_fp8",
|
22 |
+
"fp8_marlin_gemm",
|
23 |
+
"gptq_marlin_24_gemm",
|
24 |
+
"gptq_marlin_gemm",
|
25 |
+
"gptq_marlin_repack",
|
26 |
+
"marlin_gemm",
|
27 |
+
"marlin_qqq_gemm",
|
28 |
+
"scaled_fp8_quant",
|
29 |
+
"scaled_int8_quant",
|
30 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/_ops.py
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
import torch
|
2 |
from . import _quantization_0_0_1
|
3 |
ops = torch.ops._quantization_0_0_1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from . import _quantization_0_0_1
|
3 |
ops = torch.ops._quantization_0_0_1
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_quantization_0_0_1::{op_name}"
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0abd1636906a69e8cf7d85fdfc7b99b6b4f4cc3d753431ad3d49ba674238c27
|
3 |
+
size 105267936
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
# fp8
|
18 |
+
def scaled_fp8_quant(
|
19 |
+
input: torch.Tensor,
|
20 |
+
scale: Optional[torch.Tensor] = None,
|
21 |
+
num_token_padding: Optional[int] = None,
|
22 |
+
scale_ub: Optional[torch.Tensor] = None,
|
23 |
+
use_per_token_if_dynamic: bool = False,
|
24 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
"""
|
26 |
+
Quantize input tensor to FP8 and return quantized tensor and scale.
|
27 |
+
|
28 |
+
This function supports both static and dynamic quantization: If you
|
29 |
+
provide the scale, it will use static scaling and if you omit it,
|
30 |
+
the scale will be determined dynamically. The function also allows
|
31 |
+
optional padding of the output tensors for downstream kernels that
|
32 |
+
will benefit from padding.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
input: The input tensor to be quantized to FP8
|
36 |
+
scale: Optional scaling factor for the FP8 quantization
|
37 |
+
scale_ub: Optional upper bound for scaling factor in dynamic
|
38 |
+
per token case
|
39 |
+
num_token_padding: If specified, pad the first dimension
|
40 |
+
of the output to at least this value.
|
41 |
+
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
42 |
+
in the dynamic quantization case.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
46 |
+
scaling factor.
|
47 |
+
"""
|
48 |
+
# This code assumes batch_dim and num_tokens are flattened
|
49 |
+
assert input.ndim == 2
|
50 |
+
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
51 |
+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
52 |
+
# out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
53 |
+
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
54 |
+
out_dtype = torch.float8_e4m3fn
|
55 |
+
if num_token_padding:
|
56 |
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
57 |
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
58 |
+
|
59 |
+
if scale is None:
|
60 |
+
if use_per_token_if_dynamic:
|
61 |
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
62 |
+
ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
63 |
+
else:
|
64 |
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
65 |
+
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
66 |
+
else:
|
67 |
+
# num_token_padding not implemented for this case
|
68 |
+
assert scale.numel() == 1 or num_token_padding is None
|
69 |
+
ops.static_scaled_fp8_quant(output, input, scale)
|
70 |
+
|
71 |
+
return output, scale
|
72 |
+
|
73 |
+
|
74 |
+
# int8
|
75 |
+
def scaled_int8_quant(
|
76 |
+
input: torch.Tensor,
|
77 |
+
scale: Optional[torch.Tensor] = None,
|
78 |
+
azp: Optional[torch.Tensor] = None,
|
79 |
+
symmetric: bool = True,
|
80 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
81 |
+
"""
|
82 |
+
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
input: The input tensor to be quantized to int8.
|
86 |
+
scale: Optional scaling factor for the int8 quantization.
|
87 |
+
When not provided, we invoke dynamic-per-token quantization.
|
88 |
+
azp: Optional zero-point for the int8 quantization.
|
89 |
+
Must be provided for asymmetric quantization if `scale` is provided.
|
90 |
+
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
94 |
+
"""
|
95 |
+
output = torch.empty_like(input, dtype=torch.int8)
|
96 |
+
if scale is not None:
|
97 |
+
# static-per-tensor quantization.
|
98 |
+
assert symmetric == (
|
99 |
+
azp is None
|
100 |
+
), "azp must only be provided for asymmetric quantization."
|
101 |
+
ops.static_scaled_int8_quant(output, input, scale, azp)
|
102 |
+
return output, scale, azp
|
103 |
+
|
104 |
+
# dynamic-per-token quantization.
|
105 |
+
input_scales = torch.empty(
|
106 |
+
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
|
107 |
+
)
|
108 |
+
input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
|
109 |
+
ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
|
110 |
+
return output, input_scales, input_azp
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/cutlass.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
18 |
+
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
19 |
+
|
20 |
+
|
21 |
+
def cutlass_scaled_mm(
|
22 |
+
a: torch.Tensor,
|
23 |
+
b: torch.Tensor,
|
24 |
+
scale_a: torch.Tensor,
|
25 |
+
scale_b: torch.Tensor,
|
26 |
+
out_dtype: torch.dtype,
|
27 |
+
bias: Optional[torch.Tensor] = None,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
30 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
31 |
+
assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
|
32 |
+
|
33 |
+
m = a.shape[0]
|
34 |
+
n = b.shape[1]
|
35 |
+
|
36 |
+
# if current_platform.is_rocm():
|
37 |
+
# triton_scaled_mm_module = importlib.import_module(
|
38 |
+
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
39 |
+
# "triton_scaled_mm")
|
40 |
+
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
41 |
+
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
42 |
+
|
43 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
44 |
+
|
45 |
+
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
def cutlass_scaled_mm_azp(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b: torch.Tensor,
|
53 |
+
scale_a: torch.Tensor,
|
54 |
+
scale_b: torch.Tensor,
|
55 |
+
out_dtype: torch.dtype,
|
56 |
+
azp_adj: torch.Tensor,
|
57 |
+
azp: Optional[torch.Tensor] = None,
|
58 |
+
bias: Optional[torch.Tensor] = None,
|
59 |
+
) -> torch.Tensor:
|
60 |
+
"""
|
61 |
+
:param azp_adj: In the per-tensor case, this should include the azp.
|
62 |
+
Always per-channel.
|
63 |
+
:param azp: Only set in the per-token case. Per-token if set.
|
64 |
+
"""
|
65 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
66 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
67 |
+
assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
|
68 |
+
assert azp is None or azp.numel() == a.shape[0]
|
69 |
+
|
70 |
+
m = a.shape[0]
|
71 |
+
n = b.shape[1]
|
72 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
73 |
+
|
74 |
+
ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
|
75 |
+
return out
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/marlin.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# neuron has torch version that doesn't even have impl_abstract
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
def register_fake(fn):
|
8 |
+
return lambda name: fn
|
9 |
+
else:
|
10 |
+
try:
|
11 |
+
from torch.library import register_fake
|
12 |
+
except ImportError:
|
13 |
+
from torch.library import impl_abstract as register_fake
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ._ops import ops, add_op_namespace_prefix
|
17 |
+
except ImportError as e:
|
18 |
+
# Fallback for local development.
|
19 |
+
try:
|
20 |
+
import _quantization
|
21 |
+
|
22 |
+
ops = torch.ops._quantization
|
23 |
+
|
24 |
+
def add_op_namespace_prefix(op_name: str):
|
25 |
+
return f"_quantization::{op_name}"
|
26 |
+
except ImportError:
|
27 |
+
raise e
|
28 |
+
|
29 |
+
|
30 |
+
from .scalar_type import ScalarType
|
31 |
+
|
32 |
+
|
33 |
+
# fp8 marlin
|
34 |
+
def fp8_marlin_gemm(
|
35 |
+
a: torch.Tensor,
|
36 |
+
b_q_weight: torch.Tensor,
|
37 |
+
b_scales: torch.Tensor,
|
38 |
+
workspace: torch.Tensor,
|
39 |
+
num_bits: int,
|
40 |
+
size_m: int,
|
41 |
+
size_n: int,
|
42 |
+
size_k: int,
|
43 |
+
) -> torch.Tensor:
|
44 |
+
return ops.fp8_marlin_gemm(
|
45 |
+
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
# gptq_marlin
|
50 |
+
def gptq_marlin_gemm(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b_q_weight: torch.Tensor,
|
53 |
+
b_scales: torch.Tensor,
|
54 |
+
b_zeros: torch.Tensor,
|
55 |
+
g_idx: torch.Tensor,
|
56 |
+
perm: torch.Tensor,
|
57 |
+
workspace: torch.Tensor,
|
58 |
+
b_q_type: ScalarType,
|
59 |
+
size_m: int,
|
60 |
+
size_n: int,
|
61 |
+
size_k: int,
|
62 |
+
is_k_full: bool,
|
63 |
+
has_zp: bool = False,
|
64 |
+
use_fp32_reduce: bool = False,
|
65 |
+
is_zp_float: bool = False,
|
66 |
+
) -> torch.Tensor:
|
67 |
+
return ops.gptq_marlin_gemm(
|
68 |
+
a,
|
69 |
+
b_q_weight,
|
70 |
+
b_scales,
|
71 |
+
b_zeros,
|
72 |
+
g_idx,
|
73 |
+
perm,
|
74 |
+
workspace,
|
75 |
+
b_q_type.id,
|
76 |
+
size_m,
|
77 |
+
size_n,
|
78 |
+
size_k,
|
79 |
+
is_k_full,
|
80 |
+
has_zp,
|
81 |
+
use_fp32_reduce,
|
82 |
+
is_zp_float,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
# gptq_marlin
|
87 |
+
def gptq_marlin_repack(
|
88 |
+
b_q_weight: torch.Tensor,
|
89 |
+
perm: torch.Tensor,
|
90 |
+
size_k: int,
|
91 |
+
size_n: int,
|
92 |
+
num_bits: int,
|
93 |
+
) -> torch.Tensor:
|
94 |
+
return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
95 |
+
|
96 |
+
|
97 |
+
# gptq_marlin
|
98 |
+
def awq_marlin_repack(
|
99 |
+
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
100 |
+
) -> torch.Tensor:
|
101 |
+
return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
102 |
+
|
103 |
+
|
104 |
+
# marlin
|
105 |
+
def marlin_gemm(
|
106 |
+
a: torch.Tensor,
|
107 |
+
b_q_weight: torch.Tensor,
|
108 |
+
b_scales: torch.Tensor,
|
109 |
+
workspace: torch.Tensor,
|
110 |
+
size_m: int,
|
111 |
+
size_n: int,
|
112 |
+
size_k: int,
|
113 |
+
) -> torch.Tensor:
|
114 |
+
return ops.marlin_gemm(
|
115 |
+
a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# marlin_24
|
120 |
+
def gptq_marlin_24_gemm(
|
121 |
+
a: torch.Tensor,
|
122 |
+
b_q_weight: torch.Tensor,
|
123 |
+
b_meta: torch.Tensor,
|
124 |
+
b_scales: torch.Tensor,
|
125 |
+
workspace: torch.Tensor,
|
126 |
+
b_q_type: ScalarType,
|
127 |
+
size_m: int,
|
128 |
+
size_n: int,
|
129 |
+
size_k: int,
|
130 |
+
) -> torch.Tensor:
|
131 |
+
return ops.gptq_marlin_24_gemm(
|
132 |
+
a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
# qqq ops
|
137 |
+
def marlin_qqq_gemm(
|
138 |
+
a: torch.Tensor,
|
139 |
+
b_q_weight: torch.Tensor,
|
140 |
+
s_tok: torch.Tensor,
|
141 |
+
s_ch: torch.Tensor,
|
142 |
+
s_group: torch.Tensor,
|
143 |
+
workspace: torch.Tensor,
|
144 |
+
size_m: int,
|
145 |
+
size_n: int,
|
146 |
+
size_k: int,
|
147 |
+
) -> torch.Tensor:
|
148 |
+
return ops.marlin_qqq_gemm(
|
149 |
+
a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
# Fake ops
|
154 |
+
|
155 |
+
if hasattr(ops, "gptq_marlin_24_gemm"):
|
156 |
+
@register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
|
157 |
+
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
158 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
159 |
+
num_bits: int, size_m: torch.SymInt,
|
160 |
+
size_n: torch.SymInt,
|
161 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
162 |
+
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
163 |
+
|
164 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
|
165 |
+
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
166 |
+
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
167 |
+
workspace: torch.Tensor,
|
168 |
+
b_q_type: ScalarType, size_m: torch.SymInt,
|
169 |
+
size_n: torch.SymInt,
|
170 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
171 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
172 |
+
|
173 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
|
174 |
+
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
175 |
+
b_q_weight: torch.Tensor,
|
176 |
+
b_scales: torch.Tensor,
|
177 |
+
b_zeros: torch.Tensor,
|
178 |
+
g_idx: torch.Tensor,
|
179 |
+
perm: torch.Tensor,
|
180 |
+
workspace: torch.Tensor,
|
181 |
+
b_q_type: ScalarType,
|
182 |
+
size_m: torch.SymInt,
|
183 |
+
size_n: torch.SymInt,
|
184 |
+
size_k: torch.SymInt,
|
185 |
+
is_k_full: bool,
|
186 |
+
has_zp: bool = False,
|
187 |
+
use_fp32_reduce: bool = False,
|
188 |
+
is_zp_float: bool = False) -> torch.Tensor:
|
189 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
190 |
+
|
191 |
+
@register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
|
192 |
+
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
193 |
+
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
194 |
+
s_group: torch.Tensor, workspace: torch.Tensor,
|
195 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
196 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
197 |
+
return torch.empty((size_m, size_n),
|
198 |
+
dtype=torch.float16,
|
199 |
+
device=a.device)
|
200 |
+
|
201 |
+
@register_fake(add_op_namespace_prefix("marlin_gemm"))
|
202 |
+
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
203 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
204 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
205 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
206 |
+
return torch.empty((size_m, size_n),
|
207 |
+
dtype=torch.float16,
|
208 |
+
device=a.device)
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/scalar_type.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import struct
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
# Mirrors enum in `core/scalar_type.hpp`
|
9 |
+
class NanRepr(Enum):
|
10 |
+
NONE = 0 # nans are not supported
|
11 |
+
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
12 |
+
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
13 |
+
|
14 |
+
|
15 |
+
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
16 |
+
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
17 |
+
# in sync until the inductor fully supports custom C++ classes.
|
18 |
+
@dataclass(frozen=True)
|
19 |
+
class ScalarType:
|
20 |
+
"""
|
21 |
+
ScalarType can represent a wide range of floating point and integer
|
22 |
+
types, in particular it can be used to represent sub-byte data types
|
23 |
+
(something that torch.dtype currently does not support). It is also
|
24 |
+
capable of representing types with a bias, i.e.:
|
25 |
+
`stored_value = value + bias`,
|
26 |
+
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
27 |
+
of 8). The implementation for this class can be found in
|
28 |
+
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
29 |
+
with that file.
|
30 |
+
"""
|
31 |
+
|
32 |
+
exponent: int
|
33 |
+
"""
|
34 |
+
Number of bits in the exponent if this is a floating point type
|
35 |
+
(zero if this an integer type)
|
36 |
+
"""
|
37 |
+
|
38 |
+
mantissa: int
|
39 |
+
"""
|
40 |
+
Number of bits in the mantissa if this is a floating point type,
|
41 |
+
or the number bits representing an integer excluding the sign bit if
|
42 |
+
this an integer type.
|
43 |
+
"""
|
44 |
+
|
45 |
+
signed: bool
|
46 |
+
"If the type is signed (i.e. has a sign bit)"
|
47 |
+
|
48 |
+
bias: int
|
49 |
+
"""
|
50 |
+
bias used to encode the values in this scalar type
|
51 |
+
(value = stored_value - bias, default 0) for example if we store the
|
52 |
+
type as an unsigned integer with a bias of 128 then the value 0 will be
|
53 |
+
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
54 |
+
"""
|
55 |
+
|
56 |
+
_finite_values_only: bool = False
|
57 |
+
"""
|
58 |
+
Private: if infs are supported, used `has_infs()` instead.
|
59 |
+
"""
|
60 |
+
|
61 |
+
nan_repr: NanRepr = NanRepr.IEEE_754
|
62 |
+
"""
|
63 |
+
How NaNs are represent in this scalar type, returns NanRepr value.
|
64 |
+
(not applicable for integer types)
|
65 |
+
"""
|
66 |
+
|
67 |
+
def _floating_point_max_int(self) -> int:
|
68 |
+
assert (
|
69 |
+
self.mantissa <= 52 and self.exponent <= 11
|
70 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
71 |
+
|
72 |
+
max_mantissa = (1 << self.mantissa) - 1
|
73 |
+
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
74 |
+
max_mantissa = max_mantissa - 1
|
75 |
+
|
76 |
+
max_exponent = (1 << self.exponent) - 2
|
77 |
+
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
78 |
+
or self.nan_repr == NanRepr.NONE):
|
79 |
+
assert (
|
80 |
+
self.exponent < 11
|
81 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
82 |
+
max_exponent = max_exponent + 1
|
83 |
+
|
84 |
+
# adjust the exponent to match that of a double
|
85 |
+
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
86 |
+
# e is the exponent bits), there is some precedent for non-standard
|
87 |
+
# biases, example `float8_e4m3b11fnuz` here:
|
88 |
+
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
89 |
+
# complication we are just assuming the standard exponent bias until
|
90 |
+
# there is a need to support non-standard biases
|
91 |
+
exponent_bias = (1 << (self.exponent - 1)) - 1
|
92 |
+
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
93 |
+
|
94 |
+
max_exponent_double = (max_exponent - exponent_bias +
|
95 |
+
exponent_bias_double)
|
96 |
+
|
97 |
+
# shift the mantissa and exponent into the proper positions for an
|
98 |
+
# IEEE double and bitwise-or them together.
|
99 |
+
return (max_mantissa <<
|
100 |
+
(52 - self.mantissa)) | (max_exponent_double << 52)
|
101 |
+
|
102 |
+
def _floating_point_max(self) -> float:
|
103 |
+
double_raw = self._floating_point_max_int()
|
104 |
+
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
105 |
+
|
106 |
+
def _raw_max(self) -> Union[int, float]:
|
107 |
+
if self.is_floating_point():
|
108 |
+
return self._floating_point_max()
|
109 |
+
else:
|
110 |
+
assert (self.size_bits < 64 or self.size_bits == 64
|
111 |
+
and self.is_signed()), "Cannot represent max as an int"
|
112 |
+
return (1 << self.mantissa) - 1
|
113 |
+
|
114 |
+
def _raw_min(self) -> Union[int, float]:
|
115 |
+
if self.is_floating_point():
|
116 |
+
assert self.is_signed(
|
117 |
+
), "We currently assume all floating point types are signed"
|
118 |
+
sign_bit_double = 1 << 63
|
119 |
+
|
120 |
+
max_raw = self._floating_point_max_int()
|
121 |
+
min_raw = max_raw | sign_bit_double
|
122 |
+
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
123 |
+
else:
|
124 |
+
assert (not self.is_signed() or
|
125 |
+
self.size_bits <= 64), "Cannot represent min as a int64_t"
|
126 |
+
|
127 |
+
if self.is_signed():
|
128 |
+
return -(1 << (self.size_bits - 1))
|
129 |
+
else:
|
130 |
+
return 0
|
131 |
+
|
132 |
+
@functools.cached_property
|
133 |
+
def id(self) -> int:
|
134 |
+
"""
|
135 |
+
Convert the ScalarType to an int which can be passed to pytorch custom
|
136 |
+
ops. This layout of the int must be kept in sync with the C++
|
137 |
+
ScalarType's from_id method.
|
138 |
+
"""
|
139 |
+
val = 0
|
140 |
+
offset = 0
|
141 |
+
|
142 |
+
def or_and_advance(member, bit_width):
|
143 |
+
nonlocal val
|
144 |
+
nonlocal offset
|
145 |
+
bit_mask = (1 << bit_width) - 1
|
146 |
+
val = val | (int(member) & bit_mask) << offset
|
147 |
+
offset = offset + bit_width
|
148 |
+
|
149 |
+
or_and_advance(self.exponent, 8)
|
150 |
+
or_and_advance(self.mantissa, 8)
|
151 |
+
or_and_advance(self.signed, 1)
|
152 |
+
or_and_advance(self.bias, 32)
|
153 |
+
or_and_advance(self._finite_values_only, 1)
|
154 |
+
or_and_advance(self.nan_repr.value, 8)
|
155 |
+
|
156 |
+
assert offset <= 64, \
|
157 |
+
f"ScalarType fields too big {offset} to fit into an int64"
|
158 |
+
|
159 |
+
return val
|
160 |
+
|
161 |
+
@property
|
162 |
+
def size_bits(self) -> int:
|
163 |
+
return self.exponent + self.mantissa + int(self.signed)
|
164 |
+
|
165 |
+
def min(self) -> Union[int, float]:
|
166 |
+
"""
|
167 |
+
Min representable value for this scalar type.
|
168 |
+
(accounting for bias if there is one)
|
169 |
+
"""
|
170 |
+
return self._raw_min() - self.bias
|
171 |
+
|
172 |
+
def max(self) -> Union[int, float]:
|
173 |
+
"""
|
174 |
+
Max representable value for this scalar type.
|
175 |
+
(accounting for bias if there is one)
|
176 |
+
"""
|
177 |
+
return self._raw_max() - self.bias
|
178 |
+
|
179 |
+
def is_signed(self) -> bool:
|
180 |
+
"""
|
181 |
+
If the type is signed (i.e. has a sign bit), same as `signed`
|
182 |
+
added for consistency with:
|
183 |
+
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
184 |
+
"""
|
185 |
+
return self.signed
|
186 |
+
|
187 |
+
def is_floating_point(self) -> bool:
|
188 |
+
"If the type is a floating point type"
|
189 |
+
return self.exponent != 0
|
190 |
+
|
191 |
+
def is_integer(self) -> bool:
|
192 |
+
"If the type is an integer type"
|
193 |
+
return self.exponent == 0
|
194 |
+
|
195 |
+
def has_bias(self) -> bool:
|
196 |
+
"If the type has a non-zero bias"
|
197 |
+
return self.bias != 0
|
198 |
+
|
199 |
+
def has_infs(self) -> bool:
|
200 |
+
"If the type is floating point and supports infinity"
|
201 |
+
return not self._finite_values_only
|
202 |
+
|
203 |
+
def has_nans(self) -> bool:
|
204 |
+
return self.nan_repr != NanRepr.NONE.value
|
205 |
+
|
206 |
+
def is_ieee_754(self) -> bool:
|
207 |
+
"""
|
208 |
+
If the type is a floating point type that follows IEEE 754
|
209 |
+
conventions
|
210 |
+
"""
|
211 |
+
return self.nan_repr == NanRepr.IEEE_754.value and \
|
212 |
+
not self._finite_values_only
|
213 |
+
|
214 |
+
def __str__(self) -> str:
|
215 |
+
"""
|
216 |
+
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
217 |
+
for floating point types (leading f) the scheme is:
|
218 |
+
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
219 |
+
flags:
|
220 |
+
- no-flags: means it follows IEEE 754 conventions
|
221 |
+
- f: means finite values only (no infinities)
|
222 |
+
- n: means nans are supported (non-standard encoding)
|
223 |
+
for integer types the scheme is:
|
224 |
+
`[u]int<size_bits>[b<bias>]`
|
225 |
+
- if bias is not present it means its zero
|
226 |
+
"""
|
227 |
+
if self.is_floating_point():
|
228 |
+
ret = "float" + str(self.size_bits) + "_e" + str(
|
229 |
+
self.exponent) + "m" + str(self.mantissa)
|
230 |
+
|
231 |
+
if not self.is_ieee_754():
|
232 |
+
if self._finite_values_only:
|
233 |
+
ret = ret + "f"
|
234 |
+
if self.nan_repr != NanRepr.NONE:
|
235 |
+
ret = ret + "n"
|
236 |
+
|
237 |
+
return ret
|
238 |
+
else:
|
239 |
+
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
240 |
+
if self.has_bias():
|
241 |
+
ret = ret + "b" + str(self.bias)
|
242 |
+
return ret
|
243 |
+
|
244 |
+
def __repr__(self) -> str:
|
245 |
+
return "ScalarType." + self.__str__()
|
246 |
+
|
247 |
+
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
248 |
+
# opcheck to work.
|
249 |
+
def __len__(self) -> int:
|
250 |
+
raise TypeError
|
251 |
+
|
252 |
+
#
|
253 |
+
# Convenience Constructors
|
254 |
+
#
|
255 |
+
|
256 |
+
@classmethod
|
257 |
+
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
258 |
+
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
259 |
+
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
260 |
+
ret.id # noqa B018: make sure the id is cached
|
261 |
+
return ret
|
262 |
+
|
263 |
+
@classmethod
|
264 |
+
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
265 |
+
"""Create a unsigned integer scalar type."""
|
266 |
+
ret = cls(0, size_bits, False, bias if bias else 0)
|
267 |
+
ret.id # noqa B018: make sure the id is cached
|
268 |
+
return ret
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
272 |
+
"""
|
273 |
+
Create a standard floating point type
|
274 |
+
(i.e. follows IEEE 754 conventions).
|
275 |
+
"""
|
276 |
+
assert (mantissa > 0 and exponent > 0)
|
277 |
+
ret = cls(exponent, mantissa, True, 0)
|
278 |
+
ret.id # noqa B018: make sure the id is cached
|
279 |
+
return ret
|
280 |
+
|
281 |
+
@classmethod
|
282 |
+
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
283 |
+
nan_repr: NanRepr) -> 'ScalarType':
|
284 |
+
"""
|
285 |
+
Create a non-standard floating point type
|
286 |
+
(i.e. does not follow IEEE 754 conventions).
|
287 |
+
"""
|
288 |
+
assert (mantissa > 0 and exponent > 0)
|
289 |
+
assert (nan_repr != NanRepr.IEEE_754), (
|
290 |
+
"use `float_IEEE754` constructor for floating point types that "
|
291 |
+
"follow IEEE 754 conventions")
|
292 |
+
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
293 |
+
ret.id # noqa B018: make sure the id is cached
|
294 |
+
return ret
|
295 |
+
|
296 |
+
|
297 |
+
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
298 |
+
# for floating point types (leading f) the scheme is:
|
299 |
+
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
300 |
+
# flags:
|
301 |
+
# - no-flags: means it follows IEEE 754 conventions
|
302 |
+
# - f: means finite values only (no infinities)
|
303 |
+
# - n: means nans are supported (non-standard encoding)
|
304 |
+
# for integer types the scheme is:
|
305 |
+
# `[u]int<size_bits>[b<bias>]`
|
306 |
+
# - if bias is not present it means its zero
|
307 |
+
|
308 |
+
|
309 |
+
class scalar_types:
|
310 |
+
int4 = ScalarType.int_(4, None)
|
311 |
+
uint4 = ScalarType.uint(4, None)
|
312 |
+
int8 = ScalarType.int_(8, None)
|
313 |
+
uint8 = ScalarType.uint(8, None)
|
314 |
+
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
315 |
+
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
316 |
+
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
317 |
+
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
318 |
+
|
319 |
+
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
320 |
+
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
321 |
+
|
322 |
+
# "gptq" types
|
323 |
+
uint2b2 = ScalarType.uint(2, 2)
|
324 |
+
uint3b4 = ScalarType.uint(3, 4)
|
325 |
+
uint4b8 = ScalarType.uint(4, 8)
|
326 |
+
uint8b128 = ScalarType.uint(8, 128)
|
327 |
+
|
328 |
+
# colloquial names
|
329 |
+
bfloat16 = float16_e8m7
|
330 |
+
float16 = float16_e5m10
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py
ADDED
File without changes
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import quantization as ops
|
7 |
+
from quantization.scalar_type import ScalarType, scalar_types
|
8 |
+
|
9 |
+
from .quant_utils import pack_cols, unpack_cols
|
10 |
+
|
11 |
+
GPTQ_MARLIN_TILE = 16
|
12 |
+
GPTQ_MARLIN_MIN_THREAD_N = 64
|
13 |
+
GPTQ_MARLIN_MIN_THREAD_K = 128
|
14 |
+
GPTQ_MARLIN_MAX_PARALLEL = 16
|
15 |
+
|
16 |
+
GPTQ_MARLIN_24_TILE = 16
|
17 |
+
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
18 |
+
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
19 |
+
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
20 |
+
|
21 |
+
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
22 |
+
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
23 |
+
|
24 |
+
MARLIN_QQQ_TILE = 16
|
25 |
+
MARLIN_QQQ_MIN_THREAD_N = 64
|
26 |
+
MARLIN_QQQ_MIN_THREAD_K = 128
|
27 |
+
MARLIN_QQQ_MAX_PARALLEL = 16
|
28 |
+
|
29 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
30 |
+
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
31 |
+
MARLIN_QQQ_SUPPORTED_SYM = [True]
|
32 |
+
|
33 |
+
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
34 |
+
|
35 |
+
# In case there is a performance issue with Marlin, the variable below can be
|
36 |
+
# changed to False, which allows Marlin to perform global reductions in fp16
|
37 |
+
# precision (instead of fp32), and therefore, save on some memory movements.
|
38 |
+
USE_FP32_REDUCE_DEFAULT = True
|
39 |
+
|
40 |
+
|
41 |
+
# For binary size and compile time, we don't support the same types for with and
|
42 |
+
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
43 |
+
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
44 |
+
def query_marlin_supported_quant_types(
|
45 |
+
has_zp: bool, device_capability: Optional[int] = None
|
46 |
+
):
|
47 |
+
if device_capability is None:
|
48 |
+
capability_tuple = torch.cuda.get_device_capability()
|
49 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
50 |
+
|
51 |
+
if device_capability < 80:
|
52 |
+
return []
|
53 |
+
|
54 |
+
if has_zp:
|
55 |
+
# AWQ style, unsigned + runtime zero-point
|
56 |
+
return [scalar_types.uint4, scalar_types.uint8]
|
57 |
+
else:
|
58 |
+
# GPTQ style, unsigned + symmetric bias
|
59 |
+
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
60 |
+
# to add `scalar_types.float8_e4m3fn` here
|
61 |
+
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
62 |
+
|
63 |
+
|
64 |
+
def _check_marlin_supported(
|
65 |
+
quant_type: ScalarType,
|
66 |
+
group_size: Optional[int],
|
67 |
+
has_zp: bool,
|
68 |
+
device_capability: Optional[int] = None,
|
69 |
+
) -> Tuple[bool, Optional[str]]:
|
70 |
+
|
71 |
+
if device_capability is None:
|
72 |
+
capability_tuple = torch.cuda.get_device_capability()
|
73 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
74 |
+
|
75 |
+
supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
|
76 |
+
|
77 |
+
if quant_type not in supported_types:
|
78 |
+
return (
|
79 |
+
False,
|
80 |
+
f"Marlin does not support weight_bits = {quant_type}. "
|
81 |
+
f"Only types = {supported_types} "
|
82 |
+
f"are supported (for group_size = {group_size}, "
|
83 |
+
f"device_capability = {device_capability}, zp = {has_zp}).",
|
84 |
+
)
|
85 |
+
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
86 |
+
return (
|
87 |
+
False,
|
88 |
+
f"Marlin does not support group_size = {group_size}. "
|
89 |
+
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
90 |
+
"are supported.",
|
91 |
+
)
|
92 |
+
|
93 |
+
return True, None
|
94 |
+
|
95 |
+
|
96 |
+
def check_marlin_supported(
|
97 |
+
quant_type: ScalarType,
|
98 |
+
group_size: int,
|
99 |
+
has_zp: bool = False,
|
100 |
+
device_capability: Optional[int] = None,
|
101 |
+
) -> bool:
|
102 |
+
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
|
103 |
+
return cond
|
104 |
+
|
105 |
+
|
106 |
+
def verify_marlin_supported(
|
107 |
+
quant_type: ScalarType, group_size: int, has_zp: bool = False
|
108 |
+
) -> None:
|
109 |
+
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
|
110 |
+
if not cond:
|
111 |
+
assert err_msg is not None
|
112 |
+
raise ValueError(err_msg)
|
113 |
+
|
114 |
+
|
115 |
+
def verify_marlin_supports_shape(
|
116 |
+
output_size_per_partition: int,
|
117 |
+
input_size_per_partition: int,
|
118 |
+
input_size: int,
|
119 |
+
group_size: int,
|
120 |
+
) -> None:
|
121 |
+
|
122 |
+
# Validate output_size_per_partition
|
123 |
+
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
|
124 |
+
raise ValueError(
|
125 |
+
f"Weight output_size_per_partition = "
|
126 |
+
f"{output_size_per_partition} is not divisible by "
|
127 |
+
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
|
128 |
+
"Consider reducing tensor_parallel_size or running "
|
129 |
+
"with --quantization gptq."
|
130 |
+
)
|
131 |
+
|
132 |
+
# Validate input_size_per_partition
|
133 |
+
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
|
134 |
+
raise ValueError(
|
135 |
+
f"Weight input_size_per_partition = "
|
136 |
+
f"{input_size_per_partition} is not divisible "
|
137 |
+
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
|
138 |
+
"Consider reducing tensor_parallel_size or running "
|
139 |
+
"with --quantization gptq."
|
140 |
+
)
|
141 |
+
|
142 |
+
if group_size < input_size and input_size_per_partition % group_size != 0:
|
143 |
+
raise ValueError(
|
144 |
+
f"Weight input_size_per_partition = {input_size_per_partition}"
|
145 |
+
f" is not divisible by group_size = {group_size}."
|
146 |
+
"Consider reducing tensor_parallel_size or running "
|
147 |
+
"with --quantization gptq."
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
def check_marlin_supports_shape(
|
152 |
+
output_size_per_partition: int,
|
153 |
+
input_size_per_partition: int,
|
154 |
+
input_size: int,
|
155 |
+
group_size: int,
|
156 |
+
) -> Tuple[bool, Optional[str]]:
|
157 |
+
try:
|
158 |
+
verify_marlin_supports_shape(
|
159 |
+
output_size_per_partition, input_size_per_partition, input_size, group_size
|
160 |
+
)
|
161 |
+
except ValueError as e:
|
162 |
+
return False, e.__str__()
|
163 |
+
return True, None
|
164 |
+
|
165 |
+
|
166 |
+
def marlin_make_workspace(
|
167 |
+
output_size_per_partition: int, device: torch.device
|
168 |
+
) -> torch.Tensor:
|
169 |
+
max_workspace_size = (
|
170 |
+
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
|
171 |
+
) * GPTQ_MARLIN_MAX_PARALLEL
|
172 |
+
|
173 |
+
return torch.zeros(
|
174 |
+
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
179 |
+
return (not act_order) or (act_order and not is_row_parallel)
|
180 |
+
|
181 |
+
|
182 |
+
def marlin_repeat_scales_on_all_ranks(
|
183 |
+
act_order: bool, group_size: int, is_row_parallel: bool
|
184 |
+
) -> bool:
|
185 |
+
# Need to repeat scales on every rank if act_ordering or
|
186 |
+
# channelwise and RowParallelLinear
|
187 |
+
is_channelwise = group_size == -1
|
188 |
+
return act_order or (is_channelwise and is_row_parallel)
|
189 |
+
|
190 |
+
|
191 |
+
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
192 |
+
return torch.nn.Parameter(
|
193 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
198 |
+
return torch.nn.Parameter(
|
199 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
204 |
+
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
205 |
+
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
206 |
+
|
207 |
+
|
208 |
+
def get_scale_perms():
|
209 |
+
scale_perm: List[int] = []
|
210 |
+
for i in range(8):
|
211 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
212 |
+
scale_perm_single: List[int] = []
|
213 |
+
for i in range(4):
|
214 |
+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
215 |
+
return scale_perm, scale_perm_single
|
216 |
+
|
217 |
+
|
218 |
+
def marlin_permute_scales(
|
219 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
220 |
+
) -> torch.Tensor:
|
221 |
+
|
222 |
+
scale_perm, scale_perm_single = get_scale_perms()
|
223 |
+
if group_size < size_k and group_size != -1:
|
224 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
225 |
+
else:
|
226 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
227 |
+
s = s.reshape((-1, size_n)).contiguous()
|
228 |
+
|
229 |
+
return s
|
230 |
+
|
231 |
+
|
232 |
+
def marlin_moe_permute_scales(
|
233 |
+
s: torch.Tensor,
|
234 |
+
size_k: int,
|
235 |
+
size_n: int,
|
236 |
+
group_size: int,
|
237 |
+
):
|
238 |
+
num_experts = s.shape[0]
|
239 |
+
output = torch.empty(
|
240 |
+
(num_experts, s.shape[1], s.shape[2]),
|
241 |
+
device=s.device,
|
242 |
+
dtype=s.dtype,
|
243 |
+
)
|
244 |
+
|
245 |
+
for e in range(num_experts):
|
246 |
+
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
247 |
+
return output
|
248 |
+
|
249 |
+
|
250 |
+
def marlin_zero_points(
|
251 |
+
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
252 |
+
) -> torch.Tensor:
|
253 |
+
# Permute zero-points in a similar way to scales, but do not use the
|
254 |
+
# "single" permutation, since zero-points are applied on every MMA
|
255 |
+
scale_perm, _ = get_scale_perms()
|
256 |
+
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
257 |
+
|
258 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
259 |
+
if num_bits == 4:
|
260 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
261 |
+
elif num_bits == 8:
|
262 |
+
interleave = numpy.array([0, 2, 1, 3])
|
263 |
+
else:
|
264 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
265 |
+
|
266 |
+
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
267 |
+
zp = zp.reshape((-1, size_n)).contiguous()
|
268 |
+
zp = pack_cols(zp, num_bits, size_k, size_n)
|
269 |
+
|
270 |
+
return zp
|
271 |
+
|
272 |
+
|
273 |
+
def awq_to_marlin_zero_points(
|
274 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
275 |
+
) -> torch.Tensor:
|
276 |
+
# AWQ zero-points are quantized and packed on the column dim.
|
277 |
+
# In addition, the values are permuted based on dequantizer.
|
278 |
+
# Here we undo both of these, and then apply marlin permutation
|
279 |
+
# and pack it back.
|
280 |
+
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
281 |
+
|
282 |
+
# Undo interleaving (use argsort(..) to get inverse perm)
|
283 |
+
if num_bits == 4:
|
284 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
285 |
+
elif num_bits == 8:
|
286 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
287 |
+
else:
|
288 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
289 |
+
|
290 |
+
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
291 |
+
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
292 |
+
|
293 |
+
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
294 |
+
return marlin_zp
|
295 |
+
|
296 |
+
|
297 |
+
def moe_awq_to_marlin_zero_points(
|
298 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
299 |
+
):
|
300 |
+
num_experts = q_zp_packed.shape[0]
|
301 |
+
output = torch.empty(
|
302 |
+
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
303 |
+
device=q_zp_packed.device,
|
304 |
+
dtype=q_zp_packed.dtype,
|
305 |
+
)
|
306 |
+
for e in range(num_experts):
|
307 |
+
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
308 |
+
return output
|
309 |
+
|
310 |
+
|
311 |
+
def apply_gptq_marlin_linear(
|
312 |
+
input: torch.Tensor,
|
313 |
+
weight: torch.Tensor,
|
314 |
+
weight_scale: torch.Tensor,
|
315 |
+
weight_zp: torch.Tensor,
|
316 |
+
g_idx: torch.Tensor,
|
317 |
+
g_idx_sort_indices: torch.Tensor,
|
318 |
+
workspace: torch.Tensor,
|
319 |
+
wtype: ScalarType,
|
320 |
+
output_size_per_partition: int,
|
321 |
+
input_size_per_partition: int,
|
322 |
+
is_k_full: bool,
|
323 |
+
bias: Optional[torch.Tensor] = None,
|
324 |
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
325 |
+
) -> torch.Tensor:
|
326 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
327 |
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
328 |
+
|
329 |
+
output = ops.gptq_marlin_gemm(
|
330 |
+
reshaped_x,
|
331 |
+
weight,
|
332 |
+
weight_scale,
|
333 |
+
weight_zp,
|
334 |
+
g_idx,
|
335 |
+
g_idx_sort_indices,
|
336 |
+
workspace,
|
337 |
+
wtype,
|
338 |
+
size_m=reshaped_x.shape[0],
|
339 |
+
size_n=output_size_per_partition,
|
340 |
+
size_k=input_size_per_partition,
|
341 |
+
is_k_full=is_k_full,
|
342 |
+
has_zp=False,
|
343 |
+
use_fp32_reduce=use_fp32_reduce,
|
344 |
+
is_zp_float=False,
|
345 |
+
)
|
346 |
+
|
347 |
+
if bias is not None:
|
348 |
+
output.add_(bias) # In-place add
|
349 |
+
|
350 |
+
return output.reshape(out_shape)
|
351 |
+
|
352 |
+
|
353 |
+
def apply_awq_marlin_linear(
|
354 |
+
input: torch.Tensor,
|
355 |
+
weight: torch.Tensor,
|
356 |
+
weight_scale: torch.Tensor,
|
357 |
+
weight_zp: torch.Tensor,
|
358 |
+
g_idx: torch.Tensor,
|
359 |
+
g_idx_sort_indices: torch.Tensor,
|
360 |
+
workspace: torch.Tensor,
|
361 |
+
quant_type: ScalarType,
|
362 |
+
output_size_per_partition: int,
|
363 |
+
input_size_per_partition: int,
|
364 |
+
bias: Optional[torch.Tensor] = None,
|
365 |
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
366 |
+
) -> torch.Tensor:
|
367 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
368 |
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
369 |
+
|
370 |
+
output = ops.gptq_marlin_gemm(
|
371 |
+
reshaped_x,
|
372 |
+
weight,
|
373 |
+
weight_scale,
|
374 |
+
weight_zp,
|
375 |
+
g_idx,
|
376 |
+
g_idx_sort_indices,
|
377 |
+
workspace,
|
378 |
+
quant_type,
|
379 |
+
size_m=reshaped_x.shape[0],
|
380 |
+
size_n=output_size_per_partition,
|
381 |
+
size_k=input_size_per_partition,
|
382 |
+
is_k_full=True,
|
383 |
+
has_zp=True,
|
384 |
+
use_fp32_reduce=use_fp32_reduce,
|
385 |
+
is_zp_float=False,
|
386 |
+
)
|
387 |
+
|
388 |
+
if bias is not None:
|
389 |
+
output.add_(bias) # In-place add
|
390 |
+
|
391 |
+
return output.reshape(out_shape)
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import quantization as ops
|
6 |
+
|
7 |
+
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
8 |
+
|
9 |
+
|
10 |
+
def is_fp8_marlin_supported():
|
11 |
+
capability = torch.cuda.get_device_capability()
|
12 |
+
capability = capability[0] * 10 + capability[1]
|
13 |
+
return capability >= 80
|
14 |
+
|
15 |
+
|
16 |
+
def apply_fp8_marlin_linear(
|
17 |
+
input: torch.Tensor,
|
18 |
+
weight: torch.Tensor,
|
19 |
+
weight_scale: torch.Tensor,
|
20 |
+
workspace: torch.Tensor,
|
21 |
+
size_n: int,
|
22 |
+
size_k: int,
|
23 |
+
bias: Optional[torch.Tensor],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
# For GPUs that lack FP8 hardware support, we can leverage the
|
26 |
+
# Marlin kernel for fast weight-only FP8 quantization
|
27 |
+
|
28 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
29 |
+
out_shape = input.shape[:-1] + (size_n,)
|
30 |
+
|
31 |
+
output = ops.fp8_marlin_gemm(
|
32 |
+
a=reshaped_x,
|
33 |
+
b_q_weight=weight,
|
34 |
+
b_scales=weight_scale,
|
35 |
+
workspace=workspace,
|
36 |
+
num_bits=8,
|
37 |
+
size_m=reshaped_x.shape[0],
|
38 |
+
size_n=size_n,
|
39 |
+
size_k=size_k,
|
40 |
+
)
|
41 |
+
|
42 |
+
if bias is not None:
|
43 |
+
output.add_(bias) # In-place add
|
44 |
+
|
45 |
+
return output.reshape(out_shape)
|
46 |
+
|
47 |
+
|
48 |
+
def prepare_fp8_layer_for_marlin(
|
49 |
+
layer: torch.nn.Module, strategy: str = "tensor"
|
50 |
+
) -> None:
|
51 |
+
part_size_n = layer.output_size_per_partition
|
52 |
+
part_size_k = layer.input_size_per_partition
|
53 |
+
|
54 |
+
device = layer.weight.device
|
55 |
+
|
56 |
+
# WORKSPACE
|
57 |
+
layer.workspace = marlin_make_workspace(part_size_n, device)
|
58 |
+
|
59 |
+
# WEIGHT
|
60 |
+
# Repack weights to marlin format
|
61 |
+
marlin_qweight = ops.gptq_marlin_repack(
|
62 |
+
b_q_weight=pack_fp8_to_int32(layer.weight),
|
63 |
+
perm=torch.empty(0, dtype=torch.int, device=device),
|
64 |
+
size_k=part_size_k,
|
65 |
+
size_n=part_size_n,
|
66 |
+
num_bits=8,
|
67 |
+
)
|
68 |
+
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
69 |
+
|
70 |
+
# WEIGHT SCALES
|
71 |
+
scales = layer.weight_scale.to(layer.orig_dtype)
|
72 |
+
# Permute scales
|
73 |
+
marlin_scales = marlin_permute_scales(
|
74 |
+
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
|
75 |
+
)
|
76 |
+
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
77 |
+
|
78 |
+
|
79 |
+
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
Repack FP8 weights to gptq format (packed int32 elements)
|
82 |
+
"""
|
83 |
+
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
84 |
+
assert fp8_tensor.shape[0] % 4 == 0
|
85 |
+
|
86 |
+
# Reshape to prepare for packing
|
87 |
+
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
88 |
+
|
89 |
+
# Convert fp8 to uint8 (byte) representation
|
90 |
+
byte_tensor = reshaped.view(torch.uint8)
|
91 |
+
|
92 |
+
# Pack 4 uint8 values into one int32
|
93 |
+
packed = (
|
94 |
+
byte_tensor[:, 0].to(torch.int32)
|
95 |
+
| (byte_tensor[:, 1].to(torch.int32) << 8)
|
96 |
+
| (byte_tensor[:, 2].to(torch.int32) << 16)
|
97 |
+
| (byte_tensor[:, 3].to(torch.int32) << 24)
|
98 |
+
)
|
99 |
+
|
100 |
+
return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from quantization.scalar_type import ScalarType
|
9 |
+
|
10 |
+
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
11 |
+
from .quant_utils import (
|
12 |
+
get_pack_factor,
|
13 |
+
gptq_quantize_weights,
|
14 |
+
quantize_weights,
|
15 |
+
sort_weights,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class MarlinWorkspace:
|
20 |
+
|
21 |
+
def __init__(self, out_features, min_thread_n, max_parallel):
|
22 |
+
assert (
|
23 |
+
out_features % min_thread_n == 0
|
24 |
+
), "out_features = {} is undivisible by min_thread_n = {}".format(
|
25 |
+
out_features, min_thread_n
|
26 |
+
)
|
27 |
+
|
28 |
+
max_workspace_size = (out_features // min_thread_n) * max_parallel
|
29 |
+
|
30 |
+
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
|
31 |
+
|
32 |
+
|
33 |
+
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
34 |
+
assert q_w.shape == (size_k, size_n)
|
35 |
+
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
36 |
+
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
37 |
+
|
38 |
+
# Permute weights to 16x64 marlin tiles
|
39 |
+
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
40 |
+
q_w = q_w.permute((0, 2, 1, 3))
|
41 |
+
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
42 |
+
|
43 |
+
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
44 |
+
|
45 |
+
return q_w
|
46 |
+
|
47 |
+
|
48 |
+
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
49 |
+
# Permute
|
50 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
51 |
+
|
52 |
+
# Pack
|
53 |
+
pack_factor = get_pack_factor(num_bits)
|
54 |
+
orig_device = q_w.device
|
55 |
+
|
56 |
+
q_w = q_w.cpu().numpy().astype(np.uint32)
|
57 |
+
|
58 |
+
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
59 |
+
for i in range(pack_factor):
|
60 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
61 |
+
|
62 |
+
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
63 |
+
|
64 |
+
return q_packed
|
65 |
+
|
66 |
+
|
67 |
+
def get_weight_perm(num_bits: int):
|
68 |
+
perm_list: List[int] = []
|
69 |
+
for i in range(32):
|
70 |
+
perm1: List[int] = []
|
71 |
+
col = i // 4
|
72 |
+
for block in [0, 1]:
|
73 |
+
for row in [
|
74 |
+
2 * (i % 4),
|
75 |
+
2 * (i % 4) + 1,
|
76 |
+
2 * (i % 4 + 4),
|
77 |
+
2 * (i % 4 + 4) + 1,
|
78 |
+
]:
|
79 |
+
perm1.append(16 * row + col + 8 * block)
|
80 |
+
for j in range(4):
|
81 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
82 |
+
|
83 |
+
perm = np.array(perm_list)
|
84 |
+
|
85 |
+
if num_bits == 4:
|
86 |
+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
87 |
+
elif num_bits == 8:
|
88 |
+
interleave = np.array([0, 2, 1, 3])
|
89 |
+
else:
|
90 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
91 |
+
|
92 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
93 |
+
perm = torch.from_numpy(perm)
|
94 |
+
return perm
|
95 |
+
|
96 |
+
|
97 |
+
def marlin_quantize(
|
98 |
+
w: torch.Tensor,
|
99 |
+
quant_type: ScalarType,
|
100 |
+
group_size: int,
|
101 |
+
act_order: bool,
|
102 |
+
test_perm: Optional[torch.Tensor] = None,
|
103 |
+
):
|
104 |
+
size_k, size_n = w.shape
|
105 |
+
num_bits = quant_type.size_bits
|
106 |
+
|
107 |
+
# Normalize group_size
|
108 |
+
if group_size == -1:
|
109 |
+
group_size = size_k
|
110 |
+
assert group_size <= size_k
|
111 |
+
|
112 |
+
# Quantize (and apply act_order if provided)
|
113 |
+
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
114 |
+
w, quant_type, group_size, act_order, test_perm
|
115 |
+
)
|
116 |
+
|
117 |
+
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
118 |
+
# increasing
|
119 |
+
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
120 |
+
if act_order:
|
121 |
+
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
122 |
+
|
123 |
+
# Reformat to marlin
|
124 |
+
weight_perm = get_weight_perm(num_bits)
|
125 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
126 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
127 |
+
|
128 |
+
# Create result
|
129 |
+
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
130 |
+
for i in range(len(res_list)):
|
131 |
+
res_list[i] = res_list[i].to(w.device)
|
132 |
+
|
133 |
+
return res_list
|
134 |
+
|
135 |
+
|
136 |
+
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
|
137 |
+
size_k, size_n = w.shape
|
138 |
+
|
139 |
+
# Normalize group_size
|
140 |
+
if group_size == -1:
|
141 |
+
group_size = size_k
|
142 |
+
assert group_size <= size_k
|
143 |
+
|
144 |
+
# Detect num groups
|
145 |
+
assert size_k % group_size == 0
|
146 |
+
num_groups = size_k // group_size
|
147 |
+
|
148 |
+
# Quantize with zp
|
149 |
+
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
|
150 |
+
|
151 |
+
# Reformat to marlin
|
152 |
+
weight_perm = get_weight_perm(quant_type.size_bits)
|
153 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
|
154 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
155 |
+
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
|
156 |
+
|
157 |
+
# Create result
|
158 |
+
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
159 |
+
for i in range(len(res_list)):
|
160 |
+
res_list[i] = res_list[i].to(w.device)
|
161 |
+
|
162 |
+
return res_list
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
import random
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import numpy
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from quantization.scalar_type import ScalarType
|
10 |
+
|
11 |
+
from .marlin_utils_test import marlin_weights
|
12 |
+
from .quant_utils import gptq_quantize_weights
|
13 |
+
|
14 |
+
|
15 |
+
# This is PyTorch implementation of main part of reorder_meta()
|
16 |
+
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
17 |
+
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
18 |
+
# GEMM decides upon layout of this matrix, and at the moment for the
|
19 |
+
# sparse GEMM executed on tensor cores, this is layout described by
|
20 |
+
# ColumnMajorInterleaved<2> data structure, in
|
21 |
+
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
22 |
+
# reordering of meta matrix into meta_reordered matrix calculated
|
23 |
+
# according to these segments of CUTLASS code is re-implemented here.
|
24 |
+
# Note that this calculation produces offsets for scattering metadata
|
25 |
+
# matrix elements into reordered metadata matrix elements (or,
|
26 |
+
# equivalently, for gathering reordered metadata matrix element back
|
27 |
+
# into metadata matrix elements).
|
28 |
+
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
29 |
+
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
30 |
+
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
31 |
+
|
32 |
+
# Reorder the rows, then swizzle the 2x2 blocks.
|
33 |
+
group_x = 64
|
34 |
+
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
35 |
+
|
36 |
+
dst_rows = (
|
37 |
+
dst_rows // group_x * group_x
|
38 |
+
+ (dst_rows % 2) * 2
|
39 |
+
+ (dst_rows % 8) // 4
|
40 |
+
+ ((dst_rows % group_y) % 4) // 2 * 32
|
41 |
+
+ ((dst_rows % group_x) // 8) * 4
|
42 |
+
)
|
43 |
+
|
44 |
+
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
45 |
+
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
46 |
+
dst_rows += topright - bottomleft
|
47 |
+
dst_cols -= topright - bottomleft
|
48 |
+
|
49 |
+
# Assumed that meta tensor is to be stored in CUTLASS
|
50 |
+
# InterleavedColumnMajor layout, and reverse engineered
|
51 |
+
# corresponding code to store values into this tensor.
|
52 |
+
interleave = 2
|
53 |
+
cols_maj = dst_cols // interleave
|
54 |
+
cols_min = dst_cols % interleave
|
55 |
+
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
56 |
+
|
57 |
+
|
58 |
+
# This function converts dense matrix into sparse semi-structured
|
59 |
+
# representation, producing "compressed" matrix, in the layout used by
|
60 |
+
# CUTLASS backend, and corresponding metadata matrix.
|
61 |
+
def sparse_semi_structured_from_dense_cutlass(dense):
|
62 |
+
if dense.dim() != 2:
|
63 |
+
raise RuntimeError(
|
64 |
+
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
65 |
+
)
|
66 |
+
|
67 |
+
m, k = dense.shape
|
68 |
+
device = dense.device
|
69 |
+
|
70 |
+
meta_dtype = torch.int8
|
71 |
+
if dense.dtype == torch.int8:
|
72 |
+
meta_dtype = torch.int32
|
73 |
+
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
74 |
+
meta_dtype = torch.int16
|
75 |
+
else:
|
76 |
+
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
77 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
78 |
+
if quadbits_per_meta_elem not in (4, 8):
|
79 |
+
raise RuntimeError("Invalid number of elements per meta element calculated")
|
80 |
+
|
81 |
+
if meta_dtype == torch.int32:
|
82 |
+
if m % 16 != 0:
|
83 |
+
raise RuntimeError(
|
84 |
+
f"Number of rows of dense matrix {m} must be divisible by 16"
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
if m % 32 != 0:
|
88 |
+
raise RuntimeError(
|
89 |
+
f"Number of rows of dense matrix {m} must be divisible by 32"
|
90 |
+
)
|
91 |
+
if k % (4 * quadbits_per_meta_elem) != 0:
|
92 |
+
raise RuntimeError(
|
93 |
+
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
94 |
+
)
|
95 |
+
|
96 |
+
if dense.dtype != torch.float:
|
97 |
+
ksparse = 4
|
98 |
+
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
99 |
+
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
100 |
+
else:
|
101 |
+
ksparse = 2
|
102 |
+
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
103 |
+
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
104 |
+
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
105 |
+
|
106 |
+
# Encoding quadruples of True/False values as follows:
|
107 |
+
# [True, True, False, False] -> 0b0100
|
108 |
+
# [True, False, True, False] -> 0b1000
|
109 |
+
# [False, True, True, False] -> 0b1001
|
110 |
+
# [True, False, False, True ] -> 0b1100
|
111 |
+
# [False, True, False, True ] -> 0b1101
|
112 |
+
# [False, False, True, True ] -> 0b1110
|
113 |
+
# Thus, lower two bits in the encoding are index of the True value
|
114 |
+
# at the lowest index in the quadruple, and the higher two bits in
|
115 |
+
# the encoding are index of the other True value in the quadruple.
|
116 |
+
# In case there are less than two True values, than False value or
|
117 |
+
# values at some index or indices are considered True for the
|
118 |
+
# encoding. In case there are more than two True values, then the
|
119 |
+
# excess True value(s) at some indices are considered False for
|
120 |
+
# the encoding. The exact encodings used for these cases are as
|
121 |
+
# follows:
|
122 |
+
# [False, False, False, False] -> 0b1110
|
123 |
+
# [False, False, False, True ] -> 0b1110
|
124 |
+
# [False, False, True, False] -> 0b1110
|
125 |
+
# [False, True, False, False] -> 0b1001
|
126 |
+
# [False, True, True, True ] -> 0b1101
|
127 |
+
# [True, False, False, False] -> 0b1000
|
128 |
+
# [True, False, True, True ] -> 0b1100
|
129 |
+
# [True, True, False, True ] -> 0b0100
|
130 |
+
# [True, True, True, False] -> 0b0100
|
131 |
+
# [True, True, True, True ] -> 0b0100
|
132 |
+
# These particular encodings are chosen, with the help of Espresso
|
133 |
+
# logic minimizer software, for the purpose of minimization of
|
134 |
+
# corresponding Boolean functions, that translate non-zero flags
|
135 |
+
# into encoding bits. Note also possible choices for the first
|
136 |
+
# and last of these encodings were limited only to (0b0100,
|
137 |
+
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
138 |
+
# case.
|
139 |
+
|
140 |
+
expr0 = m0 & m1
|
141 |
+
expr1 = ~m0 & m1
|
142 |
+
expr2 = ~m0 & ~m1
|
143 |
+
bit0 = expr1
|
144 |
+
bit1 = expr2
|
145 |
+
bit2 = expr0 | expr2 | m3
|
146 |
+
bit3 = expr1 | ~m1
|
147 |
+
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
148 |
+
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
149 |
+
|
150 |
+
if dense.dtype != torch.float:
|
151 |
+
sparse0 = dense_4.gather(
|
152 |
+
-1, idxs0.unsqueeze(-1)
|
153 |
+
) # type: ignore[possibly-undefined]
|
154 |
+
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
155 |
+
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
156 |
+
else:
|
157 |
+
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
|
158 |
+
m, k // 2
|
159 |
+
) # type: ignore[possibly-undefined]
|
160 |
+
|
161 |
+
meta_4 = idxs0 | (idxs1 << 2)
|
162 |
+
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
163 |
+
|
164 |
+
if quadbits_per_meta_elem == 4:
|
165 |
+
meta = (
|
166 |
+
meta_n[:, :, 0]
|
167 |
+
| (meta_n[:, :, 1] << 4)
|
168 |
+
| (meta_n[:, :, 2] << 8)
|
169 |
+
| (meta_n[:, :, 3] << 12)
|
170 |
+
)
|
171 |
+
elif quadbits_per_meta_elem == 8:
|
172 |
+
meta = (
|
173 |
+
meta_n[:, :, 0]
|
174 |
+
| (meta_n[:, :, 1] << 4)
|
175 |
+
| (meta_n[:, :, 2] << 8)
|
176 |
+
| (meta_n[:, :, 3] << 12)
|
177 |
+
| (meta_n[:, :, 4] << 16)
|
178 |
+
| (meta_n[:, :, 5] << 20)
|
179 |
+
| (meta_n[:, :, 6] << 24)
|
180 |
+
| (meta_n[:, :, 7] << 28)
|
181 |
+
)
|
182 |
+
|
183 |
+
# Reorder meta tensor elements.
|
184 |
+
meta_reordered = meta.new_empty(
|
185 |
+
(m * meta_ncols,)
|
186 |
+
) # type: ignore[possibly-undefined]
|
187 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
188 |
+
m, meta_ncols, meta_dtype, device
|
189 |
+
)
|
190 |
+
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
191 |
+
|
192 |
+
return (sparse, meta_reordered.view(m, meta_ncols))
|
193 |
+
|
194 |
+
|
195 |
+
# This function performs reverse of the function above - it
|
196 |
+
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
197 |
+
# in the layout used by CUTLASS backend, and accompanying metadata
|
198 |
+
# matrix.
|
199 |
+
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
200 |
+
if sparse.dim() != 2:
|
201 |
+
raise RuntimeError(
|
202 |
+
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
203 |
+
)
|
204 |
+
|
205 |
+
m, k = sparse.shape
|
206 |
+
device = sparse.device
|
207 |
+
|
208 |
+
if meta_reordered.dim() != 2:
|
209 |
+
raise RuntimeError(
|
210 |
+
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
211 |
+
)
|
212 |
+
if meta_reordered.device != device:
|
213 |
+
raise RuntimeError(
|
214 |
+
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
215 |
+
)
|
216 |
+
|
217 |
+
meta_dtype = meta_reordered.dtype
|
218 |
+
if meta_dtype not in (torch.int16, torch.int32):
|
219 |
+
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
220 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
221 |
+
|
222 |
+
ksparse = 4 if sparse.dtype != torch.float else 2
|
223 |
+
|
224 |
+
meta_nrows, meta_ncols = meta_reordered.shape
|
225 |
+
if meta_nrows != m:
|
226 |
+
raise RuntimeError(
|
227 |
+
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
228 |
+
)
|
229 |
+
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
230 |
+
raise RuntimeError(
|
231 |
+
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
232 |
+
"expected according to the number of columns of meta matrix"
|
233 |
+
)
|
234 |
+
|
235 |
+
# Undo meta tensor elements reordering.
|
236 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
237 |
+
m, meta_ncols, meta_dtype, device
|
238 |
+
)
|
239 |
+
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
240 |
+
|
241 |
+
# Unpack sparse tensor back to original dense tensor, using
|
242 |
+
# information provided by meta tensor. Note that torch.float
|
243 |
+
# datatype is handled pretty much the same as
|
244 |
+
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
245 |
+
# value is encoded as if underlying 8 bytes contain four
|
246 |
+
# torch.half/torch.bfloat16 values, where either first two or last
|
247 |
+
# two are zeros.
|
248 |
+
meta_2 = torch.empty(
|
249 |
+
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
250 |
+
dtype=meta_dtype,
|
251 |
+
device=device,
|
252 |
+
)
|
253 |
+
if quadbits_per_meta_elem == 4:
|
254 |
+
meta_2[:, :, 0] = meta & 0b11
|
255 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
256 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
257 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
258 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
259 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
260 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
261 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
262 |
+
elif quadbits_per_meta_elem == 8:
|
263 |
+
meta_2[:, :, 0] = meta & 0b11
|
264 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
265 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
266 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
267 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
268 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
269 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
270 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
271 |
+
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
272 |
+
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
273 |
+
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
274 |
+
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
275 |
+
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
276 |
+
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
277 |
+
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
278 |
+
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
279 |
+
|
280 |
+
dense_offsets = meta_2.view(-1) + (
|
281 |
+
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
282 |
+
).view(-1, 1).repeat(1, 2).view(-1)
|
283 |
+
|
284 |
+
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
285 |
+
if sparse.dtype != torch.float:
|
286 |
+
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
287 |
+
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
288 |
+
else:
|
289 |
+
dense.view(torch.half).scatter_(
|
290 |
+
0, dense_offsets, sparse.view(torch.half).view(-1)
|
291 |
+
)
|
292 |
+
|
293 |
+
return dense.view(m, 2 * k)
|
294 |
+
|
295 |
+
|
296 |
+
def mask_creator(tensor):
|
297 |
+
"""
|
298 |
+
Class for creating N:M sparsity masks.
|
299 |
+
Masks will be created using the N:M ratio, where for every block of
|
300 |
+
M weights, N will be pruned based on ranked weight value. Each mask
|
301 |
+
will correspond to the given tensor.
|
302 |
+
|
303 |
+
:param N: The number of weights in a group to keep
|
304 |
+
:param M: The size of a weight group
|
305 |
+
"""
|
306 |
+
N = 2
|
307 |
+
M = 4
|
308 |
+
|
309 |
+
mask = None
|
310 |
+
# for i, tensor in enumerate(tensors):
|
311 |
+
if tensor.numel() % M != 0:
|
312 |
+
raise ValueError(
|
313 |
+
f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
|
314 |
+
)
|
315 |
+
|
316 |
+
num_groups = tensor.numel() // M
|
317 |
+
|
318 |
+
# N:M sparsity for linear layers
|
319 |
+
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
320 |
+
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
321 |
+
|
322 |
+
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
323 |
+
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
324 |
+
|
325 |
+
return mask
|
326 |
+
|
327 |
+
|
328 |
+
def inject_24(w, size_k, size_n):
|
329 |
+
assert w.shape == (size_k, size_n)
|
330 |
+
|
331 |
+
mask = mask_creator(w.t()).t().cuda().bool()
|
332 |
+
|
333 |
+
return (mask * w).contiguous(), mask.contiguous()
|
334 |
+
|
335 |
+
|
336 |
+
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
337 |
+
BLOCK_SIZE = 4
|
338 |
+
MAX_NON_ZEROS = 2
|
339 |
+
|
340 |
+
w = w.t().contiguous()
|
341 |
+
|
342 |
+
print("check_24: w.shape = {}".format(w.shape))
|
343 |
+
|
344 |
+
num_rows, num_cols = w.shape
|
345 |
+
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
346 |
+
if _verbose:
|
347 |
+
print(f"Sampled row idxs = {sampled_row_idxs}")
|
348 |
+
|
349 |
+
total_segments = 0
|
350 |
+
non_24_segments = 0
|
351 |
+
for i in sampled_row_idxs:
|
352 |
+
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
353 |
+
total_segments += 1
|
354 |
+
block = w[i, j : j + BLOCK_SIZE]
|
355 |
+
num_nonzero = torch.count_nonzero(block)
|
356 |
+
if num_nonzero > MAX_NON_ZEROS:
|
357 |
+
print("i = {} j = {} block = {}".format(i, j, block))
|
358 |
+
non_24_segments += 1
|
359 |
+
|
360 |
+
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
361 |
+
|
362 |
+
|
363 |
+
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
364 |
+
assert q_24.shape == (size_k, size_n)
|
365 |
+
|
366 |
+
# Remove bias to normalize over 0
|
367 |
+
q_24_no_zp = q_24 - wtype.bias
|
368 |
+
|
369 |
+
# Compress
|
370 |
+
q_24_no_zp = q_24_no_zp.t().contiguous()
|
371 |
+
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
|
372 |
+
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
373 |
+
|
374 |
+
# Restore bias
|
375 |
+
q_24_comp = q_24_no_zp_comp + wtype.bias
|
376 |
+
|
377 |
+
# Resize meta to its actual shape (without moving any data)
|
378 |
+
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
379 |
+
|
380 |
+
return q_24_comp, meta
|
381 |
+
|
382 |
+
|
383 |
+
def get_scale_perms_24():
|
384 |
+
scale_perm: List[int] = []
|
385 |
+
for i in range(8):
|
386 |
+
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
387 |
+
scale_perm_single: List[int] = []
|
388 |
+
for i in range(8):
|
389 |
+
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
390 |
+
return scale_perm, scale_perm_single
|
391 |
+
|
392 |
+
|
393 |
+
def get_weight_perm_24(num_bits: int):
|
394 |
+
perm_list: List[int] = []
|
395 |
+
for i in range(32):
|
396 |
+
perm1: List[int] = []
|
397 |
+
col = i // 4
|
398 |
+
col_o = col // 2
|
399 |
+
for block in [0, 1]:
|
400 |
+
for row in [
|
401 |
+
2 * (i % 4),
|
402 |
+
2 * (i % 4) + 1,
|
403 |
+
2 * (i % 4 + 4),
|
404 |
+
2 * (i % 4 + 4) + 1,
|
405 |
+
]:
|
406 |
+
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
|
407 |
+
for j in range(4):
|
408 |
+
perm_list.extend([p + 1 * j for p in perm1])
|
409 |
+
perm = numpy.array(perm_list)
|
410 |
+
|
411 |
+
if num_bits == 4:
|
412 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
413 |
+
elif num_bits == 8:
|
414 |
+
interleave = numpy.array([0, 2, 1, 3])
|
415 |
+
else:
|
416 |
+
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
417 |
+
|
418 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
419 |
+
perm = torch.from_numpy(perm)
|
420 |
+
return perm
|
421 |
+
|
422 |
+
|
423 |
+
def marlin_permute_scales_24(
|
424 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
425 |
+
) -> torch.Tensor:
|
426 |
+
|
427 |
+
scale_perm, scale_perm_single = get_scale_perms_24()
|
428 |
+
if group_size < size_k and group_size != -1:
|
429 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
430 |
+
else:
|
431 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
432 |
+
s = s.reshape((-1, size_n)).contiguous()
|
433 |
+
|
434 |
+
return s
|
435 |
+
|
436 |
+
|
437 |
+
def marlin_24_quantize(
|
438 |
+
w: torch.Tensor,
|
439 |
+
quant_type: ScalarType,
|
440 |
+
group_size: int,
|
441 |
+
):
|
442 |
+
size_k, size_n = w.shape
|
443 |
+
|
444 |
+
# Normalize group_size
|
445 |
+
if group_size == -1:
|
446 |
+
group_size = size_k
|
447 |
+
assert group_size <= size_k
|
448 |
+
|
449 |
+
# Inject 2:4 sparsity
|
450 |
+
w_24, mask_24 = inject_24(w, size_k, size_n)
|
451 |
+
|
452 |
+
# Quantize
|
453 |
+
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
|
454 |
+
w_24, quant_type, group_size, act_order=False
|
455 |
+
)
|
456 |
+
|
457 |
+
# Compress quantized weight
|
458 |
+
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
|
459 |
+
size_k_comp = size_k // 2
|
460 |
+
|
461 |
+
# Reformat to marlin
|
462 |
+
weight_perm = get_weight_perm_24(quant_type.size_bits)
|
463 |
+
marlin_24_q_w_comp = marlin_weights(
|
464 |
+
q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
|
465 |
+
)
|
466 |
+
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
|
467 |
+
|
468 |
+
# Create result
|
469 |
+
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
470 |
+
for i in range(len(res_list)):
|
471 |
+
res_list[i] = res_list[i].to(w.device)
|
472 |
+
|
473 |
+
return res_list
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .marlin_utils_test import marlin_permute_weights
|
7 |
+
from .quant_utils import get_pack_factor, qqq_quantize_weights
|
8 |
+
|
9 |
+
|
10 |
+
def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
|
11 |
+
# Permute
|
12 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
13 |
+
|
14 |
+
# Pack
|
15 |
+
pack_factor = get_pack_factor(num_bits)
|
16 |
+
orig_device = q_w.device
|
17 |
+
|
18 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
19 |
+
|
20 |
+
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
21 |
+
dtype=numpy.uint32)
|
22 |
+
if group_size == size_k:
|
23 |
+
for i in range(pack_factor):
|
24 |
+
q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
|
25 |
+
else:
|
26 |
+
for i in range(pack_factor):
|
27 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
28 |
+
|
29 |
+
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
30 |
+
|
31 |
+
return q_packed
|
32 |
+
|
33 |
+
|
34 |
+
def get_qqq_scale_perms():
|
35 |
+
scale_perm: List[int] = []
|
36 |
+
for i in range(8):
|
37 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
38 |
+
scale_perm_single: List[int] = []
|
39 |
+
for i in range(4):
|
40 |
+
scale_perm_single.extend(
|
41 |
+
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
42 |
+
return scale_perm, scale_perm_single
|
43 |
+
|
44 |
+
|
45 |
+
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
|
46 |
+
def get_qqq_weight_perm(num_bits: int, quant_type: str):
|
47 |
+
perm_list: List[int] = []
|
48 |
+
for i in range(32):
|
49 |
+
perm1: List[int] = []
|
50 |
+
col = i // 4
|
51 |
+
for block in [0, 1]:
|
52 |
+
for row in [
|
53 |
+
4 * (i % 4),
|
54 |
+
4 * (i % 4) + 1,
|
55 |
+
4 * (i % 4) + 2,
|
56 |
+
4 * (i % 4) + 3,
|
57 |
+
]:
|
58 |
+
perm1.append(16 * row + col + 8 * block)
|
59 |
+
for j in range(4):
|
60 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
61 |
+
|
62 |
+
perm = numpy.array(perm_list)
|
63 |
+
|
64 |
+
assert quant_type in ["per-channel",
|
65 |
+
"per-group"], "not supported quantization type"
|
66 |
+
if num_bits == 4:
|
67 |
+
if quant_type == "per-channel":
|
68 |
+
interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
|
69 |
+
else:
|
70 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
71 |
+
else:
|
72 |
+
raise Exception("num_bits must be 4, got {}".format(num_bits))
|
73 |
+
|
74 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
75 |
+
perm = torch.from_numpy(perm)
|
76 |
+
return perm
|
77 |
+
|
78 |
+
|
79 |
+
def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
|
80 |
+
scale_perm, scale_perm_single = get_qqq_scale_perms()
|
81 |
+
if group_size < size_k and group_size != -1:
|
82 |
+
s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
|
83 |
+
s_channel = s_channel.reshape(
|
84 |
+
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
85 |
+
s_group = s_group.reshape((-1, size_n)).contiguous()
|
86 |
+
else:
|
87 |
+
s_channel = s_channel.reshape(
|
88 |
+
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
89 |
+
s_channel = s_channel.reshape((-1, size_n)).contiguous()
|
90 |
+
|
91 |
+
return s_group, s_channel
|
92 |
+
|
93 |
+
|
94 |
+
def marlin_qqq_quantize(
|
95 |
+
w: torch.Tensor,
|
96 |
+
num_bits: int,
|
97 |
+
group_size: int,
|
98 |
+
):
|
99 |
+
size_k, size_n = w.shape
|
100 |
+
|
101 |
+
# Normalize group_size
|
102 |
+
if group_size == -1:
|
103 |
+
group_size = size_k
|
104 |
+
assert group_size <= size_k
|
105 |
+
quant_type = "per-channel" if group_size == size_k else "per-group"
|
106 |
+
|
107 |
+
# Quantize
|
108 |
+
w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
|
109 |
+
w, num_bits, group_size)
|
110 |
+
|
111 |
+
# Reformat to marlin_qqq
|
112 |
+
weight_perm = get_qqq_weight_perm(num_bits, quant_type)
|
113 |
+
marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
|
114 |
+
weight_perm, group_size)
|
115 |
+
marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
|
116 |
+
s_group, s_channel, size_k, size_n, group_size)
|
117 |
+
|
118 |
+
# Create result
|
119 |
+
res_list = [
|
120 |
+
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
|
121 |
+
]
|
122 |
+
for i in range(len(res_list)):
|
123 |
+
res_list[i] = res_list[i].to(w.device)
|
124 |
+
|
125 |
+
return res_list
|
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file is used for /tests and /benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from quantization.scalar_type import ScalarType, scalar_types
|
9 |
+
|
10 |
+
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
11 |
+
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
12 |
+
|
13 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
14 |
+
|
15 |
+
# Note: this is a hack. We should update each model to register the
|
16 |
+
# stacked params and get it from there instead in a future PR.
|
17 |
+
# fused_name: List[shard_name]
|
18 |
+
FUSED_LAYER_NAME_MAPPING = {
|
19 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
20 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def pack_quantized_values_into_int32(
|
25 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
26 |
+
):
|
27 |
+
# move dim to pack to the end
|
28 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
29 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
30 |
+
w_q_perm = w_q.permute(perm)
|
31 |
+
|
32 |
+
pack_factor = 32 // wtype.size_bits
|
33 |
+
mask = (1 << wtype.size_bits) - 1
|
34 |
+
|
35 |
+
new_shape_perm = list(w_q_perm.shape)
|
36 |
+
assert w_q_perm.shape[-1] % pack_factor == 0
|
37 |
+
new_shape_perm[-1] //= pack_factor
|
38 |
+
|
39 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
40 |
+
for i in range(pack_factor):
|
41 |
+
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
|
42 |
+
|
43 |
+
return res.permute(inv_perm)
|
44 |
+
|
45 |
+
|
46 |
+
def unpack_quantized_values_into_int32(
|
47 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
48 |
+
):
|
49 |
+
# move dim to pack to the end
|
50 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
51 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
52 |
+
w_q_perm = w_q.permute(perm)
|
53 |
+
|
54 |
+
pack_factor = 32 // wtype.size_bits
|
55 |
+
mask = (1 << wtype.size_bits) - 1
|
56 |
+
|
57 |
+
new_shape_perm = list(w_q_perm.shape)
|
58 |
+
new_shape_perm[-1] *= pack_factor
|
59 |
+
|
60 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
61 |
+
for i in range(pack_factor):
|
62 |
+
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
|
63 |
+
|
64 |
+
return res.permute(inv_perm)
|
65 |
+
|
66 |
+
|
67 |
+
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
68 |
+
# prefix: model.layers.0.self_attn.q_proj
|
69 |
+
# proj_name: q_proj
|
70 |
+
proj_name = prefix.split(".")[-1]
|
71 |
+
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
72 |
+
shard_prefixes = [
|
73 |
+
prefix.replace(proj_name, shard_proj_name)
|
74 |
+
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
75 |
+
]
|
76 |
+
|
77 |
+
is_skipped = None
|
78 |
+
for shard_prefix in shard_prefixes:
|
79 |
+
is_shard_skipped = shard_prefix in ignored_layers
|
80 |
+
|
81 |
+
if is_skipped is None:
|
82 |
+
is_skipped = is_shard_skipped
|
83 |
+
elif is_shard_skipped != is_skipped:
|
84 |
+
raise ValueError(
|
85 |
+
f"Detected some but not all shards of {prefix} "
|
86 |
+
"are quantized. All shards of fused layers "
|
87 |
+
"to have the same precision."
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
is_skipped = prefix in ignored_layers
|
91 |
+
|
92 |
+
assert is_skipped is not None
|
93 |
+
return is_skipped
|
94 |
+
|
95 |
+
|
96 |
+
def get_pack_factor(num_bits):
|
97 |
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
98 |
+
return 32 // num_bits
|
99 |
+
|
100 |
+
|
101 |
+
def permute_rows(
|
102 |
+
q_w: torch.Tensor,
|
103 |
+
w_ref: torch.Tensor,
|
104 |
+
group_size: int,
|
105 |
+
test_perm: Optional[torch.Tensor] = None,
|
106 |
+
):
|
107 |
+
assert q_w.shape == w_ref.shape
|
108 |
+
|
109 |
+
orig_device = q_w.device
|
110 |
+
k_size, _ = q_w.shape
|
111 |
+
|
112 |
+
g_idx = torch.zeros((k_size,), dtype=torch.int32)
|
113 |
+
for i in range(k_size):
|
114 |
+
g_idx[i] = i // group_size
|
115 |
+
|
116 |
+
# Simulate act_order by doing a random permutation on K
|
117 |
+
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
118 |
+
|
119 |
+
g_idx = g_idx[rand_perm].contiguous()
|
120 |
+
q_w = q_w[rand_perm, :].contiguous()
|
121 |
+
w_ref = w_ref[rand_perm, :].contiguous()
|
122 |
+
|
123 |
+
return (
|
124 |
+
w_ref.to(device=orig_device),
|
125 |
+
q_w.to(device=orig_device),
|
126 |
+
g_idx.to(device=orig_device),
|
127 |
+
rand_perm.to(device=orig_device),
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def quantize_weights(
|
132 |
+
w: torch.Tensor,
|
133 |
+
quant_type: ScalarType,
|
134 |
+
group_size: Optional[int],
|
135 |
+
zero_points: bool = False,
|
136 |
+
ref_zero_points_after_scales: bool = False,
|
137 |
+
):
|
138 |
+
assert (
|
139 |
+
quant_type.is_integer()
|
140 |
+
), "Floating point quantization may work but has not been tested"
|
141 |
+
assert not zero_points or group_size is not None, (
|
142 |
+
"to have group zero points, group_size must be provided "
|
143 |
+
"(-1 group_size is channelwise)"
|
144 |
+
)
|
145 |
+
|
146 |
+
orig_device = w.device
|
147 |
+
orig_type = w.dtype
|
148 |
+
size_k, size_n = w.shape
|
149 |
+
|
150 |
+
assert w.is_floating_point(), "w must be float"
|
151 |
+
|
152 |
+
if group_size == -1:
|
153 |
+
group_size = size_k
|
154 |
+
|
155 |
+
# Reshape to [groupsize, -1]
|
156 |
+
if group_size is not None and group_size < size_k:
|
157 |
+
w = w.reshape((-1, group_size, size_n))
|
158 |
+
w = w.permute(1, 0, 2)
|
159 |
+
w = w.reshape((group_size, -1))
|
160 |
+
|
161 |
+
# Compute scale for each group
|
162 |
+
max_val = torch.max(w, 0, keepdim=True).values
|
163 |
+
min_val = torch.min(w, 0, keepdim=True).values
|
164 |
+
|
165 |
+
max_q_val = quant_type.max()
|
166 |
+
min_q_val = quant_type.min()
|
167 |
+
|
168 |
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
169 |
+
maybe_w_zp = None
|
170 |
+
if group_size is not None:
|
171 |
+
if zero_points:
|
172 |
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
173 |
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
174 |
+
maybe_w_zp = (
|
175 |
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
# If the bias is such that there are no possible negative/positive
|
179 |
+
# values, set the max value to inf to avoid divide by 0
|
180 |
+
w_s = torch.max(
|
181 |
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
182 |
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
183 |
+
)
|
184 |
+
|
185 |
+
# Quantize
|
186 |
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
187 |
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
188 |
+
|
189 |
+
# Compute ref (dequantized)
|
190 |
+
# For some kernels (namely Machete) the zero-points are applied after the
|
191 |
+
# scales are applied, for this case computing the reference in similar way
|
192 |
+
# allows us to use tighter error tolerances in our unit tests.
|
193 |
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
194 |
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
195 |
+
else:
|
196 |
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
197 |
+
|
198 |
+
if quant_type.has_bias():
|
199 |
+
w_q += quant_type.bias
|
200 |
+
|
201 |
+
# Restore original shapes
|
202 |
+
if group_size is not None and group_size < size_k:
|
203 |
+
|
204 |
+
def reshape_w(w):
|
205 |
+
w = w.reshape((group_size, -1, size_n))
|
206 |
+
w = w.permute(1, 0, 2)
|
207 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
208 |
+
return w
|
209 |
+
|
210 |
+
w_q = reshape_w(w_q)
|
211 |
+
w_ref = reshape_w(w_ref)
|
212 |
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
213 |
+
|
214 |
+
if maybe_w_zp is not None:
|
215 |
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
216 |
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
217 |
+
|
218 |
+
return (
|
219 |
+
w_ref.to(device=orig_device),
|
220 |
+
w_q.to(device=orig_device),
|
221 |
+
w_s if group_size is not None else None,
|
222 |
+
maybe_w_zp,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def gptq_quantize_weights(
|
227 |
+
w: torch.Tensor,
|
228 |
+
quant_type: ScalarType,
|
229 |
+
group_size: int,
|
230 |
+
act_order: bool,
|
231 |
+
test_perm: Optional[torch.Tensor] = None,
|
232 |
+
):
|
233 |
+
size_k, _ = w.shape
|
234 |
+
|
235 |
+
assert w.is_floating_point(), "w must be float"
|
236 |
+
assert (
|
237 |
+
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
|
238 |
+
), f"Unsupported gptq type = {quant_type}"
|
239 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
240 |
+
size_k
|
241 |
+
], f"Unsupported groupsize = {group_size}"
|
242 |
+
|
243 |
+
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
244 |
+
|
245 |
+
# Apply act_order
|
246 |
+
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
247 |
+
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
248 |
+
if act_order:
|
249 |
+
assert (
|
250 |
+
group_size < size_k
|
251 |
+
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
252 |
+
group_size, size_k
|
253 |
+
)
|
254 |
+
|
255 |
+
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
|
256 |
+
|
257 |
+
return w_ref, w_q, w_s, g_idx, rand_perm
|
258 |
+
|
259 |
+
|
260 |
+
# QQQ employs different quant schemes for per-group and
|
261 |
+
# per-channel quantization.
|
262 |
+
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
263 |
+
orig_device = w.device
|
264 |
+
size_k, size_n = w.shape
|
265 |
+
|
266 |
+
assert w.is_floating_point(), "w must be float"
|
267 |
+
assert (
|
268 |
+
num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
|
269 |
+
), f"Unsupported num_bits = {num_bits}"
|
270 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
271 |
+
size_k
|
272 |
+
], f"Unsupported groupsize = {group_size}"
|
273 |
+
|
274 |
+
if group_size == -1:
|
275 |
+
group_size = size_k
|
276 |
+
assert group_size <= size_k
|
277 |
+
|
278 |
+
if group_size < size_k:
|
279 |
+
# Reshape to [groupsize, -1]
|
280 |
+
w = w.reshape((-1, group_size, size_n))
|
281 |
+
w = w.permute(1, 0, 2)
|
282 |
+
w = w.reshape((group_size, -1))
|
283 |
+
|
284 |
+
max_q_val = 2**num_bits - 1
|
285 |
+
half_q_val = (max_q_val + 1) // 2
|
286 |
+
|
287 |
+
# Compute scale for each group
|
288 |
+
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
289 |
+
s_group *= 2 / max_q_val # 2 => symmetric
|
290 |
+
|
291 |
+
# Quantize
|
292 |
+
q_w = torch.round(w / s_group).int()
|
293 |
+
q_w += half_q_val
|
294 |
+
q_w = torch.clamp(q_w, 0, max_q_val)
|
295 |
+
# Compute ref (dequantized)
|
296 |
+
w_ref = (q_w - half_q_val).half() * s_group
|
297 |
+
|
298 |
+
# Restore original shapes
|
299 |
+
def reshape_w(w):
|
300 |
+
w = w.reshape((group_size, -1, size_n))
|
301 |
+
w = w.permute(1, 0, 2)
|
302 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
303 |
+
return w
|
304 |
+
|
305 |
+
q_w = reshape_w(q_w)
|
306 |
+
w_ref = reshape_w(w_ref)
|
307 |
+
|
308 |
+
# Compute int8 quantization scale for each channel
|
309 |
+
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
|
310 |
+
s_channel /= 127.0
|
311 |
+
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
|
312 |
+
w_ref = t_int8.half() * s_channel
|
313 |
+
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
|
314 |
+
|
315 |
+
# Fuse scales
|
316 |
+
s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
|
317 |
+
dtype=torch.half
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
max_q_val = 2 ** (num_bits - 1) - 1
|
321 |
+
|
322 |
+
# Compute scale for each channel
|
323 |
+
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
324 |
+
s_channel /= max_q_val
|
325 |
+
|
326 |
+
# Quantize
|
327 |
+
q_w = torch.round(w / s_channel).int()
|
328 |
+
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
|
329 |
+
# Compute ref (dequantized)
|
330 |
+
w_ref = q_w.half() * s_channel
|
331 |
+
|
332 |
+
s_group = torch.tensor([], dtype=torch.half)
|
333 |
+
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
|
334 |
+
s_channel /= 2 ** (8 - num_bits)
|
335 |
+
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
|
336 |
+
|
337 |
+
return (
|
338 |
+
w_ref.to(device=orig_device),
|
339 |
+
q_w.to(device=orig_device),
|
340 |
+
s_group.to(device=orig_device),
|
341 |
+
s_channel.to(device=orig_device),
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
346 |
+
orig_device = q_w.device
|
347 |
+
|
348 |
+
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
|
349 |
+
|
350 |
+
g_idx = g_idx[sort_indices].contiguous()
|
351 |
+
q_w = q_w[sort_indices, :].contiguous()
|
352 |
+
|
353 |
+
return (
|
354 |
+
q_w.to(device=orig_device),
|
355 |
+
g_idx.to(device=orig_device),
|
356 |
+
sort_indices.to(device=orig_device),
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
def pack_rows(
|
361 |
+
q_w: torch.Tensor,
|
362 |
+
num_bits: int,
|
363 |
+
size_k: int,
|
364 |
+
size_n: int,
|
365 |
+
):
|
366 |
+
assert q_w.shape == (size_k, size_n)
|
367 |
+
|
368 |
+
pack_factor = get_pack_factor(num_bits)
|
369 |
+
assert size_k % pack_factor == 0
|
370 |
+
|
371 |
+
orig_device = q_w.device
|
372 |
+
|
373 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
374 |
+
|
375 |
+
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
376 |
+
|
377 |
+
for i in range(pack_factor):
|
378 |
+
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
379 |
+
|
380 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
381 |
+
return q_res
|
382 |
+
|
383 |
+
|
384 |
+
def pack_cols(
|
385 |
+
q_w: torch.Tensor,
|
386 |
+
num_bits: int,
|
387 |
+
size_k: int,
|
388 |
+
size_n: int,
|
389 |
+
):
|
390 |
+
assert q_w.shape == (size_k, size_n)
|
391 |
+
|
392 |
+
pack_factor = get_pack_factor(num_bits)
|
393 |
+
assert size_n % pack_factor == 0
|
394 |
+
|
395 |
+
orig_device = q_w.device
|
396 |
+
|
397 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
398 |
+
|
399 |
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
400 |
+
|
401 |
+
for i in range(pack_factor):
|
402 |
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
403 |
+
|
404 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
405 |
+
q_res = q_res.contiguous()
|
406 |
+
|
407 |
+
return q_res
|
408 |
+
|
409 |
+
|
410 |
+
def unpack_cols(
|
411 |
+
packed_q_w: torch.Tensor,
|
412 |
+
num_bits: int,
|
413 |
+
size_k: int,
|
414 |
+
size_n: int,
|
415 |
+
):
|
416 |
+
pack_factor = get_pack_factor(num_bits)
|
417 |
+
assert size_n % pack_factor == 0
|
418 |
+
assert packed_q_w.shape == (
|
419 |
+
size_k,
|
420 |
+
size_n // pack_factor,
|
421 |
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
422 |
+
packed_q_w.shape, size_k, size_n, pack_factor
|
423 |
+
)
|
424 |
+
|
425 |
+
orig_device = packed_q_w.device
|
426 |
+
|
427 |
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
428 |
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
429 |
+
|
430 |
+
mask = (1 << num_bits) - 1
|
431 |
+
for i in range(pack_factor):
|
432 |
+
vals = packed_q_w_cpu & mask
|
433 |
+
packed_q_w_cpu >>= num_bits
|
434 |
+
q_res[:, i::pack_factor] = vals
|
435 |
+
|
436 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
437 |
+
q_res = q_res.contiguous()
|
438 |
+
|
439 |
+
return q_res
|
440 |
+
|
441 |
+
|
442 |
+
def gptq_pack(
|
443 |
+
q_w: torch.Tensor,
|
444 |
+
num_bits: int,
|
445 |
+
size_k: int,
|
446 |
+
size_n: int,
|
447 |
+
):
|
448 |
+
return pack_rows(q_w, num_bits, size_k, size_n)
|
449 |
+
|
450 |
+
|
451 |
+
def awq_pack(
|
452 |
+
q_w: torch.Tensor,
|
453 |
+
num_bits: int,
|
454 |
+
size_k: int,
|
455 |
+
size_n: int,
|
456 |
+
):
|
457 |
+
assert q_w.shape == (size_k, size_n)
|
458 |
+
|
459 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
460 |
+
if num_bits == 4:
|
461 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
462 |
+
elif num_bits == 8:
|
463 |
+
interleave = numpy.array([0, 2, 1, 3])
|
464 |
+
else:
|
465 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
466 |
+
|
467 |
+
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
468 |
+
q_w = q_w.reshape((-1, size_n)).contiguous()
|
469 |
+
|
470 |
+
return pack_cols(q_w, num_bits, size_k, size_n)
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py
CHANGED
@@ -1,150 +1,30 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
#if current_platform.is_rocm():
|
33 |
-
# triton_scaled_mm_module = importlib.import_module(
|
34 |
-
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
35 |
-
# "triton_scaled_mm")
|
36 |
-
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
37 |
-
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
38 |
-
|
39 |
-
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
40 |
-
|
41 |
-
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
42 |
-
|
43 |
-
return out
|
44 |
-
|
45 |
-
# fp8
|
46 |
-
def scaled_fp8_quant(
|
47 |
-
input: torch.Tensor,
|
48 |
-
scale: Optional[torch.Tensor] = None,
|
49 |
-
num_token_padding: Optional[int] = None,
|
50 |
-
scale_ub: Optional[torch.Tensor] = None,
|
51 |
-
use_per_token_if_dynamic: bool = False,
|
52 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
53 |
-
"""
|
54 |
-
Quantize input tensor to FP8 and return quantized tensor and scale.
|
55 |
-
|
56 |
-
This function supports both static and dynamic quantization: If you
|
57 |
-
provide the scale, it will use static scaling and if you omit it,
|
58 |
-
the scale will be determined dynamically. The function also allows
|
59 |
-
optional padding of the output tensors for downstream kernels that
|
60 |
-
will benefit from padding.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
input: The input tensor to be quantized to FP8
|
64 |
-
scale: Optional scaling factor for the FP8 quantization
|
65 |
-
scale_ub: Optional upper bound for scaling factor in dynamic
|
66 |
-
per token case
|
67 |
-
num_token_padding: If specified, pad the first dimension
|
68 |
-
of the output to at least this value.
|
69 |
-
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
70 |
-
in the dynamic quantization case.
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
74 |
-
scaling factor.
|
75 |
-
"""
|
76 |
-
# This code assumes batch_dim and num_tokens are flattened
|
77 |
-
assert (input.ndim == 2)
|
78 |
-
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
79 |
-
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
80 |
-
#out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
81 |
-
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
82 |
-
out_dtype = torch.float8_e4m3fn
|
83 |
-
if num_token_padding:
|
84 |
-
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
85 |
-
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
86 |
-
|
87 |
-
if scale is None:
|
88 |
-
if use_per_token_if_dynamic:
|
89 |
-
scale = torch.empty((shape[0], 1),
|
90 |
-
device=input.device,
|
91 |
-
dtype=torch.float32)
|
92 |
-
ops.dynamic_per_token_scaled_fp8_quant(
|
93 |
-
output, input, scale, scale_ub)
|
94 |
-
else:
|
95 |
-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
96 |
-
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
97 |
-
else:
|
98 |
-
# num_token_padding not implemented for this case
|
99 |
-
assert (scale.numel() == 1 or num_token_padding is None)
|
100 |
-
ops.static_scaled_fp8_quant(output, input, scale)
|
101 |
-
|
102 |
-
return output, scale
|
103 |
-
|
104 |
-
# int8
|
105 |
-
def scaled_int8_quant(
|
106 |
-
input: torch.Tensor,
|
107 |
-
scale: Optional[torch.Tensor] = None,
|
108 |
-
azp: Optional[torch.Tensor] = None,
|
109 |
-
symmetric: bool = True
|
110 |
-
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
111 |
-
"""
|
112 |
-
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
113 |
-
|
114 |
-
Args:
|
115 |
-
input: The input tensor to be quantized to int8.
|
116 |
-
scale: Optional scaling factor for the int8 quantization.
|
117 |
-
When not provided, we invoke dynamic-per-token quantization.
|
118 |
-
azp: Optional zero-point for the int8 quantization.
|
119 |
-
Must be provided for asymmetric quantization if `scale` is provided.
|
120 |
-
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
121 |
-
|
122 |
-
Returns:
|
123 |
-
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
124 |
-
"""
|
125 |
-
output = torch.empty_like(input, dtype=torch.int8)
|
126 |
-
if scale is not None:
|
127 |
-
# static-per-tensor quantization.
|
128 |
-
assert symmetric == (
|
129 |
-
azp is
|
130 |
-
None), "azp must only be provided for asymmetric quantization."
|
131 |
-
ops.static_scaled_int8_quant(output, input, scale, azp)
|
132 |
-
return output, scale, azp
|
133 |
-
|
134 |
-
# dynamic-per-token quantization.
|
135 |
-
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
136 |
-
device=input.device,
|
137 |
-
dtype=torch.float32)
|
138 |
-
input_azp = None if symmetric else torch.empty_like(input_scales,
|
139 |
-
dtype=torch.int32)
|
140 |
-
ops.dynamic_scaled_int8_quant(output, input, input_scales,
|
141 |
-
input_azp)
|
142 |
-
return output, input_scales, input_azp
|
143 |
-
|
144 |
-
# fp8 marlin
|
145 |
-
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
146 |
-
b_scales: torch.Tensor, workspace: torch.Tensor,
|
147 |
-
num_bits: int, size_m: int, size_n: int,
|
148 |
-
size_k: int) -> torch.Tensor:
|
149 |
-
return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
150 |
-
num_bits, size_m, size_n, size_k)
|
|
|
1 |
+
from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
|
2 |
+
from .cutlass import (
|
3 |
+
cutlass_scaled_mm_supports_fp8,
|
4 |
+
cutlass_scaled_mm,
|
5 |
+
cutlass_scaled_mm_azp,
|
6 |
+
)
|
7 |
+
from .marlin import (
|
8 |
+
awq_marlin_repack,
|
9 |
+
fp8_marlin_gemm,
|
10 |
+
gptq_marlin_gemm,
|
11 |
+
gptq_marlin_repack,
|
12 |
+
gptq_marlin_24_gemm,
|
13 |
+
marlin_qqq_gemm,
|
14 |
+
marlin_gemm,
|
15 |
+
)
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"awq_marlin_repack",
|
19 |
+
"cutlass_scaled_mm",
|
20 |
+
"cutlass_scaled_mm_azp",
|
21 |
+
"cutlass_scaled_mm_supports_fp8",
|
22 |
+
"fp8_marlin_gemm",
|
23 |
+
"gptq_marlin_24_gemm",
|
24 |
+
"gptq_marlin_gemm",
|
25 |
+
"gptq_marlin_repack",
|
26 |
+
"marlin_gemm",
|
27 |
+
"marlin_qqq_gemm",
|
28 |
+
"scaled_fp8_quant",
|
29 |
+
"scaled_int8_quant",
|
30 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/_ops.py
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
import torch
|
2 |
from . import _quantization_0_0_1
|
3 |
ops = torch.ops._quantization_0_0_1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from . import _quantization_0_0_1
|
3 |
ops = torch.ops._quantization_0_0_1
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_quantization_0_0_1::{op_name}"
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a458d5efc51f80028811707ee7b9fadb00f3bfc49917c8377188c286c4bd8e12
|
3 |
+
size 109249352
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
# fp8
|
18 |
+
def scaled_fp8_quant(
|
19 |
+
input: torch.Tensor,
|
20 |
+
scale: Optional[torch.Tensor] = None,
|
21 |
+
num_token_padding: Optional[int] = None,
|
22 |
+
scale_ub: Optional[torch.Tensor] = None,
|
23 |
+
use_per_token_if_dynamic: bool = False,
|
24 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
"""
|
26 |
+
Quantize input tensor to FP8 and return quantized tensor and scale.
|
27 |
+
|
28 |
+
This function supports both static and dynamic quantization: If you
|
29 |
+
provide the scale, it will use static scaling and if you omit it,
|
30 |
+
the scale will be determined dynamically. The function also allows
|
31 |
+
optional padding of the output tensors for downstream kernels that
|
32 |
+
will benefit from padding.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
input: The input tensor to be quantized to FP8
|
36 |
+
scale: Optional scaling factor for the FP8 quantization
|
37 |
+
scale_ub: Optional upper bound for scaling factor in dynamic
|
38 |
+
per token case
|
39 |
+
num_token_padding: If specified, pad the first dimension
|
40 |
+
of the output to at least this value.
|
41 |
+
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
42 |
+
in the dynamic quantization case.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
46 |
+
scaling factor.
|
47 |
+
"""
|
48 |
+
# This code assumes batch_dim and num_tokens are flattened
|
49 |
+
assert input.ndim == 2
|
50 |
+
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
51 |
+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
52 |
+
# out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
53 |
+
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
54 |
+
out_dtype = torch.float8_e4m3fn
|
55 |
+
if num_token_padding:
|
56 |
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
57 |
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
58 |
+
|
59 |
+
if scale is None:
|
60 |
+
if use_per_token_if_dynamic:
|
61 |
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
62 |
+
ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
63 |
+
else:
|
64 |
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
65 |
+
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
66 |
+
else:
|
67 |
+
# num_token_padding not implemented for this case
|
68 |
+
assert scale.numel() == 1 or num_token_padding is None
|
69 |
+
ops.static_scaled_fp8_quant(output, input, scale)
|
70 |
+
|
71 |
+
return output, scale
|
72 |
+
|
73 |
+
|
74 |
+
# int8
|
75 |
+
def scaled_int8_quant(
|
76 |
+
input: torch.Tensor,
|
77 |
+
scale: Optional[torch.Tensor] = None,
|
78 |
+
azp: Optional[torch.Tensor] = None,
|
79 |
+
symmetric: bool = True,
|
80 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
81 |
+
"""
|
82 |
+
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
input: The input tensor to be quantized to int8.
|
86 |
+
scale: Optional scaling factor for the int8 quantization.
|
87 |
+
When not provided, we invoke dynamic-per-token quantization.
|
88 |
+
azp: Optional zero-point for the int8 quantization.
|
89 |
+
Must be provided for asymmetric quantization if `scale` is provided.
|
90 |
+
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
94 |
+
"""
|
95 |
+
output = torch.empty_like(input, dtype=torch.int8)
|
96 |
+
if scale is not None:
|
97 |
+
# static-per-tensor quantization.
|
98 |
+
assert symmetric == (
|
99 |
+
azp is None
|
100 |
+
), "azp must only be provided for asymmetric quantization."
|
101 |
+
ops.static_scaled_int8_quant(output, input, scale, azp)
|
102 |
+
return output, scale, azp
|
103 |
+
|
104 |
+
# dynamic-per-token quantization.
|
105 |
+
input_scales = torch.empty(
|
106 |
+
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
|
107 |
+
)
|
108 |
+
input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
|
109 |
+
ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
|
110 |
+
return output, input_scales, input_azp
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/cutlass.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
18 |
+
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
19 |
+
|
20 |
+
|
21 |
+
def cutlass_scaled_mm(
|
22 |
+
a: torch.Tensor,
|
23 |
+
b: torch.Tensor,
|
24 |
+
scale_a: torch.Tensor,
|
25 |
+
scale_b: torch.Tensor,
|
26 |
+
out_dtype: torch.dtype,
|
27 |
+
bias: Optional[torch.Tensor] = None,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
30 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
31 |
+
assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
|
32 |
+
|
33 |
+
m = a.shape[0]
|
34 |
+
n = b.shape[1]
|
35 |
+
|
36 |
+
# if current_platform.is_rocm():
|
37 |
+
# triton_scaled_mm_module = importlib.import_module(
|
38 |
+
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
39 |
+
# "triton_scaled_mm")
|
40 |
+
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
41 |
+
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
42 |
+
|
43 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
44 |
+
|
45 |
+
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
def cutlass_scaled_mm_azp(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b: torch.Tensor,
|
53 |
+
scale_a: torch.Tensor,
|
54 |
+
scale_b: torch.Tensor,
|
55 |
+
out_dtype: torch.dtype,
|
56 |
+
azp_adj: torch.Tensor,
|
57 |
+
azp: Optional[torch.Tensor] = None,
|
58 |
+
bias: Optional[torch.Tensor] = None,
|
59 |
+
) -> torch.Tensor:
|
60 |
+
"""
|
61 |
+
:param azp_adj: In the per-tensor case, this should include the azp.
|
62 |
+
Always per-channel.
|
63 |
+
:param azp: Only set in the per-token case. Per-token if set.
|
64 |
+
"""
|
65 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
66 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
67 |
+
assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
|
68 |
+
assert azp is None or azp.numel() == a.shape[0]
|
69 |
+
|
70 |
+
m = a.shape[0]
|
71 |
+
n = b.shape[1]
|
72 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
73 |
+
|
74 |
+
ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
|
75 |
+
return out
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/marlin.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# neuron has torch version that doesn't even have impl_abstract
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
def register_fake(fn):
|
8 |
+
return lambda name: fn
|
9 |
+
else:
|
10 |
+
try:
|
11 |
+
from torch.library import register_fake
|
12 |
+
except ImportError:
|
13 |
+
from torch.library import impl_abstract as register_fake
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ._ops import ops, add_op_namespace_prefix
|
17 |
+
except ImportError as e:
|
18 |
+
# Fallback for local development.
|
19 |
+
try:
|
20 |
+
import _quantization
|
21 |
+
|
22 |
+
ops = torch.ops._quantization
|
23 |
+
|
24 |
+
def add_op_namespace_prefix(op_name: str):
|
25 |
+
return f"_quantization::{op_name}"
|
26 |
+
except ImportError:
|
27 |
+
raise e
|
28 |
+
|
29 |
+
|
30 |
+
from .scalar_type import ScalarType
|
31 |
+
|
32 |
+
|
33 |
+
# fp8 marlin
|
34 |
+
def fp8_marlin_gemm(
|
35 |
+
a: torch.Tensor,
|
36 |
+
b_q_weight: torch.Tensor,
|
37 |
+
b_scales: torch.Tensor,
|
38 |
+
workspace: torch.Tensor,
|
39 |
+
num_bits: int,
|
40 |
+
size_m: int,
|
41 |
+
size_n: int,
|
42 |
+
size_k: int,
|
43 |
+
) -> torch.Tensor:
|
44 |
+
return ops.fp8_marlin_gemm(
|
45 |
+
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
# gptq_marlin
|
50 |
+
def gptq_marlin_gemm(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b_q_weight: torch.Tensor,
|
53 |
+
b_scales: torch.Tensor,
|
54 |
+
b_zeros: torch.Tensor,
|
55 |
+
g_idx: torch.Tensor,
|
56 |
+
perm: torch.Tensor,
|
57 |
+
workspace: torch.Tensor,
|
58 |
+
b_q_type: ScalarType,
|
59 |
+
size_m: int,
|
60 |
+
size_n: int,
|
61 |
+
size_k: int,
|
62 |
+
is_k_full: bool,
|
63 |
+
has_zp: bool = False,
|
64 |
+
use_fp32_reduce: bool = False,
|
65 |
+
is_zp_float: bool = False,
|
66 |
+
) -> torch.Tensor:
|
67 |
+
return ops.gptq_marlin_gemm(
|
68 |
+
a,
|
69 |
+
b_q_weight,
|
70 |
+
b_scales,
|
71 |
+
b_zeros,
|
72 |
+
g_idx,
|
73 |
+
perm,
|
74 |
+
workspace,
|
75 |
+
b_q_type.id,
|
76 |
+
size_m,
|
77 |
+
size_n,
|
78 |
+
size_k,
|
79 |
+
is_k_full,
|
80 |
+
has_zp,
|
81 |
+
use_fp32_reduce,
|
82 |
+
is_zp_float,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
# gptq_marlin
|
87 |
+
def gptq_marlin_repack(
|
88 |
+
b_q_weight: torch.Tensor,
|
89 |
+
perm: torch.Tensor,
|
90 |
+
size_k: int,
|
91 |
+
size_n: int,
|
92 |
+
num_bits: int,
|
93 |
+
) -> torch.Tensor:
|
94 |
+
return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
95 |
+
|
96 |
+
|
97 |
+
# gptq_marlin
|
98 |
+
def awq_marlin_repack(
|
99 |
+
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
100 |
+
) -> torch.Tensor:
|
101 |
+
return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
102 |
+
|
103 |
+
|
104 |
+
# marlin
|
105 |
+
def marlin_gemm(
|
106 |
+
a: torch.Tensor,
|
107 |
+
b_q_weight: torch.Tensor,
|
108 |
+
b_scales: torch.Tensor,
|
109 |
+
workspace: torch.Tensor,
|
110 |
+
size_m: int,
|
111 |
+
size_n: int,
|
112 |
+
size_k: int,
|
113 |
+
) -> torch.Tensor:
|
114 |
+
return ops.marlin_gemm(
|
115 |
+
a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# marlin_24
|
120 |
+
def gptq_marlin_24_gemm(
|
121 |
+
a: torch.Tensor,
|
122 |
+
b_q_weight: torch.Tensor,
|
123 |
+
b_meta: torch.Tensor,
|
124 |
+
b_scales: torch.Tensor,
|
125 |
+
workspace: torch.Tensor,
|
126 |
+
b_q_type: ScalarType,
|
127 |
+
size_m: int,
|
128 |
+
size_n: int,
|
129 |
+
size_k: int,
|
130 |
+
) -> torch.Tensor:
|
131 |
+
return ops.gptq_marlin_24_gemm(
|
132 |
+
a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
# qqq ops
|
137 |
+
def marlin_qqq_gemm(
|
138 |
+
a: torch.Tensor,
|
139 |
+
b_q_weight: torch.Tensor,
|
140 |
+
s_tok: torch.Tensor,
|
141 |
+
s_ch: torch.Tensor,
|
142 |
+
s_group: torch.Tensor,
|
143 |
+
workspace: torch.Tensor,
|
144 |
+
size_m: int,
|
145 |
+
size_n: int,
|
146 |
+
size_k: int,
|
147 |
+
) -> torch.Tensor:
|
148 |
+
return ops.marlin_qqq_gemm(
|
149 |
+
a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
# Fake ops
|
154 |
+
|
155 |
+
if hasattr(ops, "gptq_marlin_24_gemm"):
|
156 |
+
@register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
|
157 |
+
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
158 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
159 |
+
num_bits: int, size_m: torch.SymInt,
|
160 |
+
size_n: torch.SymInt,
|
161 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
162 |
+
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
163 |
+
|
164 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
|
165 |
+
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
166 |
+
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
167 |
+
workspace: torch.Tensor,
|
168 |
+
b_q_type: ScalarType, size_m: torch.SymInt,
|
169 |
+
size_n: torch.SymInt,
|
170 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
171 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
172 |
+
|
173 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
|
174 |
+
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
175 |
+
b_q_weight: torch.Tensor,
|
176 |
+
b_scales: torch.Tensor,
|
177 |
+
b_zeros: torch.Tensor,
|
178 |
+
g_idx: torch.Tensor,
|
179 |
+
perm: torch.Tensor,
|
180 |
+
workspace: torch.Tensor,
|
181 |
+
b_q_type: ScalarType,
|
182 |
+
size_m: torch.SymInt,
|
183 |
+
size_n: torch.SymInt,
|
184 |
+
size_k: torch.SymInt,
|
185 |
+
is_k_full: bool,
|
186 |
+
has_zp: bool = False,
|
187 |
+
use_fp32_reduce: bool = False,
|
188 |
+
is_zp_float: bool = False) -> torch.Tensor:
|
189 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
190 |
+
|
191 |
+
@register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
|
192 |
+
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
193 |
+
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
194 |
+
s_group: torch.Tensor, workspace: torch.Tensor,
|
195 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
196 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
197 |
+
return torch.empty((size_m, size_n),
|
198 |
+
dtype=torch.float16,
|
199 |
+
device=a.device)
|
200 |
+
|
201 |
+
@register_fake(add_op_namespace_prefix("marlin_gemm"))
|
202 |
+
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
203 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
204 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
205 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
206 |
+
return torch.empty((size_m, size_n),
|
207 |
+
dtype=torch.float16,
|
208 |
+
device=a.device)
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/scalar_type.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import struct
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
# Mirrors enum in `core/scalar_type.hpp`
|
9 |
+
class NanRepr(Enum):
|
10 |
+
NONE = 0 # nans are not supported
|
11 |
+
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
12 |
+
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
13 |
+
|
14 |
+
|
15 |
+
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
16 |
+
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
17 |
+
# in sync until the inductor fully supports custom C++ classes.
|
18 |
+
@dataclass(frozen=True)
|
19 |
+
class ScalarType:
|
20 |
+
"""
|
21 |
+
ScalarType can represent a wide range of floating point and integer
|
22 |
+
types, in particular it can be used to represent sub-byte data types
|
23 |
+
(something that torch.dtype currently does not support). It is also
|
24 |
+
capable of representing types with a bias, i.e.:
|
25 |
+
`stored_value = value + bias`,
|
26 |
+
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
27 |
+
of 8). The implementation for this class can be found in
|
28 |
+
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
29 |
+
with that file.
|
30 |
+
"""
|
31 |
+
|
32 |
+
exponent: int
|
33 |
+
"""
|
34 |
+
Number of bits in the exponent if this is a floating point type
|
35 |
+
(zero if this an integer type)
|
36 |
+
"""
|
37 |
+
|
38 |
+
mantissa: int
|
39 |
+
"""
|
40 |
+
Number of bits in the mantissa if this is a floating point type,
|
41 |
+
or the number bits representing an integer excluding the sign bit if
|
42 |
+
this an integer type.
|
43 |
+
"""
|
44 |
+
|
45 |
+
signed: bool
|
46 |
+
"If the type is signed (i.e. has a sign bit)"
|
47 |
+
|
48 |
+
bias: int
|
49 |
+
"""
|
50 |
+
bias used to encode the values in this scalar type
|
51 |
+
(value = stored_value - bias, default 0) for example if we store the
|
52 |
+
type as an unsigned integer with a bias of 128 then the value 0 will be
|
53 |
+
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
54 |
+
"""
|
55 |
+
|
56 |
+
_finite_values_only: bool = False
|
57 |
+
"""
|
58 |
+
Private: if infs are supported, used `has_infs()` instead.
|
59 |
+
"""
|
60 |
+
|
61 |
+
nan_repr: NanRepr = NanRepr.IEEE_754
|
62 |
+
"""
|
63 |
+
How NaNs are represent in this scalar type, returns NanRepr value.
|
64 |
+
(not applicable for integer types)
|
65 |
+
"""
|
66 |
+
|
67 |
+
def _floating_point_max_int(self) -> int:
|
68 |
+
assert (
|
69 |
+
self.mantissa <= 52 and self.exponent <= 11
|
70 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
71 |
+
|
72 |
+
max_mantissa = (1 << self.mantissa) - 1
|
73 |
+
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
74 |
+
max_mantissa = max_mantissa - 1
|
75 |
+
|
76 |
+
max_exponent = (1 << self.exponent) - 2
|
77 |
+
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
78 |
+
or self.nan_repr == NanRepr.NONE):
|
79 |
+
assert (
|
80 |
+
self.exponent < 11
|
81 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
82 |
+
max_exponent = max_exponent + 1
|
83 |
+
|
84 |
+
# adjust the exponent to match that of a double
|
85 |
+
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
86 |
+
# e is the exponent bits), there is some precedent for non-standard
|
87 |
+
# biases, example `float8_e4m3b11fnuz` here:
|
88 |
+
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
89 |
+
# complication we are just assuming the standard exponent bias until
|
90 |
+
# there is a need to support non-standard biases
|
91 |
+
exponent_bias = (1 << (self.exponent - 1)) - 1
|
92 |
+
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
93 |
+
|
94 |
+
max_exponent_double = (max_exponent - exponent_bias +
|
95 |
+
exponent_bias_double)
|
96 |
+
|
97 |
+
# shift the mantissa and exponent into the proper positions for an
|
98 |
+
# IEEE double and bitwise-or them together.
|
99 |
+
return (max_mantissa <<
|
100 |
+
(52 - self.mantissa)) | (max_exponent_double << 52)
|
101 |
+
|
102 |
+
def _floating_point_max(self) -> float:
|
103 |
+
double_raw = self._floating_point_max_int()
|
104 |
+
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
105 |
+
|
106 |
+
def _raw_max(self) -> Union[int, float]:
|
107 |
+
if self.is_floating_point():
|
108 |
+
return self._floating_point_max()
|
109 |
+
else:
|
110 |
+
assert (self.size_bits < 64 or self.size_bits == 64
|
111 |
+
and self.is_signed()), "Cannot represent max as an int"
|
112 |
+
return (1 << self.mantissa) - 1
|
113 |
+
|
114 |
+
def _raw_min(self) -> Union[int, float]:
|
115 |
+
if self.is_floating_point():
|
116 |
+
assert self.is_signed(
|
117 |
+
), "We currently assume all floating point types are signed"
|
118 |
+
sign_bit_double = 1 << 63
|
119 |
+
|
120 |
+
max_raw = self._floating_point_max_int()
|
121 |
+
min_raw = max_raw | sign_bit_double
|
122 |
+
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
123 |
+
else:
|
124 |
+
assert (not self.is_signed() or
|
125 |
+
self.size_bits <= 64), "Cannot represent min as a int64_t"
|
126 |
+
|
127 |
+
if self.is_signed():
|
128 |
+
return -(1 << (self.size_bits - 1))
|
129 |
+
else:
|
130 |
+
return 0
|
131 |
+
|
132 |
+
@functools.cached_property
|
133 |
+
def id(self) -> int:
|
134 |
+
"""
|
135 |
+
Convert the ScalarType to an int which can be passed to pytorch custom
|
136 |
+
ops. This layout of the int must be kept in sync with the C++
|
137 |
+
ScalarType's from_id method.
|
138 |
+
"""
|
139 |
+
val = 0
|
140 |
+
offset = 0
|
141 |
+
|
142 |
+
def or_and_advance(member, bit_width):
|
143 |
+
nonlocal val
|
144 |
+
nonlocal offset
|
145 |
+
bit_mask = (1 << bit_width) - 1
|
146 |
+
val = val | (int(member) & bit_mask) << offset
|
147 |
+
offset = offset + bit_width
|
148 |
+
|
149 |
+
or_and_advance(self.exponent, 8)
|
150 |
+
or_and_advance(self.mantissa, 8)
|
151 |
+
or_and_advance(self.signed, 1)
|
152 |
+
or_and_advance(self.bias, 32)
|
153 |
+
or_and_advance(self._finite_values_only, 1)
|
154 |
+
or_and_advance(self.nan_repr.value, 8)
|
155 |
+
|
156 |
+
assert offset <= 64, \
|
157 |
+
f"ScalarType fields too big {offset} to fit into an int64"
|
158 |
+
|
159 |
+
return val
|
160 |
+
|
161 |
+
@property
|
162 |
+
def size_bits(self) -> int:
|
163 |
+
return self.exponent + self.mantissa + int(self.signed)
|
164 |
+
|
165 |
+
def min(self) -> Union[int, float]:
|
166 |
+
"""
|
167 |
+
Min representable value for this scalar type.
|
168 |
+
(accounting for bias if there is one)
|
169 |
+
"""
|
170 |
+
return self._raw_min() - self.bias
|
171 |
+
|
172 |
+
def max(self) -> Union[int, float]:
|
173 |
+
"""
|
174 |
+
Max representable value for this scalar type.
|
175 |
+
(accounting for bias if there is one)
|
176 |
+
"""
|
177 |
+
return self._raw_max() - self.bias
|
178 |
+
|
179 |
+
def is_signed(self) -> bool:
|
180 |
+
"""
|
181 |
+
If the type is signed (i.e. has a sign bit), same as `signed`
|
182 |
+
added for consistency with:
|
183 |
+
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
184 |
+
"""
|
185 |
+
return self.signed
|
186 |
+
|
187 |
+
def is_floating_point(self) -> bool:
|
188 |
+
"If the type is a floating point type"
|
189 |
+
return self.exponent != 0
|
190 |
+
|
191 |
+
def is_integer(self) -> bool:
|
192 |
+
"If the type is an integer type"
|
193 |
+
return self.exponent == 0
|
194 |
+
|
195 |
+
def has_bias(self) -> bool:
|
196 |
+
"If the type has a non-zero bias"
|
197 |
+
return self.bias != 0
|
198 |
+
|
199 |
+
def has_infs(self) -> bool:
|
200 |
+
"If the type is floating point and supports infinity"
|
201 |
+
return not self._finite_values_only
|
202 |
+
|
203 |
+
def has_nans(self) -> bool:
|
204 |
+
return self.nan_repr != NanRepr.NONE.value
|
205 |
+
|
206 |
+
def is_ieee_754(self) -> bool:
|
207 |
+
"""
|
208 |
+
If the type is a floating point type that follows IEEE 754
|
209 |
+
conventions
|
210 |
+
"""
|
211 |
+
return self.nan_repr == NanRepr.IEEE_754.value and \
|
212 |
+
not self._finite_values_only
|
213 |
+
|
214 |
+
def __str__(self) -> str:
|
215 |
+
"""
|
216 |
+
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
217 |
+
for floating point types (leading f) the scheme is:
|
218 |
+
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
219 |
+
flags:
|
220 |
+
- no-flags: means it follows IEEE 754 conventions
|
221 |
+
- f: means finite values only (no infinities)
|
222 |
+
- n: means nans are supported (non-standard encoding)
|
223 |
+
for integer types the scheme is:
|
224 |
+
`[u]int<size_bits>[b<bias>]`
|
225 |
+
- if bias is not present it means its zero
|
226 |
+
"""
|
227 |
+
if self.is_floating_point():
|
228 |
+
ret = "float" + str(self.size_bits) + "_e" + str(
|
229 |
+
self.exponent) + "m" + str(self.mantissa)
|
230 |
+
|
231 |
+
if not self.is_ieee_754():
|
232 |
+
if self._finite_values_only:
|
233 |
+
ret = ret + "f"
|
234 |
+
if self.nan_repr != NanRepr.NONE:
|
235 |
+
ret = ret + "n"
|
236 |
+
|
237 |
+
return ret
|
238 |
+
else:
|
239 |
+
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
240 |
+
if self.has_bias():
|
241 |
+
ret = ret + "b" + str(self.bias)
|
242 |
+
return ret
|
243 |
+
|
244 |
+
def __repr__(self) -> str:
|
245 |
+
return "ScalarType." + self.__str__()
|
246 |
+
|
247 |
+
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
248 |
+
# opcheck to work.
|
249 |
+
def __len__(self) -> int:
|
250 |
+
raise TypeError
|
251 |
+
|
252 |
+
#
|
253 |
+
# Convenience Constructors
|
254 |
+
#
|
255 |
+
|
256 |
+
@classmethod
|
257 |
+
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
258 |
+
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
259 |
+
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
260 |
+
ret.id # noqa B018: make sure the id is cached
|
261 |
+
return ret
|
262 |
+
|
263 |
+
@classmethod
|
264 |
+
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
265 |
+
"""Create a unsigned integer scalar type."""
|
266 |
+
ret = cls(0, size_bits, False, bias if bias else 0)
|
267 |
+
ret.id # noqa B018: make sure the id is cached
|
268 |
+
return ret
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
272 |
+
"""
|
273 |
+
Create a standard floating point type
|
274 |
+
(i.e. follows IEEE 754 conventions).
|
275 |
+
"""
|
276 |
+
assert (mantissa > 0 and exponent > 0)
|
277 |
+
ret = cls(exponent, mantissa, True, 0)
|
278 |
+
ret.id # noqa B018: make sure the id is cached
|
279 |
+
return ret
|
280 |
+
|
281 |
+
@classmethod
|
282 |
+
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
283 |
+
nan_repr: NanRepr) -> 'ScalarType':
|
284 |
+
"""
|
285 |
+
Create a non-standard floating point type
|
286 |
+
(i.e. does not follow IEEE 754 conventions).
|
287 |
+
"""
|
288 |
+
assert (mantissa > 0 and exponent > 0)
|
289 |
+
assert (nan_repr != NanRepr.IEEE_754), (
|
290 |
+
"use `float_IEEE754` constructor for floating point types that "
|
291 |
+
"follow IEEE 754 conventions")
|
292 |
+
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
293 |
+
ret.id # noqa B018: make sure the id is cached
|
294 |
+
return ret
|
295 |
+
|
296 |
+
|
297 |
+
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
298 |
+
# for floating point types (leading f) the scheme is:
|
299 |
+
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
300 |
+
# flags:
|
301 |
+
# - no-flags: means it follows IEEE 754 conventions
|
302 |
+
# - f: means finite values only (no infinities)
|
303 |
+
# - n: means nans are supported (non-standard encoding)
|
304 |
+
# for integer types the scheme is:
|
305 |
+
# `[u]int<size_bits>[b<bias>]`
|
306 |
+
# - if bias is not present it means its zero
|
307 |
+
|
308 |
+
|
309 |
+
class scalar_types:
|
310 |
+
int4 = ScalarType.int_(4, None)
|
311 |
+
uint4 = ScalarType.uint(4, None)
|
312 |
+
int8 = ScalarType.int_(8, None)
|
313 |
+
uint8 = ScalarType.uint(8, None)
|
314 |
+
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
315 |
+
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
316 |
+
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
317 |
+
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
318 |
+
|
319 |
+
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
320 |
+
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
321 |
+
|
322 |
+
# "gptq" types
|
323 |
+
uint2b2 = ScalarType.uint(2, 2)
|
324 |
+
uint3b4 = ScalarType.uint(3, 4)
|
325 |
+
uint4b8 = ScalarType.uint(4, 8)
|
326 |
+
uint8b128 = ScalarType.uint(8, 128)
|
327 |
+
|
328 |
+
# colloquial names
|
329 |
+
bfloat16 = float16_e8m7
|
330 |
+
float16 = float16_e5m10
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py
ADDED
File without changes
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import quantization as ops
|
7 |
+
from quantization.scalar_type import ScalarType, scalar_types
|
8 |
+
|
9 |
+
from .quant_utils import pack_cols, unpack_cols
|
10 |
+
|
11 |
+
GPTQ_MARLIN_TILE = 16
|
12 |
+
GPTQ_MARLIN_MIN_THREAD_N = 64
|
13 |
+
GPTQ_MARLIN_MIN_THREAD_K = 128
|
14 |
+
GPTQ_MARLIN_MAX_PARALLEL = 16
|
15 |
+
|
16 |
+
GPTQ_MARLIN_24_TILE = 16
|
17 |
+
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
18 |
+
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
19 |
+
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
20 |
+
|
21 |
+
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
22 |
+
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
23 |
+
|
24 |
+
MARLIN_QQQ_TILE = 16
|
25 |
+
MARLIN_QQQ_MIN_THREAD_N = 64
|
26 |
+
MARLIN_QQQ_MIN_THREAD_K = 128
|
27 |
+
MARLIN_QQQ_MAX_PARALLEL = 16
|
28 |
+
|
29 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
30 |
+
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
31 |
+
MARLIN_QQQ_SUPPORTED_SYM = [True]
|
32 |
+
|
33 |
+
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
34 |
+
|
35 |
+
# In case there is a performance issue with Marlin, the variable below can be
|
36 |
+
# changed to False, which allows Marlin to perform global reductions in fp16
|
37 |
+
# precision (instead of fp32), and therefore, save on some memory movements.
|
38 |
+
USE_FP32_REDUCE_DEFAULT = True
|
39 |
+
|
40 |
+
|
41 |
+
# For binary size and compile time, we don't support the same types for with and
|
42 |
+
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
43 |
+
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
44 |
+
def query_marlin_supported_quant_types(
|
45 |
+
has_zp: bool, device_capability: Optional[int] = None
|
46 |
+
):
|
47 |
+
if device_capability is None:
|
48 |
+
capability_tuple = torch.cuda.get_device_capability()
|
49 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
50 |
+
|
51 |
+
if device_capability < 80:
|
52 |
+
return []
|
53 |
+
|
54 |
+
if has_zp:
|
55 |
+
# AWQ style, unsigned + runtime zero-point
|
56 |
+
return [scalar_types.uint4, scalar_types.uint8]
|
57 |
+
else:
|
58 |
+
# GPTQ style, unsigned + symmetric bias
|
59 |
+
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
60 |
+
# to add `scalar_types.float8_e4m3fn` here
|
61 |
+
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
62 |
+
|
63 |
+
|
64 |
+
def _check_marlin_supported(
|
65 |
+
quant_type: ScalarType,
|
66 |
+
group_size: Optional[int],
|
67 |
+
has_zp: bool,
|
68 |
+
device_capability: Optional[int] = None,
|
69 |
+
) -> Tuple[bool, Optional[str]]:
|
70 |
+
|
71 |
+
if device_capability is None:
|
72 |
+
capability_tuple = torch.cuda.get_device_capability()
|
73 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
74 |
+
|
75 |
+
supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
|
76 |
+
|
77 |
+
if quant_type not in supported_types:
|
78 |
+
return (
|
79 |
+
False,
|
80 |
+
f"Marlin does not support weight_bits = {quant_type}. "
|
81 |
+
f"Only types = {supported_types} "
|
82 |
+
f"are supported (for group_size = {group_size}, "
|
83 |
+
f"device_capability = {device_capability}, zp = {has_zp}).",
|
84 |
+
)
|
85 |
+
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
86 |
+
return (
|
87 |
+
False,
|
88 |
+
f"Marlin does not support group_size = {group_size}. "
|
89 |
+
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
90 |
+
"are supported.",
|
91 |
+
)
|
92 |
+
|
93 |
+
return True, None
|
94 |
+
|
95 |
+
|
96 |
+
def check_marlin_supported(
|
97 |
+
quant_type: ScalarType,
|
98 |
+
group_size: int,
|
99 |
+
has_zp: bool = False,
|
100 |
+
device_capability: Optional[int] = None,
|
101 |
+
) -> bool:
|
102 |
+
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
|
103 |
+
return cond
|
104 |
+
|
105 |
+
|
106 |
+
def verify_marlin_supported(
|
107 |
+
quant_type: ScalarType, group_size: int, has_zp: bool = False
|
108 |
+
) -> None:
|
109 |
+
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
|
110 |
+
if not cond:
|
111 |
+
assert err_msg is not None
|
112 |
+
raise ValueError(err_msg)
|
113 |
+
|
114 |
+
|
115 |
+
def verify_marlin_supports_shape(
|
116 |
+
output_size_per_partition: int,
|
117 |
+
input_size_per_partition: int,
|
118 |
+
input_size: int,
|
119 |
+
group_size: int,
|
120 |
+
) -> None:
|
121 |
+
|
122 |
+
# Validate output_size_per_partition
|
123 |
+
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
|
124 |
+
raise ValueError(
|
125 |
+
f"Weight output_size_per_partition = "
|
126 |
+
f"{output_size_per_partition} is not divisible by "
|
127 |
+
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
|
128 |
+
"Consider reducing tensor_parallel_size or running "
|
129 |
+
"with --quantization gptq."
|
130 |
+
)
|
131 |
+
|
132 |
+
# Validate input_size_per_partition
|
133 |
+
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
|
134 |
+
raise ValueError(
|
135 |
+
f"Weight input_size_per_partition = "
|
136 |
+
f"{input_size_per_partition} is not divisible "
|
137 |
+
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
|
138 |
+
"Consider reducing tensor_parallel_size or running "
|
139 |
+
"with --quantization gptq."
|
140 |
+
)
|
141 |
+
|
142 |
+
if group_size < input_size and input_size_per_partition % group_size != 0:
|
143 |
+
raise ValueError(
|
144 |
+
f"Weight input_size_per_partition = {input_size_per_partition}"
|
145 |
+
f" is not divisible by group_size = {group_size}."
|
146 |
+
"Consider reducing tensor_parallel_size or running "
|
147 |
+
"with --quantization gptq."
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
def check_marlin_supports_shape(
|
152 |
+
output_size_per_partition: int,
|
153 |
+
input_size_per_partition: int,
|
154 |
+
input_size: int,
|
155 |
+
group_size: int,
|
156 |
+
) -> Tuple[bool, Optional[str]]:
|
157 |
+
try:
|
158 |
+
verify_marlin_supports_shape(
|
159 |
+
output_size_per_partition, input_size_per_partition, input_size, group_size
|
160 |
+
)
|
161 |
+
except ValueError as e:
|
162 |
+
return False, e.__str__()
|
163 |
+
return True, None
|
164 |
+
|
165 |
+
|
166 |
+
def marlin_make_workspace(
|
167 |
+
output_size_per_partition: int, device: torch.device
|
168 |
+
) -> torch.Tensor:
|
169 |
+
max_workspace_size = (
|
170 |
+
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
|
171 |
+
) * GPTQ_MARLIN_MAX_PARALLEL
|
172 |
+
|
173 |
+
return torch.zeros(
|
174 |
+
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
179 |
+
return (not act_order) or (act_order and not is_row_parallel)
|
180 |
+
|
181 |
+
|
182 |
+
def marlin_repeat_scales_on_all_ranks(
|
183 |
+
act_order: bool, group_size: int, is_row_parallel: bool
|
184 |
+
) -> bool:
|
185 |
+
# Need to repeat scales on every rank if act_ordering or
|
186 |
+
# channelwise and RowParallelLinear
|
187 |
+
is_channelwise = group_size == -1
|
188 |
+
return act_order or (is_channelwise and is_row_parallel)
|
189 |
+
|
190 |
+
|
191 |
+
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
192 |
+
return torch.nn.Parameter(
|
193 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
198 |
+
return torch.nn.Parameter(
|
199 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
204 |
+
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
205 |
+
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
206 |
+
|
207 |
+
|
208 |
+
def get_scale_perms():
|
209 |
+
scale_perm: List[int] = []
|
210 |
+
for i in range(8):
|
211 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
212 |
+
scale_perm_single: List[int] = []
|
213 |
+
for i in range(4):
|
214 |
+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
215 |
+
return scale_perm, scale_perm_single
|
216 |
+
|
217 |
+
|
218 |
+
def marlin_permute_scales(
|
219 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
220 |
+
) -> torch.Tensor:
|
221 |
+
|
222 |
+
scale_perm, scale_perm_single = get_scale_perms()
|
223 |
+
if group_size < size_k and group_size != -1:
|
224 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
225 |
+
else:
|
226 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
227 |
+
s = s.reshape((-1, size_n)).contiguous()
|
228 |
+
|
229 |
+
return s
|
230 |
+
|
231 |
+
|
232 |
+
def marlin_moe_permute_scales(
|
233 |
+
s: torch.Tensor,
|
234 |
+
size_k: int,
|
235 |
+
size_n: int,
|
236 |
+
group_size: int,
|
237 |
+
):
|
238 |
+
num_experts = s.shape[0]
|
239 |
+
output = torch.empty(
|
240 |
+
(num_experts, s.shape[1], s.shape[2]),
|
241 |
+
device=s.device,
|
242 |
+
dtype=s.dtype,
|
243 |
+
)
|
244 |
+
|
245 |
+
for e in range(num_experts):
|
246 |
+
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
247 |
+
return output
|
248 |
+
|
249 |
+
|
250 |
+
def marlin_zero_points(
|
251 |
+
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
252 |
+
) -> torch.Tensor:
|
253 |
+
# Permute zero-points in a similar way to scales, but do not use the
|
254 |
+
# "single" permutation, since zero-points are applied on every MMA
|
255 |
+
scale_perm, _ = get_scale_perms()
|
256 |
+
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
257 |
+
|
258 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
259 |
+
if num_bits == 4:
|
260 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
261 |
+
elif num_bits == 8:
|
262 |
+
interleave = numpy.array([0, 2, 1, 3])
|
263 |
+
else:
|
264 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
265 |
+
|
266 |
+
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
267 |
+
zp = zp.reshape((-1, size_n)).contiguous()
|
268 |
+
zp = pack_cols(zp, num_bits, size_k, size_n)
|
269 |
+
|
270 |
+
return zp
|
271 |
+
|
272 |
+
|
273 |
+
def awq_to_marlin_zero_points(
|
274 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
275 |
+
) -> torch.Tensor:
|
276 |
+
# AWQ zero-points are quantized and packed on the column dim.
|
277 |
+
# In addition, the values are permuted based on dequantizer.
|
278 |
+
# Here we undo both of these, and then apply marlin permutation
|
279 |
+
# and pack it back.
|
280 |
+
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
281 |
+
|
282 |
+
# Undo interleaving (use argsort(..) to get inverse perm)
|
283 |
+
if num_bits == 4:
|
284 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
285 |
+
elif num_bits == 8:
|
286 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
287 |
+
else:
|
288 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
289 |
+
|
290 |
+
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
291 |
+
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
292 |
+
|
293 |
+
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
294 |
+
return marlin_zp
|
295 |
+
|
296 |
+
|
297 |
+
def moe_awq_to_marlin_zero_points(
|
298 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
299 |
+
):
|
300 |
+
num_experts = q_zp_packed.shape[0]
|
301 |
+
output = torch.empty(
|
302 |
+
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
303 |
+
device=q_zp_packed.device,
|
304 |
+
dtype=q_zp_packed.dtype,
|
305 |
+
)
|
306 |
+
for e in range(num_experts):
|
307 |
+
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
308 |
+
return output
|
309 |
+
|
310 |
+
|
311 |
+
def apply_gptq_marlin_linear(
|
312 |
+
input: torch.Tensor,
|
313 |
+
weight: torch.Tensor,
|
314 |
+
weight_scale: torch.Tensor,
|
315 |
+
weight_zp: torch.Tensor,
|
316 |
+
g_idx: torch.Tensor,
|
317 |
+
g_idx_sort_indices: torch.Tensor,
|
318 |
+
workspace: torch.Tensor,
|
319 |
+
wtype: ScalarType,
|
320 |
+
output_size_per_partition: int,
|
321 |
+
input_size_per_partition: int,
|
322 |
+
is_k_full: bool,
|
323 |
+
bias: Optional[torch.Tensor] = None,
|
324 |
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
325 |
+
) -> torch.Tensor:
|
326 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
327 |
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
328 |
+
|
329 |
+
output = ops.gptq_marlin_gemm(
|
330 |
+
reshaped_x,
|
331 |
+
weight,
|
332 |
+
weight_scale,
|
333 |
+
weight_zp,
|
334 |
+
g_idx,
|
335 |
+
g_idx_sort_indices,
|
336 |
+
workspace,
|
337 |
+
wtype,
|
338 |
+
size_m=reshaped_x.shape[0],
|
339 |
+
size_n=output_size_per_partition,
|
340 |
+
size_k=input_size_per_partition,
|
341 |
+
is_k_full=is_k_full,
|
342 |
+
has_zp=False,
|
343 |
+
use_fp32_reduce=use_fp32_reduce,
|
344 |
+
is_zp_float=False,
|
345 |
+
)
|
346 |
+
|
347 |
+
if bias is not None:
|
348 |
+
output.add_(bias) # In-place add
|
349 |
+
|
350 |
+
return output.reshape(out_shape)
|
351 |
+
|
352 |
+
|
353 |
+
def apply_awq_marlin_linear(
|
354 |
+
input: torch.Tensor,
|
355 |
+
weight: torch.Tensor,
|
356 |
+
weight_scale: torch.Tensor,
|
357 |
+
weight_zp: torch.Tensor,
|
358 |
+
g_idx: torch.Tensor,
|
359 |
+
g_idx_sort_indices: torch.Tensor,
|
360 |
+
workspace: torch.Tensor,
|
361 |
+
quant_type: ScalarType,
|
362 |
+
output_size_per_partition: int,
|
363 |
+
input_size_per_partition: int,
|
364 |
+
bias: Optional[torch.Tensor] = None,
|
365 |
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
366 |
+
) -> torch.Tensor:
|
367 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
368 |
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
369 |
+
|
370 |
+
output = ops.gptq_marlin_gemm(
|
371 |
+
reshaped_x,
|
372 |
+
weight,
|
373 |
+
weight_scale,
|
374 |
+
weight_zp,
|
375 |
+
g_idx,
|
376 |
+
g_idx_sort_indices,
|
377 |
+
workspace,
|
378 |
+
quant_type,
|
379 |
+
size_m=reshaped_x.shape[0],
|
380 |
+
size_n=output_size_per_partition,
|
381 |
+
size_k=input_size_per_partition,
|
382 |
+
is_k_full=True,
|
383 |
+
has_zp=True,
|
384 |
+
use_fp32_reduce=use_fp32_reduce,
|
385 |
+
is_zp_float=False,
|
386 |
+
)
|
387 |
+
|
388 |
+
if bias is not None:
|
389 |
+
output.add_(bias) # In-place add
|
390 |
+
|
391 |
+
return output.reshape(out_shape)
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import quantization as ops
|
6 |
+
|
7 |
+
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
8 |
+
|
9 |
+
|
10 |
+
def is_fp8_marlin_supported():
|
11 |
+
capability = torch.cuda.get_device_capability()
|
12 |
+
capability = capability[0] * 10 + capability[1]
|
13 |
+
return capability >= 80
|
14 |
+
|
15 |
+
|
16 |
+
def apply_fp8_marlin_linear(
|
17 |
+
input: torch.Tensor,
|
18 |
+
weight: torch.Tensor,
|
19 |
+
weight_scale: torch.Tensor,
|
20 |
+
workspace: torch.Tensor,
|
21 |
+
size_n: int,
|
22 |
+
size_k: int,
|
23 |
+
bias: Optional[torch.Tensor],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
# For GPUs that lack FP8 hardware support, we can leverage the
|
26 |
+
# Marlin kernel for fast weight-only FP8 quantization
|
27 |
+
|
28 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
29 |
+
out_shape = input.shape[:-1] + (size_n,)
|
30 |
+
|
31 |
+
output = ops.fp8_marlin_gemm(
|
32 |
+
a=reshaped_x,
|
33 |
+
b_q_weight=weight,
|
34 |
+
b_scales=weight_scale,
|
35 |
+
workspace=workspace,
|
36 |
+
num_bits=8,
|
37 |
+
size_m=reshaped_x.shape[0],
|
38 |
+
size_n=size_n,
|
39 |
+
size_k=size_k,
|
40 |
+
)
|
41 |
+
|
42 |
+
if bias is not None:
|
43 |
+
output.add_(bias) # In-place add
|
44 |
+
|
45 |
+
return output.reshape(out_shape)
|
46 |
+
|
47 |
+
|
48 |
+
def prepare_fp8_layer_for_marlin(
|
49 |
+
layer: torch.nn.Module, strategy: str = "tensor"
|
50 |
+
) -> None:
|
51 |
+
part_size_n = layer.output_size_per_partition
|
52 |
+
part_size_k = layer.input_size_per_partition
|
53 |
+
|
54 |
+
device = layer.weight.device
|
55 |
+
|
56 |
+
# WORKSPACE
|
57 |
+
layer.workspace = marlin_make_workspace(part_size_n, device)
|
58 |
+
|
59 |
+
# WEIGHT
|
60 |
+
# Repack weights to marlin format
|
61 |
+
marlin_qweight = ops.gptq_marlin_repack(
|
62 |
+
b_q_weight=pack_fp8_to_int32(layer.weight),
|
63 |
+
perm=torch.empty(0, dtype=torch.int, device=device),
|
64 |
+
size_k=part_size_k,
|
65 |
+
size_n=part_size_n,
|
66 |
+
num_bits=8,
|
67 |
+
)
|
68 |
+
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
69 |
+
|
70 |
+
# WEIGHT SCALES
|
71 |
+
scales = layer.weight_scale.to(layer.orig_dtype)
|
72 |
+
# Permute scales
|
73 |
+
marlin_scales = marlin_permute_scales(
|
74 |
+
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
|
75 |
+
)
|
76 |
+
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
77 |
+
|
78 |
+
|
79 |
+
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
Repack FP8 weights to gptq format (packed int32 elements)
|
82 |
+
"""
|
83 |
+
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
84 |
+
assert fp8_tensor.shape[0] % 4 == 0
|
85 |
+
|
86 |
+
# Reshape to prepare for packing
|
87 |
+
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
88 |
+
|
89 |
+
# Convert fp8 to uint8 (byte) representation
|
90 |
+
byte_tensor = reshaped.view(torch.uint8)
|
91 |
+
|
92 |
+
# Pack 4 uint8 values into one int32
|
93 |
+
packed = (
|
94 |
+
byte_tensor[:, 0].to(torch.int32)
|
95 |
+
| (byte_tensor[:, 1].to(torch.int32) << 8)
|
96 |
+
| (byte_tensor[:, 2].to(torch.int32) << 16)
|
97 |
+
| (byte_tensor[:, 3].to(torch.int32) << 24)
|
98 |
+
)
|
99 |
+
|
100 |
+
return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from quantization.scalar_type import ScalarType
|
9 |
+
|
10 |
+
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
11 |
+
from .quant_utils import (
|
12 |
+
get_pack_factor,
|
13 |
+
gptq_quantize_weights,
|
14 |
+
quantize_weights,
|
15 |
+
sort_weights,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class MarlinWorkspace:
|
20 |
+
|
21 |
+
def __init__(self, out_features, min_thread_n, max_parallel):
|
22 |
+
assert (
|
23 |
+
out_features % min_thread_n == 0
|
24 |
+
), "out_features = {} is undivisible by min_thread_n = {}".format(
|
25 |
+
out_features, min_thread_n
|
26 |
+
)
|
27 |
+
|
28 |
+
max_workspace_size = (out_features // min_thread_n) * max_parallel
|
29 |
+
|
30 |
+
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
|
31 |
+
|
32 |
+
|
33 |
+
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
34 |
+
assert q_w.shape == (size_k, size_n)
|
35 |
+
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
36 |
+
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
37 |
+
|
38 |
+
# Permute weights to 16x64 marlin tiles
|
39 |
+
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
40 |
+
q_w = q_w.permute((0, 2, 1, 3))
|
41 |
+
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
42 |
+
|
43 |
+
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
44 |
+
|
45 |
+
return q_w
|
46 |
+
|
47 |
+
|
48 |
+
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
49 |
+
# Permute
|
50 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
51 |
+
|
52 |
+
# Pack
|
53 |
+
pack_factor = get_pack_factor(num_bits)
|
54 |
+
orig_device = q_w.device
|
55 |
+
|
56 |
+
q_w = q_w.cpu().numpy().astype(np.uint32)
|
57 |
+
|
58 |
+
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
59 |
+
for i in range(pack_factor):
|
60 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
61 |
+
|
62 |
+
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
63 |
+
|
64 |
+
return q_packed
|
65 |
+
|
66 |
+
|
67 |
+
def get_weight_perm(num_bits: int):
|
68 |
+
perm_list: List[int] = []
|
69 |
+
for i in range(32):
|
70 |
+
perm1: List[int] = []
|
71 |
+
col = i // 4
|
72 |
+
for block in [0, 1]:
|
73 |
+
for row in [
|
74 |
+
2 * (i % 4),
|
75 |
+
2 * (i % 4) + 1,
|
76 |
+
2 * (i % 4 + 4),
|
77 |
+
2 * (i % 4 + 4) + 1,
|
78 |
+
]:
|
79 |
+
perm1.append(16 * row + col + 8 * block)
|
80 |
+
for j in range(4):
|
81 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
82 |
+
|
83 |
+
perm = np.array(perm_list)
|
84 |
+
|
85 |
+
if num_bits == 4:
|
86 |
+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
87 |
+
elif num_bits == 8:
|
88 |
+
interleave = np.array([0, 2, 1, 3])
|
89 |
+
else:
|
90 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
91 |
+
|
92 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
93 |
+
perm = torch.from_numpy(perm)
|
94 |
+
return perm
|
95 |
+
|
96 |
+
|
97 |
+
def marlin_quantize(
|
98 |
+
w: torch.Tensor,
|
99 |
+
quant_type: ScalarType,
|
100 |
+
group_size: int,
|
101 |
+
act_order: bool,
|
102 |
+
test_perm: Optional[torch.Tensor] = None,
|
103 |
+
):
|
104 |
+
size_k, size_n = w.shape
|
105 |
+
num_bits = quant_type.size_bits
|
106 |
+
|
107 |
+
# Normalize group_size
|
108 |
+
if group_size == -1:
|
109 |
+
group_size = size_k
|
110 |
+
assert group_size <= size_k
|
111 |
+
|
112 |
+
# Quantize (and apply act_order if provided)
|
113 |
+
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
114 |
+
w, quant_type, group_size, act_order, test_perm
|
115 |
+
)
|
116 |
+
|
117 |
+
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
118 |
+
# increasing
|
119 |
+
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
120 |
+
if act_order:
|
121 |
+
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
122 |
+
|
123 |
+
# Reformat to marlin
|
124 |
+
weight_perm = get_weight_perm(num_bits)
|
125 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
126 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
127 |
+
|
128 |
+
# Create result
|
129 |
+
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
130 |
+
for i in range(len(res_list)):
|
131 |
+
res_list[i] = res_list[i].to(w.device)
|
132 |
+
|
133 |
+
return res_list
|
134 |
+
|
135 |
+
|
136 |
+
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
|
137 |
+
size_k, size_n = w.shape
|
138 |
+
|
139 |
+
# Normalize group_size
|
140 |
+
if group_size == -1:
|
141 |
+
group_size = size_k
|
142 |
+
assert group_size <= size_k
|
143 |
+
|
144 |
+
# Detect num groups
|
145 |
+
assert size_k % group_size == 0
|
146 |
+
num_groups = size_k // group_size
|
147 |
+
|
148 |
+
# Quantize with zp
|
149 |
+
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
|
150 |
+
|
151 |
+
# Reformat to marlin
|
152 |
+
weight_perm = get_weight_perm(quant_type.size_bits)
|
153 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
|
154 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
155 |
+
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
|
156 |
+
|
157 |
+
# Create result
|
158 |
+
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
159 |
+
for i in range(len(res_list)):
|
160 |
+
res_list[i] = res_list[i].to(w.device)
|
161 |
+
|
162 |
+
return res_list
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
import random
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import numpy
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from quantization.scalar_type import ScalarType
|
10 |
+
|
11 |
+
from .marlin_utils_test import marlin_weights
|
12 |
+
from .quant_utils import gptq_quantize_weights
|
13 |
+
|
14 |
+
|
15 |
+
# This is PyTorch implementation of main part of reorder_meta()
|
16 |
+
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
17 |
+
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
18 |
+
# GEMM decides upon layout of this matrix, and at the moment for the
|
19 |
+
# sparse GEMM executed on tensor cores, this is layout described by
|
20 |
+
# ColumnMajorInterleaved<2> data structure, in
|
21 |
+
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
22 |
+
# reordering of meta matrix into meta_reordered matrix calculated
|
23 |
+
# according to these segments of CUTLASS code is re-implemented here.
|
24 |
+
# Note that this calculation produces offsets for scattering metadata
|
25 |
+
# matrix elements into reordered metadata matrix elements (or,
|
26 |
+
# equivalently, for gathering reordered metadata matrix element back
|
27 |
+
# into metadata matrix elements).
|
28 |
+
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
29 |
+
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
30 |
+
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
31 |
+
|
32 |
+
# Reorder the rows, then swizzle the 2x2 blocks.
|
33 |
+
group_x = 64
|
34 |
+
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
35 |
+
|
36 |
+
dst_rows = (
|
37 |
+
dst_rows // group_x * group_x
|
38 |
+
+ (dst_rows % 2) * 2
|
39 |
+
+ (dst_rows % 8) // 4
|
40 |
+
+ ((dst_rows % group_y) % 4) // 2 * 32
|
41 |
+
+ ((dst_rows % group_x) // 8) * 4
|
42 |
+
)
|
43 |
+
|
44 |
+
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
45 |
+
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
46 |
+
dst_rows += topright - bottomleft
|
47 |
+
dst_cols -= topright - bottomleft
|
48 |
+
|
49 |
+
# Assumed that meta tensor is to be stored in CUTLASS
|
50 |
+
# InterleavedColumnMajor layout, and reverse engineered
|
51 |
+
# corresponding code to store values into this tensor.
|
52 |
+
interleave = 2
|
53 |
+
cols_maj = dst_cols // interleave
|
54 |
+
cols_min = dst_cols % interleave
|
55 |
+
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
56 |
+
|
57 |
+
|
58 |
+
# This function converts dense matrix into sparse semi-structured
|
59 |
+
# representation, producing "compressed" matrix, in the layout used by
|
60 |
+
# CUTLASS backend, and corresponding metadata matrix.
|
61 |
+
def sparse_semi_structured_from_dense_cutlass(dense):
|
62 |
+
if dense.dim() != 2:
|
63 |
+
raise RuntimeError(
|
64 |
+
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
65 |
+
)
|
66 |
+
|
67 |
+
m, k = dense.shape
|
68 |
+
device = dense.device
|
69 |
+
|
70 |
+
meta_dtype = torch.int8
|
71 |
+
if dense.dtype == torch.int8:
|
72 |
+
meta_dtype = torch.int32
|
73 |
+
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
74 |
+
meta_dtype = torch.int16
|
75 |
+
else:
|
76 |
+
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
77 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
78 |
+
if quadbits_per_meta_elem not in (4, 8):
|
79 |
+
raise RuntimeError("Invalid number of elements per meta element calculated")
|
80 |
+
|
81 |
+
if meta_dtype == torch.int32:
|
82 |
+
if m % 16 != 0:
|
83 |
+
raise RuntimeError(
|
84 |
+
f"Number of rows of dense matrix {m} must be divisible by 16"
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
if m % 32 != 0:
|
88 |
+
raise RuntimeError(
|
89 |
+
f"Number of rows of dense matrix {m} must be divisible by 32"
|
90 |
+
)
|
91 |
+
if k % (4 * quadbits_per_meta_elem) != 0:
|
92 |
+
raise RuntimeError(
|
93 |
+
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
94 |
+
)
|
95 |
+
|
96 |
+
if dense.dtype != torch.float:
|
97 |
+
ksparse = 4
|
98 |
+
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
99 |
+
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
100 |
+
else:
|
101 |
+
ksparse = 2
|
102 |
+
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
103 |
+
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
104 |
+
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
105 |
+
|
106 |
+
# Encoding quadruples of True/False values as follows:
|
107 |
+
# [True, True, False, False] -> 0b0100
|
108 |
+
# [True, False, True, False] -> 0b1000
|
109 |
+
# [False, True, True, False] -> 0b1001
|
110 |
+
# [True, False, False, True ] -> 0b1100
|
111 |
+
# [False, True, False, True ] -> 0b1101
|
112 |
+
# [False, False, True, True ] -> 0b1110
|
113 |
+
# Thus, lower two bits in the encoding are index of the True value
|
114 |
+
# at the lowest index in the quadruple, and the higher two bits in
|
115 |
+
# the encoding are index of the other True value in the quadruple.
|
116 |
+
# In case there are less than two True values, than False value or
|
117 |
+
# values at some index or indices are considered True for the
|
118 |
+
# encoding. In case there are more than two True values, then the
|
119 |
+
# excess True value(s) at some indices are considered False for
|
120 |
+
# the encoding. The exact encodings used for these cases are as
|
121 |
+
# follows:
|
122 |
+
# [False, False, False, False] -> 0b1110
|
123 |
+
# [False, False, False, True ] -> 0b1110
|
124 |
+
# [False, False, True, False] -> 0b1110
|
125 |
+
# [False, True, False, False] -> 0b1001
|
126 |
+
# [False, True, True, True ] -> 0b1101
|
127 |
+
# [True, False, False, False] -> 0b1000
|
128 |
+
# [True, False, True, True ] -> 0b1100
|
129 |
+
# [True, True, False, True ] -> 0b0100
|
130 |
+
# [True, True, True, False] -> 0b0100
|
131 |
+
# [True, True, True, True ] -> 0b0100
|
132 |
+
# These particular encodings are chosen, with the help of Espresso
|
133 |
+
# logic minimizer software, for the purpose of minimization of
|
134 |
+
# corresponding Boolean functions, that translate non-zero flags
|
135 |
+
# into encoding bits. Note also possible choices for the first
|
136 |
+
# and last of these encodings were limited only to (0b0100,
|
137 |
+
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
138 |
+
# case.
|
139 |
+
|
140 |
+
expr0 = m0 & m1
|
141 |
+
expr1 = ~m0 & m1
|
142 |
+
expr2 = ~m0 & ~m1
|
143 |
+
bit0 = expr1
|
144 |
+
bit1 = expr2
|
145 |
+
bit2 = expr0 | expr2 | m3
|
146 |
+
bit3 = expr1 | ~m1
|
147 |
+
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
148 |
+
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
149 |
+
|
150 |
+
if dense.dtype != torch.float:
|
151 |
+
sparse0 = dense_4.gather(
|
152 |
+
-1, idxs0.unsqueeze(-1)
|
153 |
+
) # type: ignore[possibly-undefined]
|
154 |
+
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
155 |
+
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
156 |
+
else:
|
157 |
+
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
|
158 |
+
m, k // 2
|
159 |
+
) # type: ignore[possibly-undefined]
|
160 |
+
|
161 |
+
meta_4 = idxs0 | (idxs1 << 2)
|
162 |
+
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
163 |
+
|
164 |
+
if quadbits_per_meta_elem == 4:
|
165 |
+
meta = (
|
166 |
+
meta_n[:, :, 0]
|
167 |
+
| (meta_n[:, :, 1] << 4)
|
168 |
+
| (meta_n[:, :, 2] << 8)
|
169 |
+
| (meta_n[:, :, 3] << 12)
|
170 |
+
)
|
171 |
+
elif quadbits_per_meta_elem == 8:
|
172 |
+
meta = (
|
173 |
+
meta_n[:, :, 0]
|
174 |
+
| (meta_n[:, :, 1] << 4)
|
175 |
+
| (meta_n[:, :, 2] << 8)
|
176 |
+
| (meta_n[:, :, 3] << 12)
|
177 |
+
| (meta_n[:, :, 4] << 16)
|
178 |
+
| (meta_n[:, :, 5] << 20)
|
179 |
+
| (meta_n[:, :, 6] << 24)
|
180 |
+
| (meta_n[:, :, 7] << 28)
|
181 |
+
)
|
182 |
+
|
183 |
+
# Reorder meta tensor elements.
|
184 |
+
meta_reordered = meta.new_empty(
|
185 |
+
(m * meta_ncols,)
|
186 |
+
) # type: ignore[possibly-undefined]
|
187 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
188 |
+
m, meta_ncols, meta_dtype, device
|
189 |
+
)
|
190 |
+
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
191 |
+
|
192 |
+
return (sparse, meta_reordered.view(m, meta_ncols))
|
193 |
+
|
194 |
+
|
195 |
+
# This function performs reverse of the function above - it
|
196 |
+
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
197 |
+
# in the layout used by CUTLASS backend, and accompanying metadata
|
198 |
+
# matrix.
|
199 |
+
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
200 |
+
if sparse.dim() != 2:
|
201 |
+
raise RuntimeError(
|
202 |
+
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
203 |
+
)
|
204 |
+
|
205 |
+
m, k = sparse.shape
|
206 |
+
device = sparse.device
|
207 |
+
|
208 |
+
if meta_reordered.dim() != 2:
|
209 |
+
raise RuntimeError(
|
210 |
+
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
211 |
+
)
|
212 |
+
if meta_reordered.device != device:
|
213 |
+
raise RuntimeError(
|
214 |
+
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
215 |
+
)
|
216 |
+
|
217 |
+
meta_dtype = meta_reordered.dtype
|
218 |
+
if meta_dtype not in (torch.int16, torch.int32):
|
219 |
+
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
220 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
221 |
+
|
222 |
+
ksparse = 4 if sparse.dtype != torch.float else 2
|
223 |
+
|
224 |
+
meta_nrows, meta_ncols = meta_reordered.shape
|
225 |
+
if meta_nrows != m:
|
226 |
+
raise RuntimeError(
|
227 |
+
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
228 |
+
)
|
229 |
+
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
230 |
+
raise RuntimeError(
|
231 |
+
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
232 |
+
"expected according to the number of columns of meta matrix"
|
233 |
+
)
|
234 |
+
|
235 |
+
# Undo meta tensor elements reordering.
|
236 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
237 |
+
m, meta_ncols, meta_dtype, device
|
238 |
+
)
|
239 |
+
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
240 |
+
|
241 |
+
# Unpack sparse tensor back to original dense tensor, using
|
242 |
+
# information provided by meta tensor. Note that torch.float
|
243 |
+
# datatype is handled pretty much the same as
|
244 |
+
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
245 |
+
# value is encoded as if underlying 8 bytes contain four
|
246 |
+
# torch.half/torch.bfloat16 values, where either first two or last
|
247 |
+
# two are zeros.
|
248 |
+
meta_2 = torch.empty(
|
249 |
+
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
250 |
+
dtype=meta_dtype,
|
251 |
+
device=device,
|
252 |
+
)
|
253 |
+
if quadbits_per_meta_elem == 4:
|
254 |
+
meta_2[:, :, 0] = meta & 0b11
|
255 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
256 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
257 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
258 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
259 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
260 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
261 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
262 |
+
elif quadbits_per_meta_elem == 8:
|
263 |
+
meta_2[:, :, 0] = meta & 0b11
|
264 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
265 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
266 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
267 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
268 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
269 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
270 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
271 |
+
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
272 |
+
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
273 |
+
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
274 |
+
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
275 |
+
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
276 |
+
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
277 |
+
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
278 |
+
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
279 |
+
|
280 |
+
dense_offsets = meta_2.view(-1) + (
|
281 |
+
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
282 |
+
).view(-1, 1).repeat(1, 2).view(-1)
|
283 |
+
|
284 |
+
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
285 |
+
if sparse.dtype != torch.float:
|
286 |
+
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
287 |
+
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
288 |
+
else:
|
289 |
+
dense.view(torch.half).scatter_(
|
290 |
+
0, dense_offsets, sparse.view(torch.half).view(-1)
|
291 |
+
)
|
292 |
+
|
293 |
+
return dense.view(m, 2 * k)
|
294 |
+
|
295 |
+
|
296 |
+
def mask_creator(tensor):
|
297 |
+
"""
|
298 |
+
Class for creating N:M sparsity masks.
|
299 |
+
Masks will be created using the N:M ratio, where for every block of
|
300 |
+
M weights, N will be pruned based on ranked weight value. Each mask
|
301 |
+
will correspond to the given tensor.
|
302 |
+
|
303 |
+
:param N: The number of weights in a group to keep
|
304 |
+
:param M: The size of a weight group
|
305 |
+
"""
|
306 |
+
N = 2
|
307 |
+
M = 4
|
308 |
+
|
309 |
+
mask = None
|
310 |
+
# for i, tensor in enumerate(tensors):
|
311 |
+
if tensor.numel() % M != 0:
|
312 |
+
raise ValueError(
|
313 |
+
f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
|
314 |
+
)
|
315 |
+
|
316 |
+
num_groups = tensor.numel() // M
|
317 |
+
|
318 |
+
# N:M sparsity for linear layers
|
319 |
+
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
320 |
+
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
321 |
+
|
322 |
+
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
323 |
+
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
324 |
+
|
325 |
+
return mask
|
326 |
+
|
327 |
+
|
328 |
+
def inject_24(w, size_k, size_n):
|
329 |
+
assert w.shape == (size_k, size_n)
|
330 |
+
|
331 |
+
mask = mask_creator(w.t()).t().cuda().bool()
|
332 |
+
|
333 |
+
return (mask * w).contiguous(), mask.contiguous()
|
334 |
+
|
335 |
+
|
336 |
+
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
337 |
+
BLOCK_SIZE = 4
|
338 |
+
MAX_NON_ZEROS = 2
|
339 |
+
|
340 |
+
w = w.t().contiguous()
|
341 |
+
|
342 |
+
print("check_24: w.shape = {}".format(w.shape))
|
343 |
+
|
344 |
+
num_rows, num_cols = w.shape
|
345 |
+
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
346 |
+
if _verbose:
|
347 |
+
print(f"Sampled row idxs = {sampled_row_idxs}")
|
348 |
+
|
349 |
+
total_segments = 0
|
350 |
+
non_24_segments = 0
|
351 |
+
for i in sampled_row_idxs:
|
352 |
+
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
353 |
+
total_segments += 1
|
354 |
+
block = w[i, j : j + BLOCK_SIZE]
|
355 |
+
num_nonzero = torch.count_nonzero(block)
|
356 |
+
if num_nonzero > MAX_NON_ZEROS:
|
357 |
+
print("i = {} j = {} block = {}".format(i, j, block))
|
358 |
+
non_24_segments += 1
|
359 |
+
|
360 |
+
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
361 |
+
|
362 |
+
|
363 |
+
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
364 |
+
assert q_24.shape == (size_k, size_n)
|
365 |
+
|
366 |
+
# Remove bias to normalize over 0
|
367 |
+
q_24_no_zp = q_24 - wtype.bias
|
368 |
+
|
369 |
+
# Compress
|
370 |
+
q_24_no_zp = q_24_no_zp.t().contiguous()
|
371 |
+
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
|
372 |
+
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
373 |
+
|
374 |
+
# Restore bias
|
375 |
+
q_24_comp = q_24_no_zp_comp + wtype.bias
|
376 |
+
|
377 |
+
# Resize meta to its actual shape (without moving any data)
|
378 |
+
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
379 |
+
|
380 |
+
return q_24_comp, meta
|
381 |
+
|
382 |
+
|
383 |
+
def get_scale_perms_24():
|
384 |
+
scale_perm: List[int] = []
|
385 |
+
for i in range(8):
|
386 |
+
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
387 |
+
scale_perm_single: List[int] = []
|
388 |
+
for i in range(8):
|
389 |
+
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
390 |
+
return scale_perm, scale_perm_single
|
391 |
+
|
392 |
+
|
393 |
+
def get_weight_perm_24(num_bits: int):
|
394 |
+
perm_list: List[int] = []
|
395 |
+
for i in range(32):
|
396 |
+
perm1: List[int] = []
|
397 |
+
col = i // 4
|
398 |
+
col_o = col // 2
|
399 |
+
for block in [0, 1]:
|
400 |
+
for row in [
|
401 |
+
2 * (i % 4),
|
402 |
+
2 * (i % 4) + 1,
|
403 |
+
2 * (i % 4 + 4),
|
404 |
+
2 * (i % 4 + 4) + 1,
|
405 |
+
]:
|
406 |
+
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
|
407 |
+
for j in range(4):
|
408 |
+
perm_list.extend([p + 1 * j for p in perm1])
|
409 |
+
perm = numpy.array(perm_list)
|
410 |
+
|
411 |
+
if num_bits == 4:
|
412 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
413 |
+
elif num_bits == 8:
|
414 |
+
interleave = numpy.array([0, 2, 1, 3])
|
415 |
+
else:
|
416 |
+
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
417 |
+
|
418 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
419 |
+
perm = torch.from_numpy(perm)
|
420 |
+
return perm
|
421 |
+
|
422 |
+
|
423 |
+
def marlin_permute_scales_24(
|
424 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
425 |
+
) -> torch.Tensor:
|
426 |
+
|
427 |
+
scale_perm, scale_perm_single = get_scale_perms_24()
|
428 |
+
if group_size < size_k and group_size != -1:
|
429 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
430 |
+
else:
|
431 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
432 |
+
s = s.reshape((-1, size_n)).contiguous()
|
433 |
+
|
434 |
+
return s
|
435 |
+
|
436 |
+
|
437 |
+
def marlin_24_quantize(
|
438 |
+
w: torch.Tensor,
|
439 |
+
quant_type: ScalarType,
|
440 |
+
group_size: int,
|
441 |
+
):
|
442 |
+
size_k, size_n = w.shape
|
443 |
+
|
444 |
+
# Normalize group_size
|
445 |
+
if group_size == -1:
|
446 |
+
group_size = size_k
|
447 |
+
assert group_size <= size_k
|
448 |
+
|
449 |
+
# Inject 2:4 sparsity
|
450 |
+
w_24, mask_24 = inject_24(w, size_k, size_n)
|
451 |
+
|
452 |
+
# Quantize
|
453 |
+
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
|
454 |
+
w_24, quant_type, group_size, act_order=False
|
455 |
+
)
|
456 |
+
|
457 |
+
# Compress quantized weight
|
458 |
+
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
|
459 |
+
size_k_comp = size_k // 2
|
460 |
+
|
461 |
+
# Reformat to marlin
|
462 |
+
weight_perm = get_weight_perm_24(quant_type.size_bits)
|
463 |
+
marlin_24_q_w_comp = marlin_weights(
|
464 |
+
q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
|
465 |
+
)
|
466 |
+
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
|
467 |
+
|
468 |
+
# Create result
|
469 |
+
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
470 |
+
for i in range(len(res_list)):
|
471 |
+
res_list[i] = res_list[i].to(w.device)
|
472 |
+
|
473 |
+
return res_list
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .marlin_utils_test import marlin_permute_weights
|
7 |
+
from .quant_utils import get_pack_factor, qqq_quantize_weights
|
8 |
+
|
9 |
+
|
10 |
+
def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
|
11 |
+
# Permute
|
12 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
13 |
+
|
14 |
+
# Pack
|
15 |
+
pack_factor = get_pack_factor(num_bits)
|
16 |
+
orig_device = q_w.device
|
17 |
+
|
18 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
19 |
+
|
20 |
+
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
21 |
+
dtype=numpy.uint32)
|
22 |
+
if group_size == size_k:
|
23 |
+
for i in range(pack_factor):
|
24 |
+
q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
|
25 |
+
else:
|
26 |
+
for i in range(pack_factor):
|
27 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
28 |
+
|
29 |
+
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
30 |
+
|
31 |
+
return q_packed
|
32 |
+
|
33 |
+
|
34 |
+
def get_qqq_scale_perms():
|
35 |
+
scale_perm: List[int] = []
|
36 |
+
for i in range(8):
|
37 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
38 |
+
scale_perm_single: List[int] = []
|
39 |
+
for i in range(4):
|
40 |
+
scale_perm_single.extend(
|
41 |
+
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
42 |
+
return scale_perm, scale_perm_single
|
43 |
+
|
44 |
+
|
45 |
+
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
|
46 |
+
def get_qqq_weight_perm(num_bits: int, quant_type: str):
|
47 |
+
perm_list: List[int] = []
|
48 |
+
for i in range(32):
|
49 |
+
perm1: List[int] = []
|
50 |
+
col = i // 4
|
51 |
+
for block in [0, 1]:
|
52 |
+
for row in [
|
53 |
+
4 * (i % 4),
|
54 |
+
4 * (i % 4) + 1,
|
55 |
+
4 * (i % 4) + 2,
|
56 |
+
4 * (i % 4) + 3,
|
57 |
+
]:
|
58 |
+
perm1.append(16 * row + col + 8 * block)
|
59 |
+
for j in range(4):
|
60 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
61 |
+
|
62 |
+
perm = numpy.array(perm_list)
|
63 |
+
|
64 |
+
assert quant_type in ["per-channel",
|
65 |
+
"per-group"], "not supported quantization type"
|
66 |
+
if num_bits == 4:
|
67 |
+
if quant_type == "per-channel":
|
68 |
+
interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
|
69 |
+
else:
|
70 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
71 |
+
else:
|
72 |
+
raise Exception("num_bits must be 4, got {}".format(num_bits))
|
73 |
+
|
74 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
75 |
+
perm = torch.from_numpy(perm)
|
76 |
+
return perm
|
77 |
+
|
78 |
+
|
79 |
+
def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
|
80 |
+
scale_perm, scale_perm_single = get_qqq_scale_perms()
|
81 |
+
if group_size < size_k and group_size != -1:
|
82 |
+
s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
|
83 |
+
s_channel = s_channel.reshape(
|
84 |
+
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
85 |
+
s_group = s_group.reshape((-1, size_n)).contiguous()
|
86 |
+
else:
|
87 |
+
s_channel = s_channel.reshape(
|
88 |
+
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
89 |
+
s_channel = s_channel.reshape((-1, size_n)).contiguous()
|
90 |
+
|
91 |
+
return s_group, s_channel
|
92 |
+
|
93 |
+
|
94 |
+
def marlin_qqq_quantize(
|
95 |
+
w: torch.Tensor,
|
96 |
+
num_bits: int,
|
97 |
+
group_size: int,
|
98 |
+
):
|
99 |
+
size_k, size_n = w.shape
|
100 |
+
|
101 |
+
# Normalize group_size
|
102 |
+
if group_size == -1:
|
103 |
+
group_size = size_k
|
104 |
+
assert group_size <= size_k
|
105 |
+
quant_type = "per-channel" if group_size == size_k else "per-group"
|
106 |
+
|
107 |
+
# Quantize
|
108 |
+
w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
|
109 |
+
w, num_bits, group_size)
|
110 |
+
|
111 |
+
# Reformat to marlin_qqq
|
112 |
+
weight_perm = get_qqq_weight_perm(num_bits, quant_type)
|
113 |
+
marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
|
114 |
+
weight_perm, group_size)
|
115 |
+
marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
|
116 |
+
s_group, s_channel, size_k, size_n, group_size)
|
117 |
+
|
118 |
+
# Create result
|
119 |
+
res_list = [
|
120 |
+
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
|
121 |
+
]
|
122 |
+
for i in range(len(res_list)):
|
123 |
+
res_list[i] = res_list[i].to(w.device)
|
124 |
+
|
125 |
+
return res_list
|
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file is used for /tests and /benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from quantization.scalar_type import ScalarType, scalar_types
|
9 |
+
|
10 |
+
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
11 |
+
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
12 |
+
|
13 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
14 |
+
|
15 |
+
# Note: this is a hack. We should update each model to register the
|
16 |
+
# stacked params and get it from there instead in a future PR.
|
17 |
+
# fused_name: List[shard_name]
|
18 |
+
FUSED_LAYER_NAME_MAPPING = {
|
19 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
20 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def pack_quantized_values_into_int32(
|
25 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
26 |
+
):
|
27 |
+
# move dim to pack to the end
|
28 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
29 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
30 |
+
w_q_perm = w_q.permute(perm)
|
31 |
+
|
32 |
+
pack_factor = 32 // wtype.size_bits
|
33 |
+
mask = (1 << wtype.size_bits) - 1
|
34 |
+
|
35 |
+
new_shape_perm = list(w_q_perm.shape)
|
36 |
+
assert w_q_perm.shape[-1] % pack_factor == 0
|
37 |
+
new_shape_perm[-1] //= pack_factor
|
38 |
+
|
39 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
40 |
+
for i in range(pack_factor):
|
41 |
+
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
|
42 |
+
|
43 |
+
return res.permute(inv_perm)
|
44 |
+
|
45 |
+
|
46 |
+
def unpack_quantized_values_into_int32(
|
47 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
48 |
+
):
|
49 |
+
# move dim to pack to the end
|
50 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
51 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
52 |
+
w_q_perm = w_q.permute(perm)
|
53 |
+
|
54 |
+
pack_factor = 32 // wtype.size_bits
|
55 |
+
mask = (1 << wtype.size_bits) - 1
|
56 |
+
|
57 |
+
new_shape_perm = list(w_q_perm.shape)
|
58 |
+
new_shape_perm[-1] *= pack_factor
|
59 |
+
|
60 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
61 |
+
for i in range(pack_factor):
|
62 |
+
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
|
63 |
+
|
64 |
+
return res.permute(inv_perm)
|
65 |
+
|
66 |
+
|
67 |
+
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
68 |
+
# prefix: model.layers.0.self_attn.q_proj
|
69 |
+
# proj_name: q_proj
|
70 |
+
proj_name = prefix.split(".")[-1]
|
71 |
+
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
72 |
+
shard_prefixes = [
|
73 |
+
prefix.replace(proj_name, shard_proj_name)
|
74 |
+
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
75 |
+
]
|
76 |
+
|
77 |
+
is_skipped = None
|
78 |
+
for shard_prefix in shard_prefixes:
|
79 |
+
is_shard_skipped = shard_prefix in ignored_layers
|
80 |
+
|
81 |
+
if is_skipped is None:
|
82 |
+
is_skipped = is_shard_skipped
|
83 |
+
elif is_shard_skipped != is_skipped:
|
84 |
+
raise ValueError(
|
85 |
+
f"Detected some but not all shards of {prefix} "
|
86 |
+
"are quantized. All shards of fused layers "
|
87 |
+
"to have the same precision."
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
is_skipped = prefix in ignored_layers
|
91 |
+
|
92 |
+
assert is_skipped is not None
|
93 |
+
return is_skipped
|
94 |
+
|
95 |
+
|
96 |
+
def get_pack_factor(num_bits):
|
97 |
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
98 |
+
return 32 // num_bits
|
99 |
+
|
100 |
+
|
101 |
+
def permute_rows(
|
102 |
+
q_w: torch.Tensor,
|
103 |
+
w_ref: torch.Tensor,
|
104 |
+
group_size: int,
|
105 |
+
test_perm: Optional[torch.Tensor] = None,
|
106 |
+
):
|
107 |
+
assert q_w.shape == w_ref.shape
|
108 |
+
|
109 |
+
orig_device = q_w.device
|
110 |
+
k_size, _ = q_w.shape
|
111 |
+
|
112 |
+
g_idx = torch.zeros((k_size,), dtype=torch.int32)
|
113 |
+
for i in range(k_size):
|
114 |
+
g_idx[i] = i // group_size
|
115 |
+
|
116 |
+
# Simulate act_order by doing a random permutation on K
|
117 |
+
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
118 |
+
|
119 |
+
g_idx = g_idx[rand_perm].contiguous()
|
120 |
+
q_w = q_w[rand_perm, :].contiguous()
|
121 |
+
w_ref = w_ref[rand_perm, :].contiguous()
|
122 |
+
|
123 |
+
return (
|
124 |
+
w_ref.to(device=orig_device),
|
125 |
+
q_w.to(device=orig_device),
|
126 |
+
g_idx.to(device=orig_device),
|
127 |
+
rand_perm.to(device=orig_device),
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def quantize_weights(
|
132 |
+
w: torch.Tensor,
|
133 |
+
quant_type: ScalarType,
|
134 |
+
group_size: Optional[int],
|
135 |
+
zero_points: bool = False,
|
136 |
+
ref_zero_points_after_scales: bool = False,
|
137 |
+
):
|
138 |
+
assert (
|
139 |
+
quant_type.is_integer()
|
140 |
+
), "Floating point quantization may work but has not been tested"
|
141 |
+
assert not zero_points or group_size is not None, (
|
142 |
+
"to have group zero points, group_size must be provided "
|
143 |
+
"(-1 group_size is channelwise)"
|
144 |
+
)
|
145 |
+
|
146 |
+
orig_device = w.device
|
147 |
+
orig_type = w.dtype
|
148 |
+
size_k, size_n = w.shape
|
149 |
+
|
150 |
+
assert w.is_floating_point(), "w must be float"
|
151 |
+
|
152 |
+
if group_size == -1:
|
153 |
+
group_size = size_k
|
154 |
+
|
155 |
+
# Reshape to [groupsize, -1]
|
156 |
+
if group_size is not None and group_size < size_k:
|
157 |
+
w = w.reshape((-1, group_size, size_n))
|
158 |
+
w = w.permute(1, 0, 2)
|
159 |
+
w = w.reshape((group_size, -1))
|
160 |
+
|
161 |
+
# Compute scale for each group
|
162 |
+
max_val = torch.max(w, 0, keepdim=True).values
|
163 |
+
min_val = torch.min(w, 0, keepdim=True).values
|
164 |
+
|
165 |
+
max_q_val = quant_type.max()
|
166 |
+
min_q_val = quant_type.min()
|
167 |
+
|
168 |
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
169 |
+
maybe_w_zp = None
|
170 |
+
if group_size is not None:
|
171 |
+
if zero_points:
|
172 |
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
173 |
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
174 |
+
maybe_w_zp = (
|
175 |
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
# If the bias is such that there are no possible negative/positive
|
179 |
+
# values, set the max value to inf to avoid divide by 0
|
180 |
+
w_s = torch.max(
|
181 |
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
182 |
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
183 |
+
)
|
184 |
+
|
185 |
+
# Quantize
|
186 |
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
187 |
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
188 |
+
|
189 |
+
# Compute ref (dequantized)
|
190 |
+
# For some kernels (namely Machete) the zero-points are applied after the
|
191 |
+
# scales are applied, for this case computing the reference in similar way
|
192 |
+
# allows us to use tighter error tolerances in our unit tests.
|
193 |
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
194 |
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
195 |
+
else:
|
196 |
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
197 |
+
|
198 |
+
if quant_type.has_bias():
|
199 |
+
w_q += quant_type.bias
|
200 |
+
|
201 |
+
# Restore original shapes
|
202 |
+
if group_size is not None and group_size < size_k:
|
203 |
+
|
204 |
+
def reshape_w(w):
|
205 |
+
w = w.reshape((group_size, -1, size_n))
|
206 |
+
w = w.permute(1, 0, 2)
|
207 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
208 |
+
return w
|
209 |
+
|
210 |
+
w_q = reshape_w(w_q)
|
211 |
+
w_ref = reshape_w(w_ref)
|
212 |
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
213 |
+
|
214 |
+
if maybe_w_zp is not None:
|
215 |
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
216 |
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
217 |
+
|
218 |
+
return (
|
219 |
+
w_ref.to(device=orig_device),
|
220 |
+
w_q.to(device=orig_device),
|
221 |
+
w_s if group_size is not None else None,
|
222 |
+
maybe_w_zp,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def gptq_quantize_weights(
|
227 |
+
w: torch.Tensor,
|
228 |
+
quant_type: ScalarType,
|
229 |
+
group_size: int,
|
230 |
+
act_order: bool,
|
231 |
+
test_perm: Optional[torch.Tensor] = None,
|
232 |
+
):
|
233 |
+
size_k, _ = w.shape
|
234 |
+
|
235 |
+
assert w.is_floating_point(), "w must be float"
|
236 |
+
assert (
|
237 |
+
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
|
238 |
+
), f"Unsupported gptq type = {quant_type}"
|
239 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
240 |
+
size_k
|
241 |
+
], f"Unsupported groupsize = {group_size}"
|
242 |
+
|
243 |
+
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
244 |
+
|
245 |
+
# Apply act_order
|
246 |
+
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
247 |
+
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
248 |
+
if act_order:
|
249 |
+
assert (
|
250 |
+
group_size < size_k
|
251 |
+
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
252 |
+
group_size, size_k
|
253 |
+
)
|
254 |
+
|
255 |
+
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
|
256 |
+
|
257 |
+
return w_ref, w_q, w_s, g_idx, rand_perm
|
258 |
+
|
259 |
+
|
260 |
+
# QQQ employs different quant schemes for per-group and
|
261 |
+
# per-channel quantization.
|
262 |
+
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
263 |
+
orig_device = w.device
|
264 |
+
size_k, size_n = w.shape
|
265 |
+
|
266 |
+
assert w.is_floating_point(), "w must be float"
|
267 |
+
assert (
|
268 |
+
num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
|
269 |
+
), f"Unsupported num_bits = {num_bits}"
|
270 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
271 |
+
size_k
|
272 |
+
], f"Unsupported groupsize = {group_size}"
|
273 |
+
|
274 |
+
if group_size == -1:
|
275 |
+
group_size = size_k
|
276 |
+
assert group_size <= size_k
|
277 |
+
|
278 |
+
if group_size < size_k:
|
279 |
+
# Reshape to [groupsize, -1]
|
280 |
+
w = w.reshape((-1, group_size, size_n))
|
281 |
+
w = w.permute(1, 0, 2)
|
282 |
+
w = w.reshape((group_size, -1))
|
283 |
+
|
284 |
+
max_q_val = 2**num_bits - 1
|
285 |
+
half_q_val = (max_q_val + 1) // 2
|
286 |
+
|
287 |
+
# Compute scale for each group
|
288 |
+
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
289 |
+
s_group *= 2 / max_q_val # 2 => symmetric
|
290 |
+
|
291 |
+
# Quantize
|
292 |
+
q_w = torch.round(w / s_group).int()
|
293 |
+
q_w += half_q_val
|
294 |
+
q_w = torch.clamp(q_w, 0, max_q_val)
|
295 |
+
# Compute ref (dequantized)
|
296 |
+
w_ref = (q_w - half_q_val).half() * s_group
|
297 |
+
|
298 |
+
# Restore original shapes
|
299 |
+
def reshape_w(w):
|
300 |
+
w = w.reshape((group_size, -1, size_n))
|
301 |
+
w = w.permute(1, 0, 2)
|
302 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
303 |
+
return w
|
304 |
+
|
305 |
+
q_w = reshape_w(q_w)
|
306 |
+
w_ref = reshape_w(w_ref)
|
307 |
+
|
308 |
+
# Compute int8 quantization scale for each channel
|
309 |
+
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
|
310 |
+
s_channel /= 127.0
|
311 |
+
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
|
312 |
+
w_ref = t_int8.half() * s_channel
|
313 |
+
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
|
314 |
+
|
315 |
+
# Fuse scales
|
316 |
+
s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
|
317 |
+
dtype=torch.half
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
max_q_val = 2 ** (num_bits - 1) - 1
|
321 |
+
|
322 |
+
# Compute scale for each channel
|
323 |
+
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
324 |
+
s_channel /= max_q_val
|
325 |
+
|
326 |
+
# Quantize
|
327 |
+
q_w = torch.round(w / s_channel).int()
|
328 |
+
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
|
329 |
+
# Compute ref (dequantized)
|
330 |
+
w_ref = q_w.half() * s_channel
|
331 |
+
|
332 |
+
s_group = torch.tensor([], dtype=torch.half)
|
333 |
+
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
|
334 |
+
s_channel /= 2 ** (8 - num_bits)
|
335 |
+
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
|
336 |
+
|
337 |
+
return (
|
338 |
+
w_ref.to(device=orig_device),
|
339 |
+
q_w.to(device=orig_device),
|
340 |
+
s_group.to(device=orig_device),
|
341 |
+
s_channel.to(device=orig_device),
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
346 |
+
orig_device = q_w.device
|
347 |
+
|
348 |
+
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
|
349 |
+
|
350 |
+
g_idx = g_idx[sort_indices].contiguous()
|
351 |
+
q_w = q_w[sort_indices, :].contiguous()
|
352 |
+
|
353 |
+
return (
|
354 |
+
q_w.to(device=orig_device),
|
355 |
+
g_idx.to(device=orig_device),
|
356 |
+
sort_indices.to(device=orig_device),
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
def pack_rows(
|
361 |
+
q_w: torch.Tensor,
|
362 |
+
num_bits: int,
|
363 |
+
size_k: int,
|
364 |
+
size_n: int,
|
365 |
+
):
|
366 |
+
assert q_w.shape == (size_k, size_n)
|
367 |
+
|
368 |
+
pack_factor = get_pack_factor(num_bits)
|
369 |
+
assert size_k % pack_factor == 0
|
370 |
+
|
371 |
+
orig_device = q_w.device
|
372 |
+
|
373 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
374 |
+
|
375 |
+
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
376 |
+
|
377 |
+
for i in range(pack_factor):
|
378 |
+
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
379 |
+
|
380 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
381 |
+
return q_res
|
382 |
+
|
383 |
+
|
384 |
+
def pack_cols(
|
385 |
+
q_w: torch.Tensor,
|
386 |
+
num_bits: int,
|
387 |
+
size_k: int,
|
388 |
+
size_n: int,
|
389 |
+
):
|
390 |
+
assert q_w.shape == (size_k, size_n)
|
391 |
+
|
392 |
+
pack_factor = get_pack_factor(num_bits)
|
393 |
+
assert size_n % pack_factor == 0
|
394 |
+
|
395 |
+
orig_device = q_w.device
|
396 |
+
|
397 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
398 |
+
|
399 |
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
400 |
+
|
401 |
+
for i in range(pack_factor):
|
402 |
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
403 |
+
|
404 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
405 |
+
q_res = q_res.contiguous()
|
406 |
+
|
407 |
+
return q_res
|
408 |
+
|
409 |
+
|
410 |
+
def unpack_cols(
|
411 |
+
packed_q_w: torch.Tensor,
|
412 |
+
num_bits: int,
|
413 |
+
size_k: int,
|
414 |
+
size_n: int,
|
415 |
+
):
|
416 |
+
pack_factor = get_pack_factor(num_bits)
|
417 |
+
assert size_n % pack_factor == 0
|
418 |
+
assert packed_q_w.shape == (
|
419 |
+
size_k,
|
420 |
+
size_n // pack_factor,
|
421 |
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
422 |
+
packed_q_w.shape, size_k, size_n, pack_factor
|
423 |
+
)
|
424 |
+
|
425 |
+
orig_device = packed_q_w.device
|
426 |
+
|
427 |
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
428 |
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
429 |
+
|
430 |
+
mask = (1 << num_bits) - 1
|
431 |
+
for i in range(pack_factor):
|
432 |
+
vals = packed_q_w_cpu & mask
|
433 |
+
packed_q_w_cpu >>= num_bits
|
434 |
+
q_res[:, i::pack_factor] = vals
|
435 |
+
|
436 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
437 |
+
q_res = q_res.contiguous()
|
438 |
+
|
439 |
+
return q_res
|
440 |
+
|
441 |
+
|
442 |
+
def gptq_pack(
|
443 |
+
q_w: torch.Tensor,
|
444 |
+
num_bits: int,
|
445 |
+
size_k: int,
|
446 |
+
size_n: int,
|
447 |
+
):
|
448 |
+
return pack_rows(q_w, num_bits, size_k, size_n)
|
449 |
+
|
450 |
+
|
451 |
+
def awq_pack(
|
452 |
+
q_w: torch.Tensor,
|
453 |
+
num_bits: int,
|
454 |
+
size_k: int,
|
455 |
+
size_n: int,
|
456 |
+
):
|
457 |
+
assert q_w.shape == (size_k, size_n)
|
458 |
+
|
459 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
460 |
+
if num_bits == 4:
|
461 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
462 |
+
elif num_bits == 8:
|
463 |
+
interleave = numpy.array([0, 2, 1, 3])
|
464 |
+
else:
|
465 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
466 |
+
|
467 |
+
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
468 |
+
q_w = q_w.reshape((-1, size_n)).contiguous()
|
469 |
+
|
470 |
+
return pack_cols(q_w, num_bits, size_k, size_n)
|
build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py
CHANGED
@@ -1,150 +1,30 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
#if current_platform.is_rocm():
|
33 |
-
# triton_scaled_mm_module = importlib.import_module(
|
34 |
-
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
35 |
-
# "triton_scaled_mm")
|
36 |
-
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
37 |
-
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
38 |
-
|
39 |
-
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
40 |
-
|
41 |
-
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
42 |
-
|
43 |
-
return out
|
44 |
-
|
45 |
-
# fp8
|
46 |
-
def scaled_fp8_quant(
|
47 |
-
input: torch.Tensor,
|
48 |
-
scale: Optional[torch.Tensor] = None,
|
49 |
-
num_token_padding: Optional[int] = None,
|
50 |
-
scale_ub: Optional[torch.Tensor] = None,
|
51 |
-
use_per_token_if_dynamic: bool = False,
|
52 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
53 |
-
"""
|
54 |
-
Quantize input tensor to FP8 and return quantized tensor and scale.
|
55 |
-
|
56 |
-
This function supports both static and dynamic quantization: If you
|
57 |
-
provide the scale, it will use static scaling and if you omit it,
|
58 |
-
the scale will be determined dynamically. The function also allows
|
59 |
-
optional padding of the output tensors for downstream kernels that
|
60 |
-
will benefit from padding.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
input: The input tensor to be quantized to FP8
|
64 |
-
scale: Optional scaling factor for the FP8 quantization
|
65 |
-
scale_ub: Optional upper bound for scaling factor in dynamic
|
66 |
-
per token case
|
67 |
-
num_token_padding: If specified, pad the first dimension
|
68 |
-
of the output to at least this value.
|
69 |
-
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
70 |
-
in the dynamic quantization case.
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
74 |
-
scaling factor.
|
75 |
-
"""
|
76 |
-
# This code assumes batch_dim and num_tokens are flattened
|
77 |
-
assert (input.ndim == 2)
|
78 |
-
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
79 |
-
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
80 |
-
#out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
81 |
-
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
82 |
-
out_dtype = torch.float8_e4m3fn
|
83 |
-
if num_token_padding:
|
84 |
-
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
85 |
-
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
86 |
-
|
87 |
-
if scale is None:
|
88 |
-
if use_per_token_if_dynamic:
|
89 |
-
scale = torch.empty((shape[0], 1),
|
90 |
-
device=input.device,
|
91 |
-
dtype=torch.float32)
|
92 |
-
ops.dynamic_per_token_scaled_fp8_quant(
|
93 |
-
output, input, scale, scale_ub)
|
94 |
-
else:
|
95 |
-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
96 |
-
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
97 |
-
else:
|
98 |
-
# num_token_padding not implemented for this case
|
99 |
-
assert (scale.numel() == 1 or num_token_padding is None)
|
100 |
-
ops.static_scaled_fp8_quant(output, input, scale)
|
101 |
-
|
102 |
-
return output, scale
|
103 |
-
|
104 |
-
# int8
|
105 |
-
def scaled_int8_quant(
|
106 |
-
input: torch.Tensor,
|
107 |
-
scale: Optional[torch.Tensor] = None,
|
108 |
-
azp: Optional[torch.Tensor] = None,
|
109 |
-
symmetric: bool = True
|
110 |
-
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
111 |
-
"""
|
112 |
-
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
113 |
-
|
114 |
-
Args:
|
115 |
-
input: The input tensor to be quantized to int8.
|
116 |
-
scale: Optional scaling factor for the int8 quantization.
|
117 |
-
When not provided, we invoke dynamic-per-token quantization.
|
118 |
-
azp: Optional zero-point for the int8 quantization.
|
119 |
-
Must be provided for asymmetric quantization if `scale` is provided.
|
120 |
-
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
121 |
-
|
122 |
-
Returns:
|
123 |
-
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
124 |
-
"""
|
125 |
-
output = torch.empty_like(input, dtype=torch.int8)
|
126 |
-
if scale is not None:
|
127 |
-
# static-per-tensor quantization.
|
128 |
-
assert symmetric == (
|
129 |
-
azp is
|
130 |
-
None), "azp must only be provided for asymmetric quantization."
|
131 |
-
ops.static_scaled_int8_quant(output, input, scale, azp)
|
132 |
-
return output, scale, azp
|
133 |
-
|
134 |
-
# dynamic-per-token quantization.
|
135 |
-
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
136 |
-
device=input.device,
|
137 |
-
dtype=torch.float32)
|
138 |
-
input_azp = None if symmetric else torch.empty_like(input_scales,
|
139 |
-
dtype=torch.int32)
|
140 |
-
ops.dynamic_scaled_int8_quant(output, input, input_scales,
|
141 |
-
input_azp)
|
142 |
-
return output, input_scales, input_azp
|
143 |
-
|
144 |
-
# fp8 marlin
|
145 |
-
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
146 |
-
b_scales: torch.Tensor, workspace: torch.Tensor,
|
147 |
-
num_bits: int, size_m: int, size_n: int,
|
148 |
-
size_k: int) -> torch.Tensor:
|
149 |
-
return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
150 |
-
num_bits, size_m, size_n, size_k)
|
|
|
1 |
+
from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
|
2 |
+
from .cutlass import (
|
3 |
+
cutlass_scaled_mm_supports_fp8,
|
4 |
+
cutlass_scaled_mm,
|
5 |
+
cutlass_scaled_mm_azp,
|
6 |
+
)
|
7 |
+
from .marlin import (
|
8 |
+
awq_marlin_repack,
|
9 |
+
fp8_marlin_gemm,
|
10 |
+
gptq_marlin_gemm,
|
11 |
+
gptq_marlin_repack,
|
12 |
+
gptq_marlin_24_gemm,
|
13 |
+
marlin_qqq_gemm,
|
14 |
+
marlin_gemm,
|
15 |
+
)
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"awq_marlin_repack",
|
19 |
+
"cutlass_scaled_mm",
|
20 |
+
"cutlass_scaled_mm_azp",
|
21 |
+
"cutlass_scaled_mm_supports_fp8",
|
22 |
+
"fp8_marlin_gemm",
|
23 |
+
"gptq_marlin_24_gemm",
|
24 |
+
"gptq_marlin_gemm",
|
25 |
+
"gptq_marlin_repack",
|
26 |
+
"marlin_gemm",
|
27 |
+
"marlin_qqq_gemm",
|
28 |
+
"scaled_fp8_quant",
|
29 |
+
"scaled_int8_quant",
|
30 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch24-cxx98-cu118-x86_64-linux/quantization/_ops.py
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
import torch
|
2 |
from . import _quantization_0_0_1
|
3 |
ops = torch.ops._quantization_0_0_1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from . import _quantization_0_0_1
|
3 |
ops = torch.ops._quantization_0_0_1
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_quantization_0_0_1::{op_name}"
|
build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:734f235fc2749269910ee4e988da205a9442edf73c0f9b3ef41fff100bc66707
|
3 |
+
size 85709024
|
build/torch24-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
# fp8
|
18 |
+
def scaled_fp8_quant(
|
19 |
+
input: torch.Tensor,
|
20 |
+
scale: Optional[torch.Tensor] = None,
|
21 |
+
num_token_padding: Optional[int] = None,
|
22 |
+
scale_ub: Optional[torch.Tensor] = None,
|
23 |
+
use_per_token_if_dynamic: bool = False,
|
24 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
"""
|
26 |
+
Quantize input tensor to FP8 and return quantized tensor and scale.
|
27 |
+
|
28 |
+
This function supports both static and dynamic quantization: If you
|
29 |
+
provide the scale, it will use static scaling and if you omit it,
|
30 |
+
the scale will be determined dynamically. The function also allows
|
31 |
+
optional padding of the output tensors for downstream kernels that
|
32 |
+
will benefit from padding.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
input: The input tensor to be quantized to FP8
|
36 |
+
scale: Optional scaling factor for the FP8 quantization
|
37 |
+
scale_ub: Optional upper bound for scaling factor in dynamic
|
38 |
+
per token case
|
39 |
+
num_token_padding: If specified, pad the first dimension
|
40 |
+
of the output to at least this value.
|
41 |
+
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
42 |
+
in the dynamic quantization case.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
46 |
+
scaling factor.
|
47 |
+
"""
|
48 |
+
# This code assumes batch_dim and num_tokens are flattened
|
49 |
+
assert input.ndim == 2
|
50 |
+
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
51 |
+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
52 |
+
# out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
53 |
+
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
54 |
+
out_dtype = torch.float8_e4m3fn
|
55 |
+
if num_token_padding:
|
56 |
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
57 |
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
58 |
+
|
59 |
+
if scale is None:
|
60 |
+
if use_per_token_if_dynamic:
|
61 |
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
62 |
+
ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
63 |
+
else:
|
64 |
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
65 |
+
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
66 |
+
else:
|
67 |
+
# num_token_padding not implemented for this case
|
68 |
+
assert scale.numel() == 1 or num_token_padding is None
|
69 |
+
ops.static_scaled_fp8_quant(output, input, scale)
|
70 |
+
|
71 |
+
return output, scale
|
72 |
+
|
73 |
+
|
74 |
+
# int8
|
75 |
+
def scaled_int8_quant(
|
76 |
+
input: torch.Tensor,
|
77 |
+
scale: Optional[torch.Tensor] = None,
|
78 |
+
azp: Optional[torch.Tensor] = None,
|
79 |
+
symmetric: bool = True,
|
80 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
81 |
+
"""
|
82 |
+
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
input: The input tensor to be quantized to int8.
|
86 |
+
scale: Optional scaling factor for the int8 quantization.
|
87 |
+
When not provided, we invoke dynamic-per-token quantization.
|
88 |
+
azp: Optional zero-point for the int8 quantization.
|
89 |
+
Must be provided for asymmetric quantization if `scale` is provided.
|
90 |
+
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
94 |
+
"""
|
95 |
+
output = torch.empty_like(input, dtype=torch.int8)
|
96 |
+
if scale is not None:
|
97 |
+
# static-per-tensor quantization.
|
98 |
+
assert symmetric == (
|
99 |
+
azp is None
|
100 |
+
), "azp must only be provided for asymmetric quantization."
|
101 |
+
ops.static_scaled_int8_quant(output, input, scale, azp)
|
102 |
+
return output, scale, azp
|
103 |
+
|
104 |
+
# dynamic-per-token quantization.
|
105 |
+
input_scales = torch.empty(
|
106 |
+
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
|
107 |
+
)
|
108 |
+
input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
|
109 |
+
ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
|
110 |
+
return output, input_scales, input_azp
|
build/torch24-cxx98-cu118-x86_64-linux/quantization/cutlass.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
18 |
+
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
19 |
+
|
20 |
+
|
21 |
+
def cutlass_scaled_mm(
|
22 |
+
a: torch.Tensor,
|
23 |
+
b: torch.Tensor,
|
24 |
+
scale_a: torch.Tensor,
|
25 |
+
scale_b: torch.Tensor,
|
26 |
+
out_dtype: torch.dtype,
|
27 |
+
bias: Optional[torch.Tensor] = None,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
30 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
31 |
+
assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
|
32 |
+
|
33 |
+
m = a.shape[0]
|
34 |
+
n = b.shape[1]
|
35 |
+
|
36 |
+
# if current_platform.is_rocm():
|
37 |
+
# triton_scaled_mm_module = importlib.import_module(
|
38 |
+
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
39 |
+
# "triton_scaled_mm")
|
40 |
+
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
41 |
+
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
42 |
+
|
43 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
44 |
+
|
45 |
+
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
def cutlass_scaled_mm_azp(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b: torch.Tensor,
|
53 |
+
scale_a: torch.Tensor,
|
54 |
+
scale_b: torch.Tensor,
|
55 |
+
out_dtype: torch.dtype,
|
56 |
+
azp_adj: torch.Tensor,
|
57 |
+
azp: Optional[torch.Tensor] = None,
|
58 |
+
bias: Optional[torch.Tensor] = None,
|
59 |
+
) -> torch.Tensor:
|
60 |
+
"""
|
61 |
+
:param azp_adj: In the per-tensor case, this should include the azp.
|
62 |
+
Always per-channel.
|
63 |
+
:param azp: Only set in the per-token case. Per-token if set.
|
64 |
+
"""
|
65 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
66 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
67 |
+
assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
|
68 |
+
assert azp is None or azp.numel() == a.shape[0]
|
69 |
+
|
70 |
+
m = a.shape[0]
|
71 |
+
n = b.shape[1]
|
72 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
73 |
+
|
74 |
+
ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
|
75 |
+
return out
|
build/torch24-cxx98-cu118-x86_64-linux/quantization/marlin.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# neuron has torch version that doesn't even have impl_abstract
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
def register_fake(fn):
|
8 |
+
return lambda name: fn
|
9 |
+
else:
|
10 |
+
try:
|
11 |
+
from torch.library import register_fake
|
12 |
+
except ImportError:
|
13 |
+
from torch.library import impl_abstract as register_fake
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ._ops import ops, add_op_namespace_prefix
|
17 |
+
except ImportError as e:
|
18 |
+
# Fallback for local development.
|
19 |
+
try:
|
20 |
+
import _quantization
|
21 |
+
|
22 |
+
ops = torch.ops._quantization
|
23 |
+
|
24 |
+
def add_op_namespace_prefix(op_name: str):
|
25 |
+
return f"_quantization::{op_name}"
|
26 |
+
except ImportError:
|
27 |
+
raise e
|
28 |
+
|
29 |
+
|
30 |
+
from .scalar_type import ScalarType
|
31 |
+
|
32 |
+
|
33 |
+
# fp8 marlin
|
34 |
+
def fp8_marlin_gemm(
|
35 |
+
a: torch.Tensor,
|
36 |
+
b_q_weight: torch.Tensor,
|
37 |
+
b_scales: torch.Tensor,
|
38 |
+
workspace: torch.Tensor,
|
39 |
+
num_bits: int,
|
40 |
+
size_m: int,
|
41 |
+
size_n: int,
|
42 |
+
size_k: int,
|
43 |
+
) -> torch.Tensor:
|
44 |
+
return ops.fp8_marlin_gemm(
|
45 |
+
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
# gptq_marlin
|
50 |
+
def gptq_marlin_gemm(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b_q_weight: torch.Tensor,
|
53 |
+
b_scales: torch.Tensor,
|
54 |
+
b_zeros: torch.Tensor,
|
55 |
+
g_idx: torch.Tensor,
|
56 |
+
perm: torch.Tensor,
|
57 |
+
workspace: torch.Tensor,
|
58 |
+
b_q_type: ScalarType,
|
59 |
+
size_m: int,
|
60 |
+
size_n: int,
|
61 |
+
size_k: int,
|
62 |
+
is_k_full: bool,
|
63 |
+
has_zp: bool = False,
|
64 |
+
use_fp32_reduce: bool = False,
|
65 |
+
is_zp_float: bool = False,
|
66 |
+
) -> torch.Tensor:
|
67 |
+
return ops.gptq_marlin_gemm(
|
68 |
+
a,
|
69 |
+
b_q_weight,
|
70 |
+
b_scales,
|
71 |
+
b_zeros,
|
72 |
+
g_idx,
|
73 |
+
perm,
|
74 |
+
workspace,
|
75 |
+
b_q_type.id,
|
76 |
+
size_m,
|
77 |
+
size_n,
|
78 |
+
size_k,
|
79 |
+
is_k_full,
|
80 |
+
has_zp,
|
81 |
+
use_fp32_reduce,
|
82 |
+
is_zp_float,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
# gptq_marlin
|
87 |
+
def gptq_marlin_repack(
|
88 |
+
b_q_weight: torch.Tensor,
|
89 |
+
perm: torch.Tensor,
|
90 |
+
size_k: int,
|
91 |
+
size_n: int,
|
92 |
+
num_bits: int,
|
93 |
+
) -> torch.Tensor:
|
94 |
+
return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
95 |
+
|
96 |
+
|
97 |
+
# gptq_marlin
|
98 |
+
def awq_marlin_repack(
|
99 |
+
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
100 |
+
) -> torch.Tensor:
|
101 |
+
return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
102 |
+
|
103 |
+
|
104 |
+
# marlin
|
105 |
+
def marlin_gemm(
|
106 |
+
a: torch.Tensor,
|
107 |
+
b_q_weight: torch.Tensor,
|
108 |
+
b_scales: torch.Tensor,
|
109 |
+
workspace: torch.Tensor,
|
110 |
+
size_m: int,
|
111 |
+
size_n: int,
|
112 |
+
size_k: int,
|
113 |
+
) -> torch.Tensor:
|
114 |
+
return ops.marlin_gemm(
|
115 |
+
a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# marlin_24
|
120 |
+
def gptq_marlin_24_gemm(
|
121 |
+
a: torch.Tensor,
|
122 |
+
b_q_weight: torch.Tensor,
|
123 |
+
b_meta: torch.Tensor,
|
124 |
+
b_scales: torch.Tensor,
|
125 |
+
workspace: torch.Tensor,
|
126 |
+
b_q_type: ScalarType,
|
127 |
+
size_m: int,
|
128 |
+
size_n: int,
|
129 |
+
size_k: int,
|
130 |
+
) -> torch.Tensor:
|
131 |
+
return ops.gptq_marlin_24_gemm(
|
132 |
+
a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
# qqq ops
|
137 |
+
def marlin_qqq_gemm(
|
138 |
+
a: torch.Tensor,
|
139 |
+
b_q_weight: torch.Tensor,
|
140 |
+
s_tok: torch.Tensor,
|
141 |
+
s_ch: torch.Tensor,
|
142 |
+
s_group: torch.Tensor,
|
143 |
+
workspace: torch.Tensor,
|
144 |
+
size_m: int,
|
145 |
+
size_n: int,
|
146 |
+
size_k: int,
|
147 |
+
) -> torch.Tensor:
|
148 |
+
return ops.marlin_qqq_gemm(
|
149 |
+
a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
# Fake ops
|
154 |
+
|
155 |
+
if hasattr(ops, "gptq_marlin_24_gemm"):
|
156 |
+
@register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
|
157 |
+
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
158 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
159 |
+
num_bits: int, size_m: torch.SymInt,
|
160 |
+
size_n: torch.SymInt,
|
161 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
162 |
+
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
163 |
+
|
164 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
|
165 |
+
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
166 |
+
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
167 |
+
workspace: torch.Tensor,
|
168 |
+
b_q_type: ScalarType, size_m: torch.SymInt,
|
169 |
+
size_n: torch.SymInt,
|
170 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
171 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
172 |
+
|
173 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
|
174 |
+
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
175 |
+
b_q_weight: torch.Tensor,
|
176 |
+
b_scales: torch.Tensor,
|
177 |
+
b_zeros: torch.Tensor,
|
178 |
+
g_idx: torch.Tensor,
|
179 |
+
perm: torch.Tensor,
|
180 |
+
workspace: torch.Tensor,
|
181 |
+
b_q_type: ScalarType,
|
182 |
+
size_m: torch.SymInt,
|
183 |
+
size_n: torch.SymInt,
|
184 |
+
size_k: torch.SymInt,
|
185 |
+
is_k_full: bool,
|
186 |
+
has_zp: bool = False,
|
187 |
+
use_fp32_reduce: bool = False,
|
188 |
+
is_zp_float: bool = False) -> torch.Tensor:
|
189 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
190 |
+
|
191 |
+
@register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
|
192 |
+
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
193 |
+
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
194 |
+
s_group: torch.Tensor, workspace: torch.Tensor,
|
195 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
196 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
197 |
+
return torch.empty((size_m, size_n),
|
198 |
+
dtype=torch.float16,
|
199 |
+
device=a.device)
|
200 |
+
|
201 |
+
@register_fake(add_op_namespace_prefix("marlin_gemm"))
|
202 |
+
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
203 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
204 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
205 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
206 |
+
return torch.empty((size_m, size_n),
|
207 |
+
dtype=torch.float16,
|
208 |
+
device=a.device)
|
build/torch24-cxx98-cu118-x86_64-linux/quantization/scalar_type.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import struct
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
# Mirrors enum in `core/scalar_type.hpp`
|
9 |
+
class NanRepr(Enum):
|
10 |
+
NONE = 0 # nans are not supported
|
11 |
+
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
12 |
+
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
13 |
+
|
14 |
+
|
15 |
+
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
16 |
+
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
17 |
+
# in sync until the inductor fully supports custom C++ classes.
|
18 |
+
@dataclass(frozen=True)
|
19 |
+
class ScalarType:
|
20 |
+
"""
|
21 |
+
ScalarType can represent a wide range of floating point and integer
|
22 |
+
types, in particular it can be used to represent sub-byte data types
|
23 |
+
(something that torch.dtype currently does not support). It is also
|
24 |
+
capable of representing types with a bias, i.e.:
|
25 |
+
`stored_value = value + bias`,
|
26 |
+
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
27 |
+
of 8). The implementation for this class can be found in
|
28 |
+
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
29 |
+
with that file.
|
30 |
+
"""
|
31 |
+
|
32 |
+
exponent: int
|
33 |
+
"""
|
34 |
+
Number of bits in the exponent if this is a floating point type
|
35 |
+
(zero if this an integer type)
|
36 |
+
"""
|
37 |
+
|
38 |
+
mantissa: int
|
39 |
+
"""
|
40 |
+
Number of bits in the mantissa if this is a floating point type,
|
41 |
+
or the number bits representing an integer excluding the sign bit if
|
42 |
+
this an integer type.
|
43 |
+
"""
|
44 |
+
|
45 |
+
signed: bool
|
46 |
+
"If the type is signed (i.e. has a sign bit)"
|
47 |
+
|
48 |
+
bias: int
|
49 |
+
"""
|
50 |
+
bias used to encode the values in this scalar type
|
51 |
+
(value = stored_value - bias, default 0) for example if we store the
|
52 |
+
type as an unsigned integer with a bias of 128 then the value 0 will be
|
53 |
+
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
54 |
+
"""
|
55 |
+
|
56 |
+
_finite_values_only: bool = False
|
57 |
+
"""
|
58 |
+
Private: if infs are supported, used `has_infs()` instead.
|
59 |
+
"""
|
60 |
+
|
61 |
+
nan_repr: NanRepr = NanRepr.IEEE_754
|
62 |
+
"""
|
63 |
+
How NaNs are represent in this scalar type, returns NanRepr value.
|
64 |
+
(not applicable for integer types)
|
65 |
+
"""
|
66 |
+
|
67 |
+
def _floating_point_max_int(self) -> int:
|
68 |
+
assert (
|
69 |
+
self.mantissa <= 52 and self.exponent <= 11
|
70 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
71 |
+
|
72 |
+
max_mantissa = (1 << self.mantissa) - 1
|
73 |
+
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
74 |
+
max_mantissa = max_mantissa - 1
|
75 |
+
|
76 |
+
max_exponent = (1 << self.exponent) - 2
|
77 |
+
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
78 |
+
or self.nan_repr == NanRepr.NONE):
|
79 |
+
assert (
|
80 |
+
self.exponent < 11
|
81 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
82 |
+
max_exponent = max_exponent + 1
|
83 |
+
|
84 |
+
# adjust the exponent to match that of a double
|
85 |
+
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
86 |
+
# e is the exponent bits), there is some precedent for non-standard
|
87 |
+
# biases, example `float8_e4m3b11fnuz` here:
|
88 |
+
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
89 |
+
# complication we are just assuming the standard exponent bias until
|
90 |
+
# there is a need to support non-standard biases
|
91 |
+
exponent_bias = (1 << (self.exponent - 1)) - 1
|
92 |
+
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
93 |
+
|
94 |
+
max_exponent_double = (max_exponent - exponent_bias +
|
95 |
+
exponent_bias_double)
|
96 |
+
|
97 |
+
# shift the mantissa and exponent into the proper positions for an
|
98 |
+
# IEEE double and bitwise-or them together.
|
99 |
+
return (max_mantissa <<
|
100 |
+
(52 - self.mantissa)) | (max_exponent_double << 52)
|
101 |
+
|
102 |
+
def _floating_point_max(self) -> float:
|
103 |
+
double_raw = self._floating_point_max_int()
|
104 |
+
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
105 |
+
|
106 |
+
def _raw_max(self) -> Union[int, float]:
|
107 |
+
if self.is_floating_point():
|
108 |
+
return self._floating_point_max()
|
109 |
+
else:
|
110 |
+
assert (self.size_bits < 64 or self.size_bits == 64
|
111 |
+
and self.is_signed()), "Cannot represent max as an int"
|
112 |
+
return (1 << self.mantissa) - 1
|
113 |
+
|
114 |
+
def _raw_min(self) -> Union[int, float]:
|
115 |
+
if self.is_floating_point():
|
116 |
+
assert self.is_signed(
|
117 |
+
), "We currently assume all floating point types are signed"
|
118 |
+
sign_bit_double = 1 << 63
|
119 |
+
|
120 |
+
max_raw = self._floating_point_max_int()
|
121 |
+
min_raw = max_raw | sign_bit_double
|
122 |
+
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
123 |
+
else:
|
124 |
+
assert (not self.is_signed() or
|
125 |
+
self.size_bits <= 64), "Cannot represent min as a int64_t"
|
126 |
+
|
127 |
+
if self.is_signed():
|
128 |
+
return -(1 << (self.size_bits - 1))
|
129 |
+
else:
|
130 |
+
return 0
|
131 |
+
|
132 |
+
@functools.cached_property
|
133 |
+
def id(self) -> int:
|
134 |
+
"""
|
135 |
+
Convert the ScalarType to an int which can be passed to pytorch custom
|
136 |
+
ops. This layout of the int must be kept in sync with the C++
|
137 |
+
ScalarType's from_id method.
|
138 |
+
"""
|
139 |
+
val = 0
|
140 |
+
offset = 0
|
141 |
+
|
142 |
+
def or_and_advance(member, bit_width):
|
143 |
+
nonlocal val
|
144 |
+
nonlocal offset
|
145 |
+
bit_mask = (1 << bit_width) - 1
|
146 |
+
val = val | (int(member) & bit_mask) << offset
|
147 |
+
offset = offset + bit_width
|
148 |
+
|
149 |
+
or_and_advance(self.exponent, 8)
|
150 |
+
or_and_advance(self.mantissa, 8)
|
151 |
+
or_and_advance(self.signed, 1)
|
152 |
+
or_and_advance(self.bias, 32)
|
153 |
+
or_and_advance(self._finite_values_only, 1)
|
154 |
+
or_and_advance(self.nan_repr.value, 8)
|
155 |
+
|
156 |
+
assert offset <= 64, \
|
157 |
+
f"ScalarType fields too big {offset} to fit into an int64"
|
158 |
+
|
159 |
+
return val
|
160 |
+
|
161 |
+
@property
|
162 |
+
def size_bits(self) -> int:
|
163 |
+
return self.exponent + self.mantissa + int(self.signed)
|
164 |
+
|
165 |
+
def min(self) -> Union[int, float]:
|
166 |
+
"""
|
167 |
+
Min representable value for this scalar type.
|
168 |
+
(accounting for bias if there is one)
|
169 |
+
"""
|
170 |
+
return self._raw_min() - self.bias
|
171 |
+
|
172 |
+
def max(self) -> Union[int, float]:
|
173 |
+
"""
|
174 |
+
Max representable value for this scalar type.
|
175 |
+
(accounting for bias if there is one)
|
176 |
+
"""
|
177 |
+
return self._raw_max() - self.bias
|
178 |
+
|
179 |
+
def is_signed(self) -> bool:
|
180 |
+
"""
|
181 |
+
If the type is signed (i.e. has a sign bit), same as `signed`
|
182 |
+
added for consistency with:
|
183 |
+
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
184 |
+
"""
|
185 |
+
return self.signed
|
186 |
+
|
187 |
+
def is_floating_point(self) -> bool:
|
188 |
+
"If the type is a floating point type"
|
189 |
+
return self.exponent != 0
|
190 |
+
|
191 |
+
def is_integer(self) -> bool:
|
192 |
+
"If the type is an integer type"
|
193 |
+
return self.exponent == 0
|
194 |
+
|
195 |
+
def has_bias(self) -> bool:
|
196 |
+
"If the type has a non-zero bias"
|
197 |
+
return self.bias != 0
|
198 |
+
|
199 |
+
def has_infs(self) -> bool:
|
200 |
+
"If the type is floating point and supports infinity"
|
201 |
+
return not self._finite_values_only
|
202 |
+
|
203 |
+
def has_nans(self) -> bool:
|
204 |
+
return self.nan_repr != NanRepr.NONE.value
|
205 |
+
|
206 |
+
def is_ieee_754(self) -> bool:
|
207 |
+
"""
|
208 |
+
If the type is a floating point type that follows IEEE 754
|
209 |
+
conventions
|
210 |
+
"""
|
211 |
+
return self.nan_repr == NanRepr.IEEE_754.value and \
|
212 |
+
not self._finite_values_only
|
213 |
+
|
214 |
+
def __str__(self) -> str:
|
215 |
+
"""
|
216 |
+
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
217 |
+
for floating point types (leading f) the scheme is:
|
218 |
+
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
219 |
+
flags:
|
220 |
+
- no-flags: means it follows IEEE 754 conventions
|
221 |
+
- f: means finite values only (no infinities)
|
222 |
+
- n: means nans are supported (non-standard encoding)
|
223 |
+
for integer types the scheme is:
|
224 |
+
`[u]int<size_bits>[b<bias>]`
|
225 |
+
- if bias is not present it means its zero
|
226 |
+
"""
|
227 |
+
if self.is_floating_point():
|
228 |
+
ret = "float" + str(self.size_bits) + "_e" + str(
|
229 |
+
self.exponent) + "m" + str(self.mantissa)
|
230 |
+
|
231 |
+
if not self.is_ieee_754():
|
232 |
+
if self._finite_values_only:
|
233 |
+
ret = ret + "f"
|
234 |
+
if self.nan_repr != NanRepr.NONE:
|
235 |
+
ret = ret + "n"
|
236 |
+
|
237 |
+
return ret
|
238 |
+
else:
|
239 |
+
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
240 |
+
if self.has_bias():
|
241 |
+
ret = ret + "b" + str(self.bias)
|
242 |
+
return ret
|
243 |
+
|
244 |
+
def __repr__(self) -> str:
|
245 |
+
return "ScalarType." + self.__str__()
|
246 |
+
|
247 |
+
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
248 |
+
# opcheck to work.
|
249 |
+
def __len__(self) -> int:
|
250 |
+
raise TypeError
|
251 |
+
|
252 |
+
#
|
253 |
+
# Convenience Constructors
|
254 |
+
#
|
255 |
+
|
256 |
+
@classmethod
|
257 |
+
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
258 |
+
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
259 |
+
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
260 |
+
ret.id # noqa B018: make sure the id is cached
|
261 |
+
return ret
|
262 |
+
|
263 |
+
@classmethod
|
264 |
+
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
265 |
+
"""Create a unsigned integer scalar type."""
|
266 |
+
ret = cls(0, size_bits, False, bias if bias else 0)
|
267 |
+
ret.id # noqa B018: make sure the id is cached
|
268 |
+
return ret
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
272 |
+
"""
|
273 |
+
Create a standard floating point type
|
274 |
+
(i.e. follows IEEE 754 conventions).
|
275 |
+
"""
|
276 |
+
assert (mantissa > 0 and exponent > 0)
|
277 |
+
ret = cls(exponent, mantissa, True, 0)
|
278 |
+
ret.id # noqa B018: make sure the id is cached
|
279 |
+
return ret
|
280 |
+
|
281 |
+
@classmethod
|
282 |
+
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
283 |
+
nan_repr: NanRepr) -> 'ScalarType':
|
284 |
+
"""
|
285 |
+
Create a non-standard floating point type
|
286 |
+
(i.e. does not follow IEEE 754 conventions).
|
287 |
+
"""
|
288 |
+
assert (mantissa > 0 and exponent > 0)
|
289 |
+
assert (nan_repr != NanRepr.IEEE_754), (
|
290 |
+
"use `float_IEEE754` constructor for floating point types that "
|
291 |
+
"follow IEEE 754 conventions")
|
292 |
+
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
293 |
+
ret.id # noqa B018: make sure the id is cached
|
294 |
+
return ret
|
295 |
+
|
296 |
+
|
297 |
+
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
298 |
+
# for floating point types (leading f) the scheme is:
|
299 |
+
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
300 |
+
# flags:
|
301 |
+
# - no-flags: means it follows IEEE 754 conventions
|
302 |
+
# - f: means finite values only (no infinities)
|
303 |
+
# - n: means nans are supported (non-standard encoding)
|
304 |
+
# for integer types the scheme is:
|
305 |
+
# `[u]int<size_bits>[b<bias>]`
|
306 |
+
# - if bias is not present it means its zero
|
307 |
+
|
308 |
+
|
309 |
+
class scalar_types:
|
310 |
+
int4 = ScalarType.int_(4, None)
|
311 |
+
uint4 = ScalarType.uint(4, None)
|
312 |
+
int8 = ScalarType.int_(8, None)
|
313 |
+
uint8 = ScalarType.uint(8, None)
|
314 |
+
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
315 |
+
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
316 |
+
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
317 |
+
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
318 |
+
|
319 |
+
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
320 |
+
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
321 |
+
|
322 |
+
# "gptq" types
|
323 |
+
uint2b2 = ScalarType.uint(2, 2)
|
324 |
+
uint3b4 = ScalarType.uint(3, 4)
|
325 |
+
uint4b8 = ScalarType.uint(4, 8)
|
326 |
+
uint8b128 = ScalarType.uint(8, 128)
|
327 |
+
|
328 |
+
# colloquial names
|
329 |
+
bfloat16 = float16_e8m7
|
330 |
+
float16 = float16_e5m10
|
build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py
ADDED
File without changes
|