Add full Marlin support and tests for Marlin/CUTLASS
Browse files- build.toml +19 -4
- ext-torch/__init__.py +30 -177
- ext-torch/compressed_tensors.py +110 -0
- ext-torch/cutlass.py +75 -0
- ext-torch/marlin.py +208 -0
- ext-torch/scalar_type.py +330 -0
- ext-torch/torch_binding.cpp +29 -3
- ext-torch/torch_binding.h +23 -0
- ext-torch/utils/marlin_utils.py +391 -0
- ext-torch/utils/marlin_utils_fp8.py +100 -0
- ext-torch/utils/marlin_utils_test.py +162 -0
- ext-torch/utils/marlin_utils_test_24.py +473 -0
- ext-torch/utils/marlin_utils_test_qqq.py +125 -0
- ext-torch/utils/quant_utils.py +470 -0
- marlin/dense/LICENSE +209 -0
- marlin/dense/common/base.h +32 -0
- marlin/dense/common/mem.h +89 -0
- marlin/dense/marlin_cuda_kernel.cu +1068 -0
- marlin/qqq/marlin_qqq_gemm_kernel.cu +1243 -0
- marlin/sparse/LICENSE +203 -0
- marlin/sparse/common/base.h +51 -0
- marlin/sparse/common/mem.h +136 -0
- marlin/sparse/common/mma.h +191 -0
- marlin/sparse/marlin_24_cuda_kernel.cu +1140 -0
- tests/kernels/test_marlin_gemm.py +733 -0
- tests/kernels/utils.py +38 -27
build.toml
CHANGED
@@ -10,9 +10,7 @@ src = [
|
|
10 |
"ext-torch/torch_binding.h"
|
11 |
]
|
12 |
include = [ "." ]
|
13 |
-
|
14 |
-
"ext-torch/__init__.py"
|
15 |
-
]
|
16 |
|
17 |
[kernel.cutlass_w8a8]
|
18 |
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
@@ -59,7 +57,6 @@ src = [
|
|
59 |
"gptq_marlin/marlin.cuh",
|
60 |
"gptq_marlin/marlin_dtypes.cuh",
|
61 |
]
|
62 |
-
#include = [ "." ]
|
63 |
depends = [ "torch" ]
|
64 |
|
65 |
[kernel.int8_common]
|
@@ -83,3 +80,21 @@ src = [
|
|
83 |
]
|
84 |
include = [ "." ]
|
85 |
depends = [ "torch" ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"ext-torch/torch_binding.h"
|
11 |
]
|
12 |
include = [ "." ]
|
13 |
+
pyroot = "ext-torch"
|
|
|
|
|
14 |
|
15 |
[kernel.cutlass_w8a8]
|
16 |
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
|
|
57 |
"gptq_marlin/marlin.cuh",
|
58 |
"gptq_marlin/marlin_dtypes.cuh",
|
59 |
]
|
|
|
60 |
depends = [ "torch" ]
|
61 |
|
62 |
[kernel.int8_common]
|
|
|
80 |
]
|
81 |
include = [ "." ]
|
82 |
depends = [ "torch" ]
|
83 |
+
|
84 |
+
[kernel.marlin]
|
85 |
+
capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
86 |
+
src = [
|
87 |
+
"core/scalar_type.hpp",
|
88 |
+
"marlin/dense/common/base.h",
|
89 |
+
"marlin/dense/common/mem.h",
|
90 |
+
"marlin/dense/marlin_cuda_kernel.cu",
|
91 |
+
"marlin/qqq/marlin_qqq_gemm_kernel.cu",
|
92 |
+
"marlin/sparse/common/base.h",
|
93 |
+
"marlin/sparse/common/mem.h",
|
94 |
+
"marlin/sparse/common/mma.h",
|
95 |
+
"marlin/sparse/marlin_24_cuda_kernel.cu"
|
96 |
+
]
|
97 |
+
include = [ "." ]
|
98 |
+
depends = [ "torch" ]
|
99 |
+
|
100 |
+
|
ext-torch/__init__.py
CHANGED
@@ -1,177 +1,30 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
#if current_platform.is_rocm():
|
33 |
-
# triton_scaled_mm_module = importlib.import_module(
|
34 |
-
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
35 |
-
# "triton_scaled_mm")
|
36 |
-
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
37 |
-
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
38 |
-
|
39 |
-
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
40 |
-
|
41 |
-
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
42 |
-
|
43 |
-
return out
|
44 |
-
|
45 |
-
def cutlass_scaled_mm_azp(a: torch.Tensor,
|
46 |
-
b: torch.Tensor,
|
47 |
-
scale_a: torch.Tensor,
|
48 |
-
scale_b: torch.Tensor,
|
49 |
-
out_dtype: torch.dtype,
|
50 |
-
azp_adj: torch.Tensor,
|
51 |
-
azp: Optional[torch.Tensor] = None,
|
52 |
-
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
53 |
-
"""
|
54 |
-
:param azp_adj: In the per-tensor case, this should include the azp.
|
55 |
-
Always per-channel.
|
56 |
-
:param azp: Only set in the per-token case. Per-token if set.
|
57 |
-
"""
|
58 |
-
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
59 |
-
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
60 |
-
assert bias is None or bias.numel(
|
61 |
-
) == b.shape[1] and bias.dtype == out_dtype
|
62 |
-
assert azp is None or azp.numel() == a.shape[0]
|
63 |
-
|
64 |
-
m = a.shape[0]
|
65 |
-
n = b.shape[1]
|
66 |
-
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
67 |
-
|
68 |
-
ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
|
69 |
-
azp, bias)
|
70 |
-
return out
|
71 |
-
|
72 |
-
# fp8
|
73 |
-
def scaled_fp8_quant(
|
74 |
-
input: torch.Tensor,
|
75 |
-
scale: Optional[torch.Tensor] = None,
|
76 |
-
num_token_padding: Optional[int] = None,
|
77 |
-
scale_ub: Optional[torch.Tensor] = None,
|
78 |
-
use_per_token_if_dynamic: bool = False,
|
79 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
80 |
-
"""
|
81 |
-
Quantize input tensor to FP8 and return quantized tensor and scale.
|
82 |
-
|
83 |
-
This function supports both static and dynamic quantization: If you
|
84 |
-
provide the scale, it will use static scaling and if you omit it,
|
85 |
-
the scale will be determined dynamically. The function also allows
|
86 |
-
optional padding of the output tensors for downstream kernels that
|
87 |
-
will benefit from padding.
|
88 |
-
|
89 |
-
Args:
|
90 |
-
input: The input tensor to be quantized to FP8
|
91 |
-
scale: Optional scaling factor for the FP8 quantization
|
92 |
-
scale_ub: Optional upper bound for scaling factor in dynamic
|
93 |
-
per token case
|
94 |
-
num_token_padding: If specified, pad the first dimension
|
95 |
-
of the output to at least this value.
|
96 |
-
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
97 |
-
in the dynamic quantization case.
|
98 |
-
|
99 |
-
Returns:
|
100 |
-
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
101 |
-
scaling factor.
|
102 |
-
"""
|
103 |
-
# This code assumes batch_dim and num_tokens are flattened
|
104 |
-
assert (input.ndim == 2)
|
105 |
-
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
106 |
-
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
107 |
-
#out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
108 |
-
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
109 |
-
out_dtype = torch.float8_e4m3fn
|
110 |
-
if num_token_padding:
|
111 |
-
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
112 |
-
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
113 |
-
|
114 |
-
if scale is None:
|
115 |
-
if use_per_token_if_dynamic:
|
116 |
-
scale = torch.empty((shape[0], 1),
|
117 |
-
device=input.device,
|
118 |
-
dtype=torch.float32)
|
119 |
-
ops.dynamic_per_token_scaled_fp8_quant(
|
120 |
-
output, input, scale, scale_ub)
|
121 |
-
else:
|
122 |
-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
123 |
-
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
124 |
-
else:
|
125 |
-
# num_token_padding not implemented for this case
|
126 |
-
assert (scale.numel() == 1 or num_token_padding is None)
|
127 |
-
ops.static_scaled_fp8_quant(output, input, scale)
|
128 |
-
|
129 |
-
return output, scale
|
130 |
-
|
131 |
-
# int8
|
132 |
-
def scaled_int8_quant(
|
133 |
-
input: torch.Tensor,
|
134 |
-
scale: Optional[torch.Tensor] = None,
|
135 |
-
azp: Optional[torch.Tensor] = None,
|
136 |
-
symmetric: bool = True
|
137 |
-
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
138 |
-
"""
|
139 |
-
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
140 |
-
|
141 |
-
Args:
|
142 |
-
input: The input tensor to be quantized to int8.
|
143 |
-
scale: Optional scaling factor for the int8 quantization.
|
144 |
-
When not provided, we invoke dynamic-per-token quantization.
|
145 |
-
azp: Optional zero-point for the int8 quantization.
|
146 |
-
Must be provided for asymmetric quantization if `scale` is provided.
|
147 |
-
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
148 |
-
|
149 |
-
Returns:
|
150 |
-
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
151 |
-
"""
|
152 |
-
output = torch.empty_like(input, dtype=torch.int8)
|
153 |
-
if scale is not None:
|
154 |
-
# static-per-tensor quantization.
|
155 |
-
assert symmetric == (
|
156 |
-
azp is
|
157 |
-
None), "azp must only be provided for asymmetric quantization."
|
158 |
-
ops.static_scaled_int8_quant(output, input, scale, azp)
|
159 |
-
return output, scale, azp
|
160 |
-
|
161 |
-
# dynamic-per-token quantization.
|
162 |
-
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
163 |
-
device=input.device,
|
164 |
-
dtype=torch.float32)
|
165 |
-
input_azp = None if symmetric else torch.empty_like(input_scales,
|
166 |
-
dtype=torch.int32)
|
167 |
-
ops.dynamic_scaled_int8_quant(output, input, input_scales,
|
168 |
-
input_azp)
|
169 |
-
return output, input_scales, input_azp
|
170 |
-
|
171 |
-
# fp8 marlin
|
172 |
-
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
173 |
-
b_scales: torch.Tensor, workspace: torch.Tensor,
|
174 |
-
num_bits: int, size_m: int, size_n: int,
|
175 |
-
size_k: int) -> torch.Tensor:
|
176 |
-
return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
177 |
-
num_bits, size_m, size_n, size_k)
|
|
|
1 |
+
from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
|
2 |
+
from .cutlass import (
|
3 |
+
cutlass_scaled_mm_supports_fp8,
|
4 |
+
cutlass_scaled_mm,
|
5 |
+
cutlass_scaled_mm_azp,
|
6 |
+
)
|
7 |
+
from .marlin import (
|
8 |
+
awq_marlin_repack,
|
9 |
+
fp8_marlin_gemm,
|
10 |
+
gptq_marlin_gemm,
|
11 |
+
gptq_marlin_repack,
|
12 |
+
gptq_marlin_24_gemm,
|
13 |
+
marlin_qqq_gemm,
|
14 |
+
marlin_gemm,
|
15 |
+
)
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"awq_marlin_repack",
|
19 |
+
"cutlass_scaled_mm",
|
20 |
+
"cutlass_scaled_mm_azp",
|
21 |
+
"cutlass_scaled_mm_supports_fp8",
|
22 |
+
"fp8_marlin_gemm",
|
23 |
+
"gptq_marlin_24_gemm",
|
24 |
+
"gptq_marlin_gemm",
|
25 |
+
"gptq_marlin_repack",
|
26 |
+
"marlin_gemm",
|
27 |
+
"marlin_qqq_gemm",
|
28 |
+
"scaled_fp8_quant",
|
29 |
+
"scaled_int8_quant",
|
30 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ext-torch/compressed_tensors.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
# fp8
|
18 |
+
def scaled_fp8_quant(
|
19 |
+
input: torch.Tensor,
|
20 |
+
scale: Optional[torch.Tensor] = None,
|
21 |
+
num_token_padding: Optional[int] = None,
|
22 |
+
scale_ub: Optional[torch.Tensor] = None,
|
23 |
+
use_per_token_if_dynamic: bool = False,
|
24 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
"""
|
26 |
+
Quantize input tensor to FP8 and return quantized tensor and scale.
|
27 |
+
|
28 |
+
This function supports both static and dynamic quantization: If you
|
29 |
+
provide the scale, it will use static scaling and if you omit it,
|
30 |
+
the scale will be determined dynamically. The function also allows
|
31 |
+
optional padding of the output tensors for downstream kernels that
|
32 |
+
will benefit from padding.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
input: The input tensor to be quantized to FP8
|
36 |
+
scale: Optional scaling factor for the FP8 quantization
|
37 |
+
scale_ub: Optional upper bound for scaling factor in dynamic
|
38 |
+
per token case
|
39 |
+
num_token_padding: If specified, pad the first dimension
|
40 |
+
of the output to at least this value.
|
41 |
+
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
42 |
+
in the dynamic quantization case.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
46 |
+
scaling factor.
|
47 |
+
"""
|
48 |
+
# This code assumes batch_dim and num_tokens are flattened
|
49 |
+
assert input.ndim == 2
|
50 |
+
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
51 |
+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
52 |
+
# out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
53 |
+
# if current_platform.is_rocm() else torch.float8_e4m3fn
|
54 |
+
out_dtype = torch.float8_e4m3fn
|
55 |
+
if num_token_padding:
|
56 |
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
57 |
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
58 |
+
|
59 |
+
if scale is None:
|
60 |
+
if use_per_token_if_dynamic:
|
61 |
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
62 |
+
ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
63 |
+
else:
|
64 |
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
65 |
+
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
66 |
+
else:
|
67 |
+
# num_token_padding not implemented for this case
|
68 |
+
assert scale.numel() == 1 or num_token_padding is None
|
69 |
+
ops.static_scaled_fp8_quant(output, input, scale)
|
70 |
+
|
71 |
+
return output, scale
|
72 |
+
|
73 |
+
|
74 |
+
# int8
|
75 |
+
def scaled_int8_quant(
|
76 |
+
input: torch.Tensor,
|
77 |
+
scale: Optional[torch.Tensor] = None,
|
78 |
+
azp: Optional[torch.Tensor] = None,
|
79 |
+
symmetric: bool = True,
|
80 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
81 |
+
"""
|
82 |
+
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
input: The input tensor to be quantized to int8.
|
86 |
+
scale: Optional scaling factor for the int8 quantization.
|
87 |
+
When not provided, we invoke dynamic-per-token quantization.
|
88 |
+
azp: Optional zero-point for the int8 quantization.
|
89 |
+
Must be provided for asymmetric quantization if `scale` is provided.
|
90 |
+
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
94 |
+
"""
|
95 |
+
output = torch.empty_like(input, dtype=torch.int8)
|
96 |
+
if scale is not None:
|
97 |
+
# static-per-tensor quantization.
|
98 |
+
assert symmetric == (
|
99 |
+
azp is None
|
100 |
+
), "azp must only be provided for asymmetric quantization."
|
101 |
+
ops.static_scaled_int8_quant(output, input, scale, azp)
|
102 |
+
return output, scale, azp
|
103 |
+
|
104 |
+
# dynamic-per-token quantization.
|
105 |
+
input_scales = torch.empty(
|
106 |
+
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
|
107 |
+
)
|
108 |
+
input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
|
109 |
+
ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
|
110 |
+
return output, input_scales, input_azp
|
ext-torch/cutlass.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
|
12 |
+
ops = torch.ops._quantization
|
13 |
+
except ImportError:
|
14 |
+
raise e
|
15 |
+
|
16 |
+
|
17 |
+
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
18 |
+
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
19 |
+
|
20 |
+
|
21 |
+
def cutlass_scaled_mm(
|
22 |
+
a: torch.Tensor,
|
23 |
+
b: torch.Tensor,
|
24 |
+
scale_a: torch.Tensor,
|
25 |
+
scale_b: torch.Tensor,
|
26 |
+
out_dtype: torch.dtype,
|
27 |
+
bias: Optional[torch.Tensor] = None,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
30 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
31 |
+
assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
|
32 |
+
|
33 |
+
m = a.shape[0]
|
34 |
+
n = b.shape[1]
|
35 |
+
|
36 |
+
# if current_platform.is_rocm():
|
37 |
+
# triton_scaled_mm_module = importlib.import_module(
|
38 |
+
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
39 |
+
# "triton_scaled_mm")
|
40 |
+
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
41 |
+
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
42 |
+
|
43 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
44 |
+
|
45 |
+
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
def cutlass_scaled_mm_azp(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b: torch.Tensor,
|
53 |
+
scale_a: torch.Tensor,
|
54 |
+
scale_b: torch.Tensor,
|
55 |
+
out_dtype: torch.dtype,
|
56 |
+
azp_adj: torch.Tensor,
|
57 |
+
azp: Optional[torch.Tensor] = None,
|
58 |
+
bias: Optional[torch.Tensor] = None,
|
59 |
+
) -> torch.Tensor:
|
60 |
+
"""
|
61 |
+
:param azp_adj: In the per-tensor case, this should include the azp.
|
62 |
+
Always per-channel.
|
63 |
+
:param azp: Only set in the per-token case. Per-token if set.
|
64 |
+
"""
|
65 |
+
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
66 |
+
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
67 |
+
assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
|
68 |
+
assert azp is None or azp.numel() == a.shape[0]
|
69 |
+
|
70 |
+
m = a.shape[0]
|
71 |
+
n = b.shape[1]
|
72 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
73 |
+
|
74 |
+
ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
|
75 |
+
return out
|
ext-torch/marlin.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# neuron has torch version that doesn't even have impl_abstract
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
def register_fake(fn):
|
8 |
+
return lambda name: fn
|
9 |
+
else:
|
10 |
+
try:
|
11 |
+
from torch.library import register_fake
|
12 |
+
except ImportError:
|
13 |
+
from torch.library import impl_abstract as register_fake
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ._ops import ops, add_op_namespace_prefix
|
17 |
+
except ImportError as e:
|
18 |
+
# Fallback for local development.
|
19 |
+
try:
|
20 |
+
import _quantization
|
21 |
+
|
22 |
+
ops = torch.ops._quantization
|
23 |
+
|
24 |
+
def add_op_namespace_prefix(op_name: str):
|
25 |
+
return f"_quantization::{op_name}"
|
26 |
+
except ImportError:
|
27 |
+
raise e
|
28 |
+
|
29 |
+
|
30 |
+
from .scalar_type import ScalarType
|
31 |
+
|
32 |
+
|
33 |
+
# fp8 marlin
|
34 |
+
def fp8_marlin_gemm(
|
35 |
+
a: torch.Tensor,
|
36 |
+
b_q_weight: torch.Tensor,
|
37 |
+
b_scales: torch.Tensor,
|
38 |
+
workspace: torch.Tensor,
|
39 |
+
num_bits: int,
|
40 |
+
size_m: int,
|
41 |
+
size_n: int,
|
42 |
+
size_k: int,
|
43 |
+
) -> torch.Tensor:
|
44 |
+
return ops.fp8_marlin_gemm(
|
45 |
+
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
# gptq_marlin
|
50 |
+
def gptq_marlin_gemm(
|
51 |
+
a: torch.Tensor,
|
52 |
+
b_q_weight: torch.Tensor,
|
53 |
+
b_scales: torch.Tensor,
|
54 |
+
b_zeros: torch.Tensor,
|
55 |
+
g_idx: torch.Tensor,
|
56 |
+
perm: torch.Tensor,
|
57 |
+
workspace: torch.Tensor,
|
58 |
+
b_q_type: ScalarType,
|
59 |
+
size_m: int,
|
60 |
+
size_n: int,
|
61 |
+
size_k: int,
|
62 |
+
is_k_full: bool,
|
63 |
+
has_zp: bool = False,
|
64 |
+
use_fp32_reduce: bool = False,
|
65 |
+
is_zp_float: bool = False,
|
66 |
+
) -> torch.Tensor:
|
67 |
+
return ops.gptq_marlin_gemm(
|
68 |
+
a,
|
69 |
+
b_q_weight,
|
70 |
+
b_scales,
|
71 |
+
b_zeros,
|
72 |
+
g_idx,
|
73 |
+
perm,
|
74 |
+
workspace,
|
75 |
+
b_q_type.id,
|
76 |
+
size_m,
|
77 |
+
size_n,
|
78 |
+
size_k,
|
79 |
+
is_k_full,
|
80 |
+
has_zp,
|
81 |
+
use_fp32_reduce,
|
82 |
+
is_zp_float,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
# gptq_marlin
|
87 |
+
def gptq_marlin_repack(
|
88 |
+
b_q_weight: torch.Tensor,
|
89 |
+
perm: torch.Tensor,
|
90 |
+
size_k: int,
|
91 |
+
size_n: int,
|
92 |
+
num_bits: int,
|
93 |
+
) -> torch.Tensor:
|
94 |
+
return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
95 |
+
|
96 |
+
|
97 |
+
# gptq_marlin
|
98 |
+
def awq_marlin_repack(
|
99 |
+
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
100 |
+
) -> torch.Tensor:
|
101 |
+
return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
102 |
+
|
103 |
+
|
104 |
+
# marlin
|
105 |
+
def marlin_gemm(
|
106 |
+
a: torch.Tensor,
|
107 |
+
b_q_weight: torch.Tensor,
|
108 |
+
b_scales: torch.Tensor,
|
109 |
+
workspace: torch.Tensor,
|
110 |
+
size_m: int,
|
111 |
+
size_n: int,
|
112 |
+
size_k: int,
|
113 |
+
) -> torch.Tensor:
|
114 |
+
return ops.marlin_gemm(
|
115 |
+
a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# marlin_24
|
120 |
+
def gptq_marlin_24_gemm(
|
121 |
+
a: torch.Tensor,
|
122 |
+
b_q_weight: torch.Tensor,
|
123 |
+
b_meta: torch.Tensor,
|
124 |
+
b_scales: torch.Tensor,
|
125 |
+
workspace: torch.Tensor,
|
126 |
+
b_q_type: ScalarType,
|
127 |
+
size_m: int,
|
128 |
+
size_n: int,
|
129 |
+
size_k: int,
|
130 |
+
) -> torch.Tensor:
|
131 |
+
return ops.gptq_marlin_24_gemm(
|
132 |
+
a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
# qqq ops
|
137 |
+
def marlin_qqq_gemm(
|
138 |
+
a: torch.Tensor,
|
139 |
+
b_q_weight: torch.Tensor,
|
140 |
+
s_tok: torch.Tensor,
|
141 |
+
s_ch: torch.Tensor,
|
142 |
+
s_group: torch.Tensor,
|
143 |
+
workspace: torch.Tensor,
|
144 |
+
size_m: int,
|
145 |
+
size_n: int,
|
146 |
+
size_k: int,
|
147 |
+
) -> torch.Tensor:
|
148 |
+
return ops.marlin_qqq_gemm(
|
149 |
+
a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
# Fake ops
|
154 |
+
|
155 |
+
if hasattr(ops, "gptq_marlin_24_gemm"):
|
156 |
+
@register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
|
157 |
+
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
158 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
159 |
+
num_bits: int, size_m: torch.SymInt,
|
160 |
+
size_n: torch.SymInt,
|
161 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
162 |
+
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
163 |
+
|
164 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
|
165 |
+
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
166 |
+
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
167 |
+
workspace: torch.Tensor,
|
168 |
+
b_q_type: ScalarType, size_m: torch.SymInt,
|
169 |
+
size_n: torch.SymInt,
|
170 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
171 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
172 |
+
|
173 |
+
@register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
|
174 |
+
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
175 |
+
b_q_weight: torch.Tensor,
|
176 |
+
b_scales: torch.Tensor,
|
177 |
+
b_zeros: torch.Tensor,
|
178 |
+
g_idx: torch.Tensor,
|
179 |
+
perm: torch.Tensor,
|
180 |
+
workspace: torch.Tensor,
|
181 |
+
b_q_type: ScalarType,
|
182 |
+
size_m: torch.SymInt,
|
183 |
+
size_n: torch.SymInt,
|
184 |
+
size_k: torch.SymInt,
|
185 |
+
is_k_full: bool,
|
186 |
+
has_zp: bool = False,
|
187 |
+
use_fp32_reduce: bool = False,
|
188 |
+
is_zp_float: bool = False) -> torch.Tensor:
|
189 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
190 |
+
|
191 |
+
@register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
|
192 |
+
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
193 |
+
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
194 |
+
s_group: torch.Tensor, workspace: torch.Tensor,
|
195 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
196 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
197 |
+
return torch.empty((size_m, size_n),
|
198 |
+
dtype=torch.float16,
|
199 |
+
device=a.device)
|
200 |
+
|
201 |
+
@register_fake(add_op_namespace_prefix("marlin_gemm"))
|
202 |
+
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
203 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
204 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
205 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
206 |
+
return torch.empty((size_m, size_n),
|
207 |
+
dtype=torch.float16,
|
208 |
+
device=a.device)
|
ext-torch/scalar_type.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import struct
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
# Mirrors enum in `core/scalar_type.hpp`
|
9 |
+
class NanRepr(Enum):
|
10 |
+
NONE = 0 # nans are not supported
|
11 |
+
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
12 |
+
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
13 |
+
|
14 |
+
|
15 |
+
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
16 |
+
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
17 |
+
# in sync until the inductor fully supports custom C++ classes.
|
18 |
+
@dataclass(frozen=True)
|
19 |
+
class ScalarType:
|
20 |
+
"""
|
21 |
+
ScalarType can represent a wide range of floating point and integer
|
22 |
+
types, in particular it can be used to represent sub-byte data types
|
23 |
+
(something that torch.dtype currently does not support). It is also
|
24 |
+
capable of representing types with a bias, i.e.:
|
25 |
+
`stored_value = value + bias`,
|
26 |
+
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
27 |
+
of 8). The implementation for this class can be found in
|
28 |
+
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
29 |
+
with that file.
|
30 |
+
"""
|
31 |
+
|
32 |
+
exponent: int
|
33 |
+
"""
|
34 |
+
Number of bits in the exponent if this is a floating point type
|
35 |
+
(zero if this an integer type)
|
36 |
+
"""
|
37 |
+
|
38 |
+
mantissa: int
|
39 |
+
"""
|
40 |
+
Number of bits in the mantissa if this is a floating point type,
|
41 |
+
or the number bits representing an integer excluding the sign bit if
|
42 |
+
this an integer type.
|
43 |
+
"""
|
44 |
+
|
45 |
+
signed: bool
|
46 |
+
"If the type is signed (i.e. has a sign bit)"
|
47 |
+
|
48 |
+
bias: int
|
49 |
+
"""
|
50 |
+
bias used to encode the values in this scalar type
|
51 |
+
(value = stored_value - bias, default 0) for example if we store the
|
52 |
+
type as an unsigned integer with a bias of 128 then the value 0 will be
|
53 |
+
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
54 |
+
"""
|
55 |
+
|
56 |
+
_finite_values_only: bool = False
|
57 |
+
"""
|
58 |
+
Private: if infs are supported, used `has_infs()` instead.
|
59 |
+
"""
|
60 |
+
|
61 |
+
nan_repr: NanRepr = NanRepr.IEEE_754
|
62 |
+
"""
|
63 |
+
How NaNs are represent in this scalar type, returns NanRepr value.
|
64 |
+
(not applicable for integer types)
|
65 |
+
"""
|
66 |
+
|
67 |
+
def _floating_point_max_int(self) -> int:
|
68 |
+
assert (
|
69 |
+
self.mantissa <= 52 and self.exponent <= 11
|
70 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
71 |
+
|
72 |
+
max_mantissa = (1 << self.mantissa) - 1
|
73 |
+
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
74 |
+
max_mantissa = max_mantissa - 1
|
75 |
+
|
76 |
+
max_exponent = (1 << self.exponent) - 2
|
77 |
+
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
78 |
+
or self.nan_repr == NanRepr.NONE):
|
79 |
+
assert (
|
80 |
+
self.exponent < 11
|
81 |
+
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
82 |
+
max_exponent = max_exponent + 1
|
83 |
+
|
84 |
+
# adjust the exponent to match that of a double
|
85 |
+
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
86 |
+
# e is the exponent bits), there is some precedent for non-standard
|
87 |
+
# biases, example `float8_e4m3b11fnuz` here:
|
88 |
+
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
89 |
+
# complication we are just assuming the standard exponent bias until
|
90 |
+
# there is a need to support non-standard biases
|
91 |
+
exponent_bias = (1 << (self.exponent - 1)) - 1
|
92 |
+
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
93 |
+
|
94 |
+
max_exponent_double = (max_exponent - exponent_bias +
|
95 |
+
exponent_bias_double)
|
96 |
+
|
97 |
+
# shift the mantissa and exponent into the proper positions for an
|
98 |
+
# IEEE double and bitwise-or them together.
|
99 |
+
return (max_mantissa <<
|
100 |
+
(52 - self.mantissa)) | (max_exponent_double << 52)
|
101 |
+
|
102 |
+
def _floating_point_max(self) -> float:
|
103 |
+
double_raw = self._floating_point_max_int()
|
104 |
+
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
105 |
+
|
106 |
+
def _raw_max(self) -> Union[int, float]:
|
107 |
+
if self.is_floating_point():
|
108 |
+
return self._floating_point_max()
|
109 |
+
else:
|
110 |
+
assert (self.size_bits < 64 or self.size_bits == 64
|
111 |
+
and self.is_signed()), "Cannot represent max as an int"
|
112 |
+
return (1 << self.mantissa) - 1
|
113 |
+
|
114 |
+
def _raw_min(self) -> Union[int, float]:
|
115 |
+
if self.is_floating_point():
|
116 |
+
assert self.is_signed(
|
117 |
+
), "We currently assume all floating point types are signed"
|
118 |
+
sign_bit_double = 1 << 63
|
119 |
+
|
120 |
+
max_raw = self._floating_point_max_int()
|
121 |
+
min_raw = max_raw | sign_bit_double
|
122 |
+
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
123 |
+
else:
|
124 |
+
assert (not self.is_signed() or
|
125 |
+
self.size_bits <= 64), "Cannot represent min as a int64_t"
|
126 |
+
|
127 |
+
if self.is_signed():
|
128 |
+
return -(1 << (self.size_bits - 1))
|
129 |
+
else:
|
130 |
+
return 0
|
131 |
+
|
132 |
+
@functools.cached_property
|
133 |
+
def id(self) -> int:
|
134 |
+
"""
|
135 |
+
Convert the ScalarType to an int which can be passed to pytorch custom
|
136 |
+
ops. This layout of the int must be kept in sync with the C++
|
137 |
+
ScalarType's from_id method.
|
138 |
+
"""
|
139 |
+
val = 0
|
140 |
+
offset = 0
|
141 |
+
|
142 |
+
def or_and_advance(member, bit_width):
|
143 |
+
nonlocal val
|
144 |
+
nonlocal offset
|
145 |
+
bit_mask = (1 << bit_width) - 1
|
146 |
+
val = val | (int(member) & bit_mask) << offset
|
147 |
+
offset = offset + bit_width
|
148 |
+
|
149 |
+
or_and_advance(self.exponent, 8)
|
150 |
+
or_and_advance(self.mantissa, 8)
|
151 |
+
or_and_advance(self.signed, 1)
|
152 |
+
or_and_advance(self.bias, 32)
|
153 |
+
or_and_advance(self._finite_values_only, 1)
|
154 |
+
or_and_advance(self.nan_repr.value, 8)
|
155 |
+
|
156 |
+
assert offset <= 64, \
|
157 |
+
f"ScalarType fields too big {offset} to fit into an int64"
|
158 |
+
|
159 |
+
return val
|
160 |
+
|
161 |
+
@property
|
162 |
+
def size_bits(self) -> int:
|
163 |
+
return self.exponent + self.mantissa + int(self.signed)
|
164 |
+
|
165 |
+
def min(self) -> Union[int, float]:
|
166 |
+
"""
|
167 |
+
Min representable value for this scalar type.
|
168 |
+
(accounting for bias if there is one)
|
169 |
+
"""
|
170 |
+
return self._raw_min() - self.bias
|
171 |
+
|
172 |
+
def max(self) -> Union[int, float]:
|
173 |
+
"""
|
174 |
+
Max representable value for this scalar type.
|
175 |
+
(accounting for bias if there is one)
|
176 |
+
"""
|
177 |
+
return self._raw_max() - self.bias
|
178 |
+
|
179 |
+
def is_signed(self) -> bool:
|
180 |
+
"""
|
181 |
+
If the type is signed (i.e. has a sign bit), same as `signed`
|
182 |
+
added for consistency with:
|
183 |
+
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
184 |
+
"""
|
185 |
+
return self.signed
|
186 |
+
|
187 |
+
def is_floating_point(self) -> bool:
|
188 |
+
"If the type is a floating point type"
|
189 |
+
return self.exponent != 0
|
190 |
+
|
191 |
+
def is_integer(self) -> bool:
|
192 |
+
"If the type is an integer type"
|
193 |
+
return self.exponent == 0
|
194 |
+
|
195 |
+
def has_bias(self) -> bool:
|
196 |
+
"If the type has a non-zero bias"
|
197 |
+
return self.bias != 0
|
198 |
+
|
199 |
+
def has_infs(self) -> bool:
|
200 |
+
"If the type is floating point and supports infinity"
|
201 |
+
return not self._finite_values_only
|
202 |
+
|
203 |
+
def has_nans(self) -> bool:
|
204 |
+
return self.nan_repr != NanRepr.NONE.value
|
205 |
+
|
206 |
+
def is_ieee_754(self) -> bool:
|
207 |
+
"""
|
208 |
+
If the type is a floating point type that follows IEEE 754
|
209 |
+
conventions
|
210 |
+
"""
|
211 |
+
return self.nan_repr == NanRepr.IEEE_754.value and \
|
212 |
+
not self._finite_values_only
|
213 |
+
|
214 |
+
def __str__(self) -> str:
|
215 |
+
"""
|
216 |
+
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
217 |
+
for floating point types (leading f) the scheme is:
|
218 |
+
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
219 |
+
flags:
|
220 |
+
- no-flags: means it follows IEEE 754 conventions
|
221 |
+
- f: means finite values only (no infinities)
|
222 |
+
- n: means nans are supported (non-standard encoding)
|
223 |
+
for integer types the scheme is:
|
224 |
+
`[u]int<size_bits>[b<bias>]`
|
225 |
+
- if bias is not present it means its zero
|
226 |
+
"""
|
227 |
+
if self.is_floating_point():
|
228 |
+
ret = "float" + str(self.size_bits) + "_e" + str(
|
229 |
+
self.exponent) + "m" + str(self.mantissa)
|
230 |
+
|
231 |
+
if not self.is_ieee_754():
|
232 |
+
if self._finite_values_only:
|
233 |
+
ret = ret + "f"
|
234 |
+
if self.nan_repr != NanRepr.NONE:
|
235 |
+
ret = ret + "n"
|
236 |
+
|
237 |
+
return ret
|
238 |
+
else:
|
239 |
+
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
240 |
+
if self.has_bias():
|
241 |
+
ret = ret + "b" + str(self.bias)
|
242 |
+
return ret
|
243 |
+
|
244 |
+
def __repr__(self) -> str:
|
245 |
+
return "ScalarType." + self.__str__()
|
246 |
+
|
247 |
+
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
248 |
+
# opcheck to work.
|
249 |
+
def __len__(self) -> int:
|
250 |
+
raise TypeError
|
251 |
+
|
252 |
+
#
|
253 |
+
# Convenience Constructors
|
254 |
+
#
|
255 |
+
|
256 |
+
@classmethod
|
257 |
+
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
258 |
+
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
259 |
+
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
260 |
+
ret.id # noqa B018: make sure the id is cached
|
261 |
+
return ret
|
262 |
+
|
263 |
+
@classmethod
|
264 |
+
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
265 |
+
"""Create a unsigned integer scalar type."""
|
266 |
+
ret = cls(0, size_bits, False, bias if bias else 0)
|
267 |
+
ret.id # noqa B018: make sure the id is cached
|
268 |
+
return ret
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
272 |
+
"""
|
273 |
+
Create a standard floating point type
|
274 |
+
(i.e. follows IEEE 754 conventions).
|
275 |
+
"""
|
276 |
+
assert (mantissa > 0 and exponent > 0)
|
277 |
+
ret = cls(exponent, mantissa, True, 0)
|
278 |
+
ret.id # noqa B018: make sure the id is cached
|
279 |
+
return ret
|
280 |
+
|
281 |
+
@classmethod
|
282 |
+
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
283 |
+
nan_repr: NanRepr) -> 'ScalarType':
|
284 |
+
"""
|
285 |
+
Create a non-standard floating point type
|
286 |
+
(i.e. does not follow IEEE 754 conventions).
|
287 |
+
"""
|
288 |
+
assert (mantissa > 0 and exponent > 0)
|
289 |
+
assert (nan_repr != NanRepr.IEEE_754), (
|
290 |
+
"use `float_IEEE754` constructor for floating point types that "
|
291 |
+
"follow IEEE 754 conventions")
|
292 |
+
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
293 |
+
ret.id # noqa B018: make sure the id is cached
|
294 |
+
return ret
|
295 |
+
|
296 |
+
|
297 |
+
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
298 |
+
# for floating point types (leading f) the scheme is:
|
299 |
+
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
300 |
+
# flags:
|
301 |
+
# - no-flags: means it follows IEEE 754 conventions
|
302 |
+
# - f: means finite values only (no infinities)
|
303 |
+
# - n: means nans are supported (non-standard encoding)
|
304 |
+
# for integer types the scheme is:
|
305 |
+
# `[u]int<size_bits>[b<bias>]`
|
306 |
+
# - if bias is not present it means its zero
|
307 |
+
|
308 |
+
|
309 |
+
class scalar_types:
|
310 |
+
int4 = ScalarType.int_(4, None)
|
311 |
+
uint4 = ScalarType.uint(4, None)
|
312 |
+
int8 = ScalarType.int_(8, None)
|
313 |
+
uint8 = ScalarType.uint(8, None)
|
314 |
+
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
315 |
+
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
316 |
+
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
317 |
+
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
318 |
+
|
319 |
+
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
320 |
+
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
321 |
+
|
322 |
+
# "gptq" types
|
323 |
+
uint2b2 = ScalarType.uint(2, 2)
|
324 |
+
uint3b4 = ScalarType.uint(3, 4)
|
325 |
+
uint4b8 = ScalarType.uint(4, 8)
|
326 |
+
uint8b128 = ScalarType.uint(8, 128)
|
327 |
+
|
328 |
+
# colloquial names
|
329 |
+
bfloat16 = float16_e8m7
|
330 |
+
float16 = float16_e5m10
|
ext-torch/torch_binding.cpp
CHANGED
@@ -65,16 +65,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
65 |
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
66 |
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
67 |
"SymInt size_k) -> Tensor");
|
68 |
-
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
69 |
|
70 |
// awq_marlin repack from AWQ.
|
71 |
ops.def(
|
72 |
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
73 |
"SymInt size_n, int num_bits) -> Tensor");
|
74 |
-
ops.impl("awq_marlin_repack", &awq_marlin_repack);
|
75 |
|
76 |
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
77 |
-
ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
78 |
ops.def(
|
79 |
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
80 |
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
@@ -86,7 +83,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
86 |
ops.def(
|
87 |
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
88 |
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
|
|
|
|
90 |
}
|
91 |
|
92 |
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
|
|
|
65 |
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
66 |
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
67 |
"SymInt size_k) -> Tensor");
|
|
|
68 |
|
69 |
// awq_marlin repack from AWQ.
|
70 |
ops.def(
|
71 |
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
72 |
"SymInt size_n, int num_bits) -> Tensor");
|
|
|
73 |
|
74 |
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
|
|
75 |
ops.def(
|
76 |
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
77 |
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
|
|
83 |
ops.def(
|
84 |
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
85 |
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
86 |
+
|
87 |
+
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
88 |
+
ops.def(
|
89 |
+
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
90 |
+
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
91 |
+
"Tensor");
|
92 |
+
|
93 |
+
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
94 |
+
ops.def(
|
95 |
+
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
96 |
+
"Tensor b_scales, Tensor workspace, "
|
97 |
+
"int b_q_type, "
|
98 |
+
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
|
99 |
+
|
100 |
+
// marlin_qqq_gemm for QQQ.
|
101 |
+
ops.def(
|
102 |
+
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
103 |
+
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
104 |
+
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
105 |
+
"SymInt size_k) -> Tensor");
|
106 |
+
}
|
107 |
+
|
108 |
+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
|
109 |
+
ops.impl("awq_marlin_repack", &awq_marlin_repack);
|
110 |
+
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
111 |
+
ops.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
|
112 |
+
ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
113 |
ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
114 |
+
ops.impl("marlin_gemm", &marlin_gemm);
|
115 |
+
ops.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
|
116 |
}
|
117 |
|
118 |
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
|
ext-torch/torch_binding.h
CHANGED
@@ -74,3 +74,26 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|
74 |
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
75 |
torch::Tensor& perm, c10::SymInt size_k,
|
76 |
c10::SymInt size_n, int64_t num_bits);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
75 |
torch::Tensor& perm, c10::SymInt size_k,
|
76 |
c10::SymInt size_n, int64_t num_bits);
|
77 |
+
|
78 |
+
|
79 |
+
// Marlin
|
80 |
+
|
81 |
+
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
82 |
+
torch::Tensor& b_scales, torch::Tensor& workspace,
|
83 |
+
int64_t size_m, int64_t size_n, int64_t size_k);
|
84 |
+
|
85 |
+
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
86 |
+
torch::Tensor& b_meta,
|
87 |
+
torch::Tensor& b_scales,
|
88 |
+
torch::Tensor& workspace,
|
89 |
+
vllm::ScalarTypeId const b_q_type_id,
|
90 |
+
int64_t size_m, int64_t size_n,
|
91 |
+
int64_t size_k);
|
92 |
+
|
93 |
+
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
94 |
+
torch::Tensor const& b_q_weight,
|
95 |
+
torch::Tensor const& s_tok,
|
96 |
+
torch::Tensor const& s_ch,
|
97 |
+
torch::Tensor const& s_group,
|
98 |
+
torch::Tensor& workspace, int64_t size_m,
|
99 |
+
int64_t size_n, int64_t size_k);
|
ext-torch/utils/marlin_utils.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import quantization as ops
|
7 |
+
from quantization.scalar_type import ScalarType, scalar_types
|
8 |
+
|
9 |
+
from .quant_utils import pack_cols, unpack_cols
|
10 |
+
|
11 |
+
GPTQ_MARLIN_TILE = 16
|
12 |
+
GPTQ_MARLIN_MIN_THREAD_N = 64
|
13 |
+
GPTQ_MARLIN_MIN_THREAD_K = 128
|
14 |
+
GPTQ_MARLIN_MAX_PARALLEL = 16
|
15 |
+
|
16 |
+
GPTQ_MARLIN_24_TILE = 16
|
17 |
+
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
18 |
+
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
19 |
+
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
20 |
+
|
21 |
+
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
22 |
+
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
23 |
+
|
24 |
+
MARLIN_QQQ_TILE = 16
|
25 |
+
MARLIN_QQQ_MIN_THREAD_N = 64
|
26 |
+
MARLIN_QQQ_MIN_THREAD_K = 128
|
27 |
+
MARLIN_QQQ_MAX_PARALLEL = 16
|
28 |
+
|
29 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
30 |
+
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
|
31 |
+
MARLIN_QQQ_SUPPORTED_SYM = [True]
|
32 |
+
|
33 |
+
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
34 |
+
|
35 |
+
# In case there is a performance issue with Marlin, the variable below can be
|
36 |
+
# changed to False, which allows Marlin to perform global reductions in fp16
|
37 |
+
# precision (instead of fp32), and therefore, save on some memory movements.
|
38 |
+
USE_FP32_REDUCE_DEFAULT = True
|
39 |
+
|
40 |
+
|
41 |
+
# For binary size and compile time, we don't support the same types for with and
|
42 |
+
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
43 |
+
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
44 |
+
def query_marlin_supported_quant_types(
|
45 |
+
has_zp: bool, device_capability: Optional[int] = None
|
46 |
+
):
|
47 |
+
if device_capability is None:
|
48 |
+
capability_tuple = torch.cuda.get_device_capability()
|
49 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
50 |
+
|
51 |
+
if device_capability < 80:
|
52 |
+
return []
|
53 |
+
|
54 |
+
if has_zp:
|
55 |
+
# AWQ style, unsigned + runtime zero-point
|
56 |
+
return [scalar_types.uint4, scalar_types.uint8]
|
57 |
+
else:
|
58 |
+
# GPTQ style, unsigned + symmetric bias
|
59 |
+
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
60 |
+
# to add `scalar_types.float8_e4m3fn` here
|
61 |
+
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
62 |
+
|
63 |
+
|
64 |
+
def _check_marlin_supported(
|
65 |
+
quant_type: ScalarType,
|
66 |
+
group_size: Optional[int],
|
67 |
+
has_zp: bool,
|
68 |
+
device_capability: Optional[int] = None,
|
69 |
+
) -> Tuple[bool, Optional[str]]:
|
70 |
+
|
71 |
+
if device_capability is None:
|
72 |
+
capability_tuple = torch.cuda.get_device_capability()
|
73 |
+
device_capability = capability_tuple[0] * 10 + capability_tuple[1]
|
74 |
+
|
75 |
+
supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
|
76 |
+
|
77 |
+
if quant_type not in supported_types:
|
78 |
+
return (
|
79 |
+
False,
|
80 |
+
f"Marlin does not support weight_bits = {quant_type}. "
|
81 |
+
f"Only types = {supported_types} "
|
82 |
+
f"are supported (for group_size = {group_size}, "
|
83 |
+
f"device_capability = {device_capability}, zp = {has_zp}).",
|
84 |
+
)
|
85 |
+
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
86 |
+
return (
|
87 |
+
False,
|
88 |
+
f"Marlin does not support group_size = {group_size}. "
|
89 |
+
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
90 |
+
"are supported.",
|
91 |
+
)
|
92 |
+
|
93 |
+
return True, None
|
94 |
+
|
95 |
+
|
96 |
+
def check_marlin_supported(
|
97 |
+
quant_type: ScalarType,
|
98 |
+
group_size: int,
|
99 |
+
has_zp: bool = False,
|
100 |
+
device_capability: Optional[int] = None,
|
101 |
+
) -> bool:
|
102 |
+
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
|
103 |
+
return cond
|
104 |
+
|
105 |
+
|
106 |
+
def verify_marlin_supported(
|
107 |
+
quant_type: ScalarType, group_size: int, has_zp: bool = False
|
108 |
+
) -> None:
|
109 |
+
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
|
110 |
+
if not cond:
|
111 |
+
assert err_msg is not None
|
112 |
+
raise ValueError(err_msg)
|
113 |
+
|
114 |
+
|
115 |
+
def verify_marlin_supports_shape(
|
116 |
+
output_size_per_partition: int,
|
117 |
+
input_size_per_partition: int,
|
118 |
+
input_size: int,
|
119 |
+
group_size: int,
|
120 |
+
) -> None:
|
121 |
+
|
122 |
+
# Validate output_size_per_partition
|
123 |
+
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
|
124 |
+
raise ValueError(
|
125 |
+
f"Weight output_size_per_partition = "
|
126 |
+
f"{output_size_per_partition} is not divisible by "
|
127 |
+
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
|
128 |
+
"Consider reducing tensor_parallel_size or running "
|
129 |
+
"with --quantization gptq."
|
130 |
+
)
|
131 |
+
|
132 |
+
# Validate input_size_per_partition
|
133 |
+
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
|
134 |
+
raise ValueError(
|
135 |
+
f"Weight input_size_per_partition = "
|
136 |
+
f"{input_size_per_partition} is not divisible "
|
137 |
+
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
|
138 |
+
"Consider reducing tensor_parallel_size or running "
|
139 |
+
"with --quantization gptq."
|
140 |
+
)
|
141 |
+
|
142 |
+
if group_size < input_size and input_size_per_partition % group_size != 0:
|
143 |
+
raise ValueError(
|
144 |
+
f"Weight input_size_per_partition = {input_size_per_partition}"
|
145 |
+
f" is not divisible by group_size = {group_size}."
|
146 |
+
"Consider reducing tensor_parallel_size or running "
|
147 |
+
"with --quantization gptq."
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
def check_marlin_supports_shape(
|
152 |
+
output_size_per_partition: int,
|
153 |
+
input_size_per_partition: int,
|
154 |
+
input_size: int,
|
155 |
+
group_size: int,
|
156 |
+
) -> Tuple[bool, Optional[str]]:
|
157 |
+
try:
|
158 |
+
verify_marlin_supports_shape(
|
159 |
+
output_size_per_partition, input_size_per_partition, input_size, group_size
|
160 |
+
)
|
161 |
+
except ValueError as e:
|
162 |
+
return False, e.__str__()
|
163 |
+
return True, None
|
164 |
+
|
165 |
+
|
166 |
+
def marlin_make_workspace(
|
167 |
+
output_size_per_partition: int, device: torch.device
|
168 |
+
) -> torch.Tensor:
|
169 |
+
max_workspace_size = (
|
170 |
+
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
|
171 |
+
) * GPTQ_MARLIN_MAX_PARALLEL
|
172 |
+
|
173 |
+
return torch.zeros(
|
174 |
+
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
179 |
+
return (not act_order) or (act_order and not is_row_parallel)
|
180 |
+
|
181 |
+
|
182 |
+
def marlin_repeat_scales_on_all_ranks(
|
183 |
+
act_order: bool, group_size: int, is_row_parallel: bool
|
184 |
+
) -> bool:
|
185 |
+
# Need to repeat scales on every rank if act_ordering or
|
186 |
+
# channelwise and RowParallelLinear
|
187 |
+
is_channelwise = group_size == -1
|
188 |
+
return act_order or (is_channelwise and is_row_parallel)
|
189 |
+
|
190 |
+
|
191 |
+
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
192 |
+
return torch.nn.Parameter(
|
193 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
198 |
+
return torch.nn.Parameter(
|
199 |
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
204 |
+
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
205 |
+
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
206 |
+
|
207 |
+
|
208 |
+
def get_scale_perms():
|
209 |
+
scale_perm: List[int] = []
|
210 |
+
for i in range(8):
|
211 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
212 |
+
scale_perm_single: List[int] = []
|
213 |
+
for i in range(4):
|
214 |
+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
215 |
+
return scale_perm, scale_perm_single
|
216 |
+
|
217 |
+
|
218 |
+
def marlin_permute_scales(
|
219 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
220 |
+
) -> torch.Tensor:
|
221 |
+
|
222 |
+
scale_perm, scale_perm_single = get_scale_perms()
|
223 |
+
if group_size < size_k and group_size != -1:
|
224 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
225 |
+
else:
|
226 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
227 |
+
s = s.reshape((-1, size_n)).contiguous()
|
228 |
+
|
229 |
+
return s
|
230 |
+
|
231 |
+
|
232 |
+
def marlin_moe_permute_scales(
|
233 |
+
s: torch.Tensor,
|
234 |
+
size_k: int,
|
235 |
+
size_n: int,
|
236 |
+
group_size: int,
|
237 |
+
):
|
238 |
+
num_experts = s.shape[0]
|
239 |
+
output = torch.empty(
|
240 |
+
(num_experts, s.shape[1], s.shape[2]),
|
241 |
+
device=s.device,
|
242 |
+
dtype=s.dtype,
|
243 |
+
)
|
244 |
+
|
245 |
+
for e in range(num_experts):
|
246 |
+
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
247 |
+
return output
|
248 |
+
|
249 |
+
|
250 |
+
def marlin_zero_points(
|
251 |
+
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
252 |
+
) -> torch.Tensor:
|
253 |
+
# Permute zero-points in a similar way to scales, but do not use the
|
254 |
+
# "single" permutation, since zero-points are applied on every MMA
|
255 |
+
scale_perm, _ = get_scale_perms()
|
256 |
+
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
257 |
+
|
258 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
259 |
+
if num_bits == 4:
|
260 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
261 |
+
elif num_bits == 8:
|
262 |
+
interleave = numpy.array([0, 2, 1, 3])
|
263 |
+
else:
|
264 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
265 |
+
|
266 |
+
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
267 |
+
zp = zp.reshape((-1, size_n)).contiguous()
|
268 |
+
zp = pack_cols(zp, num_bits, size_k, size_n)
|
269 |
+
|
270 |
+
return zp
|
271 |
+
|
272 |
+
|
273 |
+
def awq_to_marlin_zero_points(
|
274 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
275 |
+
) -> torch.Tensor:
|
276 |
+
# AWQ zero-points are quantized and packed on the column dim.
|
277 |
+
# In addition, the values are permuted based on dequantizer.
|
278 |
+
# Here we undo both of these, and then apply marlin permutation
|
279 |
+
# and pack it back.
|
280 |
+
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
281 |
+
|
282 |
+
# Undo interleaving (use argsort(..) to get inverse perm)
|
283 |
+
if num_bits == 4:
|
284 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
285 |
+
elif num_bits == 8:
|
286 |
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
287 |
+
else:
|
288 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
289 |
+
|
290 |
+
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
291 |
+
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
292 |
+
|
293 |
+
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
294 |
+
return marlin_zp
|
295 |
+
|
296 |
+
|
297 |
+
def moe_awq_to_marlin_zero_points(
|
298 |
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
299 |
+
):
|
300 |
+
num_experts = q_zp_packed.shape[0]
|
301 |
+
output = torch.empty(
|
302 |
+
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
303 |
+
device=q_zp_packed.device,
|
304 |
+
dtype=q_zp_packed.dtype,
|
305 |
+
)
|
306 |
+
for e in range(num_experts):
|
307 |
+
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
308 |
+
return output
|
309 |
+
|
310 |
+
|
311 |
+
def apply_gptq_marlin_linear(
|
312 |
+
input: torch.Tensor,
|
313 |
+
weight: torch.Tensor,
|
314 |
+
weight_scale: torch.Tensor,
|
315 |
+
weight_zp: torch.Tensor,
|
316 |
+
g_idx: torch.Tensor,
|
317 |
+
g_idx_sort_indices: torch.Tensor,
|
318 |
+
workspace: torch.Tensor,
|
319 |
+
wtype: ScalarType,
|
320 |
+
output_size_per_partition: int,
|
321 |
+
input_size_per_partition: int,
|
322 |
+
is_k_full: bool,
|
323 |
+
bias: Optional[torch.Tensor] = None,
|
324 |
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
325 |
+
) -> torch.Tensor:
|
326 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
327 |
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
328 |
+
|
329 |
+
output = ops.gptq_marlin_gemm(
|
330 |
+
reshaped_x,
|
331 |
+
weight,
|
332 |
+
weight_scale,
|
333 |
+
weight_zp,
|
334 |
+
g_idx,
|
335 |
+
g_idx_sort_indices,
|
336 |
+
workspace,
|
337 |
+
wtype,
|
338 |
+
size_m=reshaped_x.shape[0],
|
339 |
+
size_n=output_size_per_partition,
|
340 |
+
size_k=input_size_per_partition,
|
341 |
+
is_k_full=is_k_full,
|
342 |
+
has_zp=False,
|
343 |
+
use_fp32_reduce=use_fp32_reduce,
|
344 |
+
is_zp_float=False,
|
345 |
+
)
|
346 |
+
|
347 |
+
if bias is not None:
|
348 |
+
output.add_(bias) # In-place add
|
349 |
+
|
350 |
+
return output.reshape(out_shape)
|
351 |
+
|
352 |
+
|
353 |
+
def apply_awq_marlin_linear(
|
354 |
+
input: torch.Tensor,
|
355 |
+
weight: torch.Tensor,
|
356 |
+
weight_scale: torch.Tensor,
|
357 |
+
weight_zp: torch.Tensor,
|
358 |
+
g_idx: torch.Tensor,
|
359 |
+
g_idx_sort_indices: torch.Tensor,
|
360 |
+
workspace: torch.Tensor,
|
361 |
+
quant_type: ScalarType,
|
362 |
+
output_size_per_partition: int,
|
363 |
+
input_size_per_partition: int,
|
364 |
+
bias: Optional[torch.Tensor] = None,
|
365 |
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
366 |
+
) -> torch.Tensor:
|
367 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
368 |
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
369 |
+
|
370 |
+
output = ops.gptq_marlin_gemm(
|
371 |
+
reshaped_x,
|
372 |
+
weight,
|
373 |
+
weight_scale,
|
374 |
+
weight_zp,
|
375 |
+
g_idx,
|
376 |
+
g_idx_sort_indices,
|
377 |
+
workspace,
|
378 |
+
quant_type,
|
379 |
+
size_m=reshaped_x.shape[0],
|
380 |
+
size_n=output_size_per_partition,
|
381 |
+
size_k=input_size_per_partition,
|
382 |
+
is_k_full=True,
|
383 |
+
has_zp=True,
|
384 |
+
use_fp32_reduce=use_fp32_reduce,
|
385 |
+
is_zp_float=False,
|
386 |
+
)
|
387 |
+
|
388 |
+
if bias is not None:
|
389 |
+
output.add_(bias) # In-place add
|
390 |
+
|
391 |
+
return output.reshape(out_shape)
|
ext-torch/utils/marlin_utils_fp8.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import quantization as ops
|
6 |
+
|
7 |
+
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
8 |
+
|
9 |
+
|
10 |
+
def is_fp8_marlin_supported():
|
11 |
+
capability = torch.cuda.get_device_capability()
|
12 |
+
capability = capability[0] * 10 + capability[1]
|
13 |
+
return capability >= 80
|
14 |
+
|
15 |
+
|
16 |
+
def apply_fp8_marlin_linear(
|
17 |
+
input: torch.Tensor,
|
18 |
+
weight: torch.Tensor,
|
19 |
+
weight_scale: torch.Tensor,
|
20 |
+
workspace: torch.Tensor,
|
21 |
+
size_n: int,
|
22 |
+
size_k: int,
|
23 |
+
bias: Optional[torch.Tensor],
|
24 |
+
) -> torch.Tensor:
|
25 |
+
# For GPUs that lack FP8 hardware support, we can leverage the
|
26 |
+
# Marlin kernel for fast weight-only FP8 quantization
|
27 |
+
|
28 |
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
29 |
+
out_shape = input.shape[:-1] + (size_n,)
|
30 |
+
|
31 |
+
output = ops.fp8_marlin_gemm(
|
32 |
+
a=reshaped_x,
|
33 |
+
b_q_weight=weight,
|
34 |
+
b_scales=weight_scale,
|
35 |
+
workspace=workspace,
|
36 |
+
num_bits=8,
|
37 |
+
size_m=reshaped_x.shape[0],
|
38 |
+
size_n=size_n,
|
39 |
+
size_k=size_k,
|
40 |
+
)
|
41 |
+
|
42 |
+
if bias is not None:
|
43 |
+
output.add_(bias) # In-place add
|
44 |
+
|
45 |
+
return output.reshape(out_shape)
|
46 |
+
|
47 |
+
|
48 |
+
def prepare_fp8_layer_for_marlin(
|
49 |
+
layer: torch.nn.Module, strategy: str = "tensor"
|
50 |
+
) -> None:
|
51 |
+
part_size_n = layer.output_size_per_partition
|
52 |
+
part_size_k = layer.input_size_per_partition
|
53 |
+
|
54 |
+
device = layer.weight.device
|
55 |
+
|
56 |
+
# WORKSPACE
|
57 |
+
layer.workspace = marlin_make_workspace(part_size_n, device)
|
58 |
+
|
59 |
+
# WEIGHT
|
60 |
+
# Repack weights to marlin format
|
61 |
+
marlin_qweight = ops.gptq_marlin_repack(
|
62 |
+
b_q_weight=pack_fp8_to_int32(layer.weight),
|
63 |
+
perm=torch.empty(0, dtype=torch.int, device=device),
|
64 |
+
size_k=part_size_k,
|
65 |
+
size_n=part_size_n,
|
66 |
+
num_bits=8,
|
67 |
+
)
|
68 |
+
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
69 |
+
|
70 |
+
# WEIGHT SCALES
|
71 |
+
scales = layer.weight_scale.to(layer.orig_dtype)
|
72 |
+
# Permute scales
|
73 |
+
marlin_scales = marlin_permute_scales(
|
74 |
+
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
|
75 |
+
)
|
76 |
+
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
77 |
+
|
78 |
+
|
79 |
+
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
Repack FP8 weights to gptq format (packed int32 elements)
|
82 |
+
"""
|
83 |
+
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
84 |
+
assert fp8_tensor.shape[0] % 4 == 0
|
85 |
+
|
86 |
+
# Reshape to prepare for packing
|
87 |
+
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
88 |
+
|
89 |
+
# Convert fp8 to uint8 (byte) representation
|
90 |
+
byte_tensor = reshaped.view(torch.uint8)
|
91 |
+
|
92 |
+
# Pack 4 uint8 values into one int32
|
93 |
+
packed = (
|
94 |
+
byte_tensor[:, 0].to(torch.int32)
|
95 |
+
| (byte_tensor[:, 1].to(torch.int32) << 8)
|
96 |
+
| (byte_tensor[:, 2].to(torch.int32) << 16)
|
97 |
+
| (byte_tensor[:, 3].to(torch.int32) << 24)
|
98 |
+
)
|
99 |
+
|
100 |
+
return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
|
ext-torch/utils/marlin_utils_test.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from quantization.scalar_type import ScalarType
|
9 |
+
|
10 |
+
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
11 |
+
from .quant_utils import (
|
12 |
+
get_pack_factor,
|
13 |
+
gptq_quantize_weights,
|
14 |
+
quantize_weights,
|
15 |
+
sort_weights,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class MarlinWorkspace:
|
20 |
+
|
21 |
+
def __init__(self, out_features, min_thread_n, max_parallel):
|
22 |
+
assert (
|
23 |
+
out_features % min_thread_n == 0
|
24 |
+
), "out_features = {} is undivisible by min_thread_n = {}".format(
|
25 |
+
out_features, min_thread_n
|
26 |
+
)
|
27 |
+
|
28 |
+
max_workspace_size = (out_features // min_thread_n) * max_parallel
|
29 |
+
|
30 |
+
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
|
31 |
+
|
32 |
+
|
33 |
+
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
34 |
+
assert q_w.shape == (size_k, size_n)
|
35 |
+
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
36 |
+
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
37 |
+
|
38 |
+
# Permute weights to 16x64 marlin tiles
|
39 |
+
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
40 |
+
q_w = q_w.permute((0, 2, 1, 3))
|
41 |
+
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
42 |
+
|
43 |
+
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
44 |
+
|
45 |
+
return q_w
|
46 |
+
|
47 |
+
|
48 |
+
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
49 |
+
# Permute
|
50 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
51 |
+
|
52 |
+
# Pack
|
53 |
+
pack_factor = get_pack_factor(num_bits)
|
54 |
+
orig_device = q_w.device
|
55 |
+
|
56 |
+
q_w = q_w.cpu().numpy().astype(np.uint32)
|
57 |
+
|
58 |
+
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
|
59 |
+
for i in range(pack_factor):
|
60 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
61 |
+
|
62 |
+
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
|
63 |
+
|
64 |
+
return q_packed
|
65 |
+
|
66 |
+
|
67 |
+
def get_weight_perm(num_bits: int):
|
68 |
+
perm_list: List[int] = []
|
69 |
+
for i in range(32):
|
70 |
+
perm1: List[int] = []
|
71 |
+
col = i // 4
|
72 |
+
for block in [0, 1]:
|
73 |
+
for row in [
|
74 |
+
2 * (i % 4),
|
75 |
+
2 * (i % 4) + 1,
|
76 |
+
2 * (i % 4 + 4),
|
77 |
+
2 * (i % 4 + 4) + 1,
|
78 |
+
]:
|
79 |
+
perm1.append(16 * row + col + 8 * block)
|
80 |
+
for j in range(4):
|
81 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
82 |
+
|
83 |
+
perm = np.array(perm_list)
|
84 |
+
|
85 |
+
if num_bits == 4:
|
86 |
+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
87 |
+
elif num_bits == 8:
|
88 |
+
interleave = np.array([0, 2, 1, 3])
|
89 |
+
else:
|
90 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
91 |
+
|
92 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
93 |
+
perm = torch.from_numpy(perm)
|
94 |
+
return perm
|
95 |
+
|
96 |
+
|
97 |
+
def marlin_quantize(
|
98 |
+
w: torch.Tensor,
|
99 |
+
quant_type: ScalarType,
|
100 |
+
group_size: int,
|
101 |
+
act_order: bool,
|
102 |
+
test_perm: Optional[torch.Tensor] = None,
|
103 |
+
):
|
104 |
+
size_k, size_n = w.shape
|
105 |
+
num_bits = quant_type.size_bits
|
106 |
+
|
107 |
+
# Normalize group_size
|
108 |
+
if group_size == -1:
|
109 |
+
group_size = size_k
|
110 |
+
assert group_size <= size_k
|
111 |
+
|
112 |
+
# Quantize (and apply act_order if provided)
|
113 |
+
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
114 |
+
w, quant_type, group_size, act_order, test_perm
|
115 |
+
)
|
116 |
+
|
117 |
+
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
118 |
+
# increasing
|
119 |
+
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
120 |
+
if act_order:
|
121 |
+
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
122 |
+
|
123 |
+
# Reformat to marlin
|
124 |
+
weight_perm = get_weight_perm(num_bits)
|
125 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
126 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
127 |
+
|
128 |
+
# Create result
|
129 |
+
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
130 |
+
for i in range(len(res_list)):
|
131 |
+
res_list[i] = res_list[i].to(w.device)
|
132 |
+
|
133 |
+
return res_list
|
134 |
+
|
135 |
+
|
136 |
+
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
|
137 |
+
size_k, size_n = w.shape
|
138 |
+
|
139 |
+
# Normalize group_size
|
140 |
+
if group_size == -1:
|
141 |
+
group_size = size_k
|
142 |
+
assert group_size <= size_k
|
143 |
+
|
144 |
+
# Detect num groups
|
145 |
+
assert size_k % group_size == 0
|
146 |
+
num_groups = size_k // group_size
|
147 |
+
|
148 |
+
# Quantize with zp
|
149 |
+
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
|
150 |
+
|
151 |
+
# Reformat to marlin
|
152 |
+
weight_perm = get_weight_perm(quant_type.size_bits)
|
153 |
+
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
|
154 |
+
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
155 |
+
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
|
156 |
+
|
157 |
+
# Create result
|
158 |
+
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
159 |
+
for i in range(len(res_list)):
|
160 |
+
res_list[i] = res_list[i].to(w.device)
|
161 |
+
|
162 |
+
return res_list
|
ext-torch/utils/marlin_utils_test_24.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions used for tests and benchmarks"""
|
2 |
+
|
3 |
+
import random
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import numpy
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from quantization.scalar_type import ScalarType
|
10 |
+
|
11 |
+
from .marlin_utils_test import marlin_weights
|
12 |
+
from .quant_utils import gptq_quantize_weights
|
13 |
+
|
14 |
+
|
15 |
+
# This is PyTorch implementation of main part of reorder_meta()
|
16 |
+
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
17 |
+
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
18 |
+
# GEMM decides upon layout of this matrix, and at the moment for the
|
19 |
+
# sparse GEMM executed on tensor cores, this is layout described by
|
20 |
+
# ColumnMajorInterleaved<2> data structure, in
|
21 |
+
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
22 |
+
# reordering of meta matrix into meta_reordered matrix calculated
|
23 |
+
# according to these segments of CUTLASS code is re-implemented here.
|
24 |
+
# Note that this calculation produces offsets for scattering metadata
|
25 |
+
# matrix elements into reordered metadata matrix elements (or,
|
26 |
+
# equivalently, for gathering reordered metadata matrix element back
|
27 |
+
# into metadata matrix elements).
|
28 |
+
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
29 |
+
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
30 |
+
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
31 |
+
|
32 |
+
# Reorder the rows, then swizzle the 2x2 blocks.
|
33 |
+
group_x = 64
|
34 |
+
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
35 |
+
|
36 |
+
dst_rows = (
|
37 |
+
dst_rows // group_x * group_x
|
38 |
+
+ (dst_rows % 2) * 2
|
39 |
+
+ (dst_rows % 8) // 4
|
40 |
+
+ ((dst_rows % group_y) % 4) // 2 * 32
|
41 |
+
+ ((dst_rows % group_x) // 8) * 4
|
42 |
+
)
|
43 |
+
|
44 |
+
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
45 |
+
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
46 |
+
dst_rows += topright - bottomleft
|
47 |
+
dst_cols -= topright - bottomleft
|
48 |
+
|
49 |
+
# Assumed that meta tensor is to be stored in CUTLASS
|
50 |
+
# InterleavedColumnMajor layout, and reverse engineered
|
51 |
+
# corresponding code to store values into this tensor.
|
52 |
+
interleave = 2
|
53 |
+
cols_maj = dst_cols // interleave
|
54 |
+
cols_min = dst_cols % interleave
|
55 |
+
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
56 |
+
|
57 |
+
|
58 |
+
# This function converts dense matrix into sparse semi-structured
|
59 |
+
# representation, producing "compressed" matrix, in the layout used by
|
60 |
+
# CUTLASS backend, and corresponding metadata matrix.
|
61 |
+
def sparse_semi_structured_from_dense_cutlass(dense):
|
62 |
+
if dense.dim() != 2:
|
63 |
+
raise RuntimeError(
|
64 |
+
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
65 |
+
)
|
66 |
+
|
67 |
+
m, k = dense.shape
|
68 |
+
device = dense.device
|
69 |
+
|
70 |
+
meta_dtype = torch.int8
|
71 |
+
if dense.dtype == torch.int8:
|
72 |
+
meta_dtype = torch.int32
|
73 |
+
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
74 |
+
meta_dtype = torch.int16
|
75 |
+
else:
|
76 |
+
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
77 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
78 |
+
if quadbits_per_meta_elem not in (4, 8):
|
79 |
+
raise RuntimeError("Invalid number of elements per meta element calculated")
|
80 |
+
|
81 |
+
if meta_dtype == torch.int32:
|
82 |
+
if m % 16 != 0:
|
83 |
+
raise RuntimeError(
|
84 |
+
f"Number of rows of dense matrix {m} must be divisible by 16"
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
if m % 32 != 0:
|
88 |
+
raise RuntimeError(
|
89 |
+
f"Number of rows of dense matrix {m} must be divisible by 32"
|
90 |
+
)
|
91 |
+
if k % (4 * quadbits_per_meta_elem) != 0:
|
92 |
+
raise RuntimeError(
|
93 |
+
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
94 |
+
)
|
95 |
+
|
96 |
+
if dense.dtype != torch.float:
|
97 |
+
ksparse = 4
|
98 |
+
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
99 |
+
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
100 |
+
else:
|
101 |
+
ksparse = 2
|
102 |
+
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
103 |
+
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
104 |
+
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
105 |
+
|
106 |
+
# Encoding quadruples of True/False values as follows:
|
107 |
+
# [True, True, False, False] -> 0b0100
|
108 |
+
# [True, False, True, False] -> 0b1000
|
109 |
+
# [False, True, True, False] -> 0b1001
|
110 |
+
# [True, False, False, True ] -> 0b1100
|
111 |
+
# [False, True, False, True ] -> 0b1101
|
112 |
+
# [False, False, True, True ] -> 0b1110
|
113 |
+
# Thus, lower two bits in the encoding are index of the True value
|
114 |
+
# at the lowest index in the quadruple, and the higher two bits in
|
115 |
+
# the encoding are index of the other True value in the quadruple.
|
116 |
+
# In case there are less than two True values, than False value or
|
117 |
+
# values at some index or indices are considered True for the
|
118 |
+
# encoding. In case there are more than two True values, then the
|
119 |
+
# excess True value(s) at some indices are considered False for
|
120 |
+
# the encoding. The exact encodings used for these cases are as
|
121 |
+
# follows:
|
122 |
+
# [False, False, False, False] -> 0b1110
|
123 |
+
# [False, False, False, True ] -> 0b1110
|
124 |
+
# [False, False, True, False] -> 0b1110
|
125 |
+
# [False, True, False, False] -> 0b1001
|
126 |
+
# [False, True, True, True ] -> 0b1101
|
127 |
+
# [True, False, False, False] -> 0b1000
|
128 |
+
# [True, False, True, True ] -> 0b1100
|
129 |
+
# [True, True, False, True ] -> 0b0100
|
130 |
+
# [True, True, True, False] -> 0b0100
|
131 |
+
# [True, True, True, True ] -> 0b0100
|
132 |
+
# These particular encodings are chosen, with the help of Espresso
|
133 |
+
# logic minimizer software, for the purpose of minimization of
|
134 |
+
# corresponding Boolean functions, that translate non-zero flags
|
135 |
+
# into encoding bits. Note also possible choices for the first
|
136 |
+
# and last of these encodings were limited only to (0b0100,
|
137 |
+
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
138 |
+
# case.
|
139 |
+
|
140 |
+
expr0 = m0 & m1
|
141 |
+
expr1 = ~m0 & m1
|
142 |
+
expr2 = ~m0 & ~m1
|
143 |
+
bit0 = expr1
|
144 |
+
bit1 = expr2
|
145 |
+
bit2 = expr0 | expr2 | m3
|
146 |
+
bit3 = expr1 | ~m1
|
147 |
+
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
148 |
+
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
149 |
+
|
150 |
+
if dense.dtype != torch.float:
|
151 |
+
sparse0 = dense_4.gather(
|
152 |
+
-1, idxs0.unsqueeze(-1)
|
153 |
+
) # type: ignore[possibly-undefined]
|
154 |
+
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
155 |
+
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
156 |
+
else:
|
157 |
+
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
|
158 |
+
m, k // 2
|
159 |
+
) # type: ignore[possibly-undefined]
|
160 |
+
|
161 |
+
meta_4 = idxs0 | (idxs1 << 2)
|
162 |
+
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
163 |
+
|
164 |
+
if quadbits_per_meta_elem == 4:
|
165 |
+
meta = (
|
166 |
+
meta_n[:, :, 0]
|
167 |
+
| (meta_n[:, :, 1] << 4)
|
168 |
+
| (meta_n[:, :, 2] << 8)
|
169 |
+
| (meta_n[:, :, 3] << 12)
|
170 |
+
)
|
171 |
+
elif quadbits_per_meta_elem == 8:
|
172 |
+
meta = (
|
173 |
+
meta_n[:, :, 0]
|
174 |
+
| (meta_n[:, :, 1] << 4)
|
175 |
+
| (meta_n[:, :, 2] << 8)
|
176 |
+
| (meta_n[:, :, 3] << 12)
|
177 |
+
| (meta_n[:, :, 4] << 16)
|
178 |
+
| (meta_n[:, :, 5] << 20)
|
179 |
+
| (meta_n[:, :, 6] << 24)
|
180 |
+
| (meta_n[:, :, 7] << 28)
|
181 |
+
)
|
182 |
+
|
183 |
+
# Reorder meta tensor elements.
|
184 |
+
meta_reordered = meta.new_empty(
|
185 |
+
(m * meta_ncols,)
|
186 |
+
) # type: ignore[possibly-undefined]
|
187 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
188 |
+
m, meta_ncols, meta_dtype, device
|
189 |
+
)
|
190 |
+
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
191 |
+
|
192 |
+
return (sparse, meta_reordered.view(m, meta_ncols))
|
193 |
+
|
194 |
+
|
195 |
+
# This function performs reverse of the function above - it
|
196 |
+
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
197 |
+
# in the layout used by CUTLASS backend, and accompanying metadata
|
198 |
+
# matrix.
|
199 |
+
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
200 |
+
if sparse.dim() != 2:
|
201 |
+
raise RuntimeError(
|
202 |
+
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
203 |
+
)
|
204 |
+
|
205 |
+
m, k = sparse.shape
|
206 |
+
device = sparse.device
|
207 |
+
|
208 |
+
if meta_reordered.dim() != 2:
|
209 |
+
raise RuntimeError(
|
210 |
+
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
211 |
+
)
|
212 |
+
if meta_reordered.device != device:
|
213 |
+
raise RuntimeError(
|
214 |
+
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
215 |
+
)
|
216 |
+
|
217 |
+
meta_dtype = meta_reordered.dtype
|
218 |
+
if meta_dtype not in (torch.int16, torch.int32):
|
219 |
+
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
220 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
221 |
+
|
222 |
+
ksparse = 4 if sparse.dtype != torch.float else 2
|
223 |
+
|
224 |
+
meta_nrows, meta_ncols = meta_reordered.shape
|
225 |
+
if meta_nrows != m:
|
226 |
+
raise RuntimeError(
|
227 |
+
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
228 |
+
)
|
229 |
+
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
230 |
+
raise RuntimeError(
|
231 |
+
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
232 |
+
"expected according to the number of columns of meta matrix"
|
233 |
+
)
|
234 |
+
|
235 |
+
# Undo meta tensor elements reordering.
|
236 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
237 |
+
m, meta_ncols, meta_dtype, device
|
238 |
+
)
|
239 |
+
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
240 |
+
|
241 |
+
# Unpack sparse tensor back to original dense tensor, using
|
242 |
+
# information provided by meta tensor. Note that torch.float
|
243 |
+
# datatype is handled pretty much the same as
|
244 |
+
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
245 |
+
# value is encoded as if underlying 8 bytes contain four
|
246 |
+
# torch.half/torch.bfloat16 values, where either first two or last
|
247 |
+
# two are zeros.
|
248 |
+
meta_2 = torch.empty(
|
249 |
+
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
250 |
+
dtype=meta_dtype,
|
251 |
+
device=device,
|
252 |
+
)
|
253 |
+
if quadbits_per_meta_elem == 4:
|
254 |
+
meta_2[:, :, 0] = meta & 0b11
|
255 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
256 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
257 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
258 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
259 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
260 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
261 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
262 |
+
elif quadbits_per_meta_elem == 8:
|
263 |
+
meta_2[:, :, 0] = meta & 0b11
|
264 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
265 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
266 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
267 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
268 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
269 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
270 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
271 |
+
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
272 |
+
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
273 |
+
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
274 |
+
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
275 |
+
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
276 |
+
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
277 |
+
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
278 |
+
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
279 |
+
|
280 |
+
dense_offsets = meta_2.view(-1) + (
|
281 |
+
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
282 |
+
).view(-1, 1).repeat(1, 2).view(-1)
|
283 |
+
|
284 |
+
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
285 |
+
if sparse.dtype != torch.float:
|
286 |
+
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
287 |
+
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
288 |
+
else:
|
289 |
+
dense.view(torch.half).scatter_(
|
290 |
+
0, dense_offsets, sparse.view(torch.half).view(-1)
|
291 |
+
)
|
292 |
+
|
293 |
+
return dense.view(m, 2 * k)
|
294 |
+
|
295 |
+
|
296 |
+
def mask_creator(tensor):
|
297 |
+
"""
|
298 |
+
Class for creating N:M sparsity masks.
|
299 |
+
Masks will be created using the N:M ratio, where for every block of
|
300 |
+
M weights, N will be pruned based on ranked weight value. Each mask
|
301 |
+
will correspond to the given tensor.
|
302 |
+
|
303 |
+
:param N: The number of weights in a group to keep
|
304 |
+
:param M: The size of a weight group
|
305 |
+
"""
|
306 |
+
N = 2
|
307 |
+
M = 4
|
308 |
+
|
309 |
+
mask = None
|
310 |
+
# for i, tensor in enumerate(tensors):
|
311 |
+
if tensor.numel() % M != 0:
|
312 |
+
raise ValueError(
|
313 |
+
f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
|
314 |
+
)
|
315 |
+
|
316 |
+
num_groups = tensor.numel() // M
|
317 |
+
|
318 |
+
# N:M sparsity for linear layers
|
319 |
+
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
320 |
+
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
321 |
+
|
322 |
+
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
323 |
+
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
324 |
+
|
325 |
+
return mask
|
326 |
+
|
327 |
+
|
328 |
+
def inject_24(w, size_k, size_n):
|
329 |
+
assert w.shape == (size_k, size_n)
|
330 |
+
|
331 |
+
mask = mask_creator(w.t()).t().cuda().bool()
|
332 |
+
|
333 |
+
return (mask * w).contiguous(), mask.contiguous()
|
334 |
+
|
335 |
+
|
336 |
+
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
337 |
+
BLOCK_SIZE = 4
|
338 |
+
MAX_NON_ZEROS = 2
|
339 |
+
|
340 |
+
w = w.t().contiguous()
|
341 |
+
|
342 |
+
print("check_24: w.shape = {}".format(w.shape))
|
343 |
+
|
344 |
+
num_rows, num_cols = w.shape
|
345 |
+
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
346 |
+
if _verbose:
|
347 |
+
print(f"Sampled row idxs = {sampled_row_idxs}")
|
348 |
+
|
349 |
+
total_segments = 0
|
350 |
+
non_24_segments = 0
|
351 |
+
for i in sampled_row_idxs:
|
352 |
+
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
353 |
+
total_segments += 1
|
354 |
+
block = w[i, j : j + BLOCK_SIZE]
|
355 |
+
num_nonzero = torch.count_nonzero(block)
|
356 |
+
if num_nonzero > MAX_NON_ZEROS:
|
357 |
+
print("i = {} j = {} block = {}".format(i, j, block))
|
358 |
+
non_24_segments += 1
|
359 |
+
|
360 |
+
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
361 |
+
|
362 |
+
|
363 |
+
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
364 |
+
assert q_24.shape == (size_k, size_n)
|
365 |
+
|
366 |
+
# Remove bias to normalize over 0
|
367 |
+
q_24_no_zp = q_24 - wtype.bias
|
368 |
+
|
369 |
+
# Compress
|
370 |
+
q_24_no_zp = q_24_no_zp.t().contiguous()
|
371 |
+
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
|
372 |
+
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
373 |
+
|
374 |
+
# Restore bias
|
375 |
+
q_24_comp = q_24_no_zp_comp + wtype.bias
|
376 |
+
|
377 |
+
# Resize meta to its actual shape (without moving any data)
|
378 |
+
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
379 |
+
|
380 |
+
return q_24_comp, meta
|
381 |
+
|
382 |
+
|
383 |
+
def get_scale_perms_24():
|
384 |
+
scale_perm: List[int] = []
|
385 |
+
for i in range(8):
|
386 |
+
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
387 |
+
scale_perm_single: List[int] = []
|
388 |
+
for i in range(8):
|
389 |
+
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
390 |
+
return scale_perm, scale_perm_single
|
391 |
+
|
392 |
+
|
393 |
+
def get_weight_perm_24(num_bits: int):
|
394 |
+
perm_list: List[int] = []
|
395 |
+
for i in range(32):
|
396 |
+
perm1: List[int] = []
|
397 |
+
col = i // 4
|
398 |
+
col_o = col // 2
|
399 |
+
for block in [0, 1]:
|
400 |
+
for row in [
|
401 |
+
2 * (i % 4),
|
402 |
+
2 * (i % 4) + 1,
|
403 |
+
2 * (i % 4 + 4),
|
404 |
+
2 * (i % 4 + 4) + 1,
|
405 |
+
]:
|
406 |
+
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
|
407 |
+
for j in range(4):
|
408 |
+
perm_list.extend([p + 1 * j for p in perm1])
|
409 |
+
perm = numpy.array(perm_list)
|
410 |
+
|
411 |
+
if num_bits == 4:
|
412 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
413 |
+
elif num_bits == 8:
|
414 |
+
interleave = numpy.array([0, 2, 1, 3])
|
415 |
+
else:
|
416 |
+
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
417 |
+
|
418 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
419 |
+
perm = torch.from_numpy(perm)
|
420 |
+
return perm
|
421 |
+
|
422 |
+
|
423 |
+
def marlin_permute_scales_24(
|
424 |
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
425 |
+
) -> torch.Tensor:
|
426 |
+
|
427 |
+
scale_perm, scale_perm_single = get_scale_perms_24()
|
428 |
+
if group_size < size_k and group_size != -1:
|
429 |
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
430 |
+
else:
|
431 |
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
432 |
+
s = s.reshape((-1, size_n)).contiguous()
|
433 |
+
|
434 |
+
return s
|
435 |
+
|
436 |
+
|
437 |
+
def marlin_24_quantize(
|
438 |
+
w: torch.Tensor,
|
439 |
+
quant_type: ScalarType,
|
440 |
+
group_size: int,
|
441 |
+
):
|
442 |
+
size_k, size_n = w.shape
|
443 |
+
|
444 |
+
# Normalize group_size
|
445 |
+
if group_size == -1:
|
446 |
+
group_size = size_k
|
447 |
+
assert group_size <= size_k
|
448 |
+
|
449 |
+
# Inject 2:4 sparsity
|
450 |
+
w_24, mask_24 = inject_24(w, size_k, size_n)
|
451 |
+
|
452 |
+
# Quantize
|
453 |
+
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
|
454 |
+
w_24, quant_type, group_size, act_order=False
|
455 |
+
)
|
456 |
+
|
457 |
+
# Compress quantized weight
|
458 |
+
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
|
459 |
+
size_k_comp = size_k // 2
|
460 |
+
|
461 |
+
# Reformat to marlin
|
462 |
+
weight_perm = get_weight_perm_24(quant_type.size_bits)
|
463 |
+
marlin_24_q_w_comp = marlin_weights(
|
464 |
+
q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
|
465 |
+
)
|
466 |
+
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
|
467 |
+
|
468 |
+
# Create result
|
469 |
+
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
470 |
+
for i in range(len(res_list)):
|
471 |
+
res_list[i] = res_list[i].to(w.device)
|
472 |
+
|
473 |
+
return res_list
|
ext-torch/utils/marlin_utils_test_qqq.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .marlin_utils_test import marlin_permute_weights
|
7 |
+
from .quant_utils import get_pack_factor, qqq_quantize_weights
|
8 |
+
|
9 |
+
|
10 |
+
def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
|
11 |
+
# Permute
|
12 |
+
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
13 |
+
|
14 |
+
# Pack
|
15 |
+
pack_factor = get_pack_factor(num_bits)
|
16 |
+
orig_device = q_w.device
|
17 |
+
|
18 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
19 |
+
|
20 |
+
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
21 |
+
dtype=numpy.uint32)
|
22 |
+
if group_size == size_k:
|
23 |
+
for i in range(pack_factor):
|
24 |
+
q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
|
25 |
+
else:
|
26 |
+
for i in range(pack_factor):
|
27 |
+
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
28 |
+
|
29 |
+
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
30 |
+
|
31 |
+
return q_packed
|
32 |
+
|
33 |
+
|
34 |
+
def get_qqq_scale_perms():
|
35 |
+
scale_perm: List[int] = []
|
36 |
+
for i in range(8):
|
37 |
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
38 |
+
scale_perm_single: List[int] = []
|
39 |
+
for i in range(4):
|
40 |
+
scale_perm_single.extend(
|
41 |
+
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
42 |
+
return scale_perm, scale_perm_single
|
43 |
+
|
44 |
+
|
45 |
+
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
|
46 |
+
def get_qqq_weight_perm(num_bits: int, quant_type: str):
|
47 |
+
perm_list: List[int] = []
|
48 |
+
for i in range(32):
|
49 |
+
perm1: List[int] = []
|
50 |
+
col = i // 4
|
51 |
+
for block in [0, 1]:
|
52 |
+
for row in [
|
53 |
+
4 * (i % 4),
|
54 |
+
4 * (i % 4) + 1,
|
55 |
+
4 * (i % 4) + 2,
|
56 |
+
4 * (i % 4) + 3,
|
57 |
+
]:
|
58 |
+
perm1.append(16 * row + col + 8 * block)
|
59 |
+
for j in range(4):
|
60 |
+
perm_list.extend([p + 256 * j for p in perm1])
|
61 |
+
|
62 |
+
perm = numpy.array(perm_list)
|
63 |
+
|
64 |
+
assert quant_type in ["per-channel",
|
65 |
+
"per-group"], "not supported quantization type"
|
66 |
+
if num_bits == 4:
|
67 |
+
if quant_type == "per-channel":
|
68 |
+
interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
|
69 |
+
else:
|
70 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
71 |
+
else:
|
72 |
+
raise Exception("num_bits must be 4, got {}".format(num_bits))
|
73 |
+
|
74 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
75 |
+
perm = torch.from_numpy(perm)
|
76 |
+
return perm
|
77 |
+
|
78 |
+
|
79 |
+
def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
|
80 |
+
scale_perm, scale_perm_single = get_qqq_scale_perms()
|
81 |
+
if group_size < size_k and group_size != -1:
|
82 |
+
s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
|
83 |
+
s_channel = s_channel.reshape(
|
84 |
+
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
85 |
+
s_group = s_group.reshape((-1, size_n)).contiguous()
|
86 |
+
else:
|
87 |
+
s_channel = s_channel.reshape(
|
88 |
+
(-1, len(scale_perm_single)))[:, scale_perm_single]
|
89 |
+
s_channel = s_channel.reshape((-1, size_n)).contiguous()
|
90 |
+
|
91 |
+
return s_group, s_channel
|
92 |
+
|
93 |
+
|
94 |
+
def marlin_qqq_quantize(
|
95 |
+
w: torch.Tensor,
|
96 |
+
num_bits: int,
|
97 |
+
group_size: int,
|
98 |
+
):
|
99 |
+
size_k, size_n = w.shape
|
100 |
+
|
101 |
+
# Normalize group_size
|
102 |
+
if group_size == -1:
|
103 |
+
group_size = size_k
|
104 |
+
assert group_size <= size_k
|
105 |
+
quant_type = "per-channel" if group_size == size_k else "per-group"
|
106 |
+
|
107 |
+
# Quantize
|
108 |
+
w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
|
109 |
+
w, num_bits, group_size)
|
110 |
+
|
111 |
+
# Reformat to marlin_qqq
|
112 |
+
weight_perm = get_qqq_weight_perm(num_bits, quant_type)
|
113 |
+
marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
|
114 |
+
weight_perm, group_size)
|
115 |
+
marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
|
116 |
+
s_group, s_channel, size_k, size_n, group_size)
|
117 |
+
|
118 |
+
# Create result
|
119 |
+
res_list = [
|
120 |
+
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
|
121 |
+
]
|
122 |
+
for i in range(len(res_list)):
|
123 |
+
res_list[i] = res_list[i].to(w.device)
|
124 |
+
|
125 |
+
return res_list
|
ext-torch/utils/quant_utils.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file is used for /tests and /benchmarks"""
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from quantization.scalar_type import ScalarType, scalar_types
|
9 |
+
|
10 |
+
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
11 |
+
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
12 |
+
|
13 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
|
14 |
+
|
15 |
+
# Note: this is a hack. We should update each model to register the
|
16 |
+
# stacked params and get it from there instead in a future PR.
|
17 |
+
# fused_name: List[shard_name]
|
18 |
+
FUSED_LAYER_NAME_MAPPING = {
|
19 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
20 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def pack_quantized_values_into_int32(
|
25 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
26 |
+
):
|
27 |
+
# move dim to pack to the end
|
28 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
29 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
30 |
+
w_q_perm = w_q.permute(perm)
|
31 |
+
|
32 |
+
pack_factor = 32 // wtype.size_bits
|
33 |
+
mask = (1 << wtype.size_bits) - 1
|
34 |
+
|
35 |
+
new_shape_perm = list(w_q_perm.shape)
|
36 |
+
assert w_q_perm.shape[-1] % pack_factor == 0
|
37 |
+
new_shape_perm[-1] //= pack_factor
|
38 |
+
|
39 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
40 |
+
for i in range(pack_factor):
|
41 |
+
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
|
42 |
+
|
43 |
+
return res.permute(inv_perm)
|
44 |
+
|
45 |
+
|
46 |
+
def unpack_quantized_values_into_int32(
|
47 |
+
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
48 |
+
):
|
49 |
+
# move dim to pack to the end
|
50 |
+
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
51 |
+
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
52 |
+
w_q_perm = w_q.permute(perm)
|
53 |
+
|
54 |
+
pack_factor = 32 // wtype.size_bits
|
55 |
+
mask = (1 << wtype.size_bits) - 1
|
56 |
+
|
57 |
+
new_shape_perm = list(w_q_perm.shape)
|
58 |
+
new_shape_perm[-1] *= pack_factor
|
59 |
+
|
60 |
+
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
|
61 |
+
for i in range(pack_factor):
|
62 |
+
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
|
63 |
+
|
64 |
+
return res.permute(inv_perm)
|
65 |
+
|
66 |
+
|
67 |
+
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
68 |
+
# prefix: model.layers.0.self_attn.q_proj
|
69 |
+
# proj_name: q_proj
|
70 |
+
proj_name = prefix.split(".")[-1]
|
71 |
+
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
72 |
+
shard_prefixes = [
|
73 |
+
prefix.replace(proj_name, shard_proj_name)
|
74 |
+
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
75 |
+
]
|
76 |
+
|
77 |
+
is_skipped = None
|
78 |
+
for shard_prefix in shard_prefixes:
|
79 |
+
is_shard_skipped = shard_prefix in ignored_layers
|
80 |
+
|
81 |
+
if is_skipped is None:
|
82 |
+
is_skipped = is_shard_skipped
|
83 |
+
elif is_shard_skipped != is_skipped:
|
84 |
+
raise ValueError(
|
85 |
+
f"Detected some but not all shards of {prefix} "
|
86 |
+
"are quantized. All shards of fused layers "
|
87 |
+
"to have the same precision."
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
is_skipped = prefix in ignored_layers
|
91 |
+
|
92 |
+
assert is_skipped is not None
|
93 |
+
return is_skipped
|
94 |
+
|
95 |
+
|
96 |
+
def get_pack_factor(num_bits):
|
97 |
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
98 |
+
return 32 // num_bits
|
99 |
+
|
100 |
+
|
101 |
+
def permute_rows(
|
102 |
+
q_w: torch.Tensor,
|
103 |
+
w_ref: torch.Tensor,
|
104 |
+
group_size: int,
|
105 |
+
test_perm: Optional[torch.Tensor] = None,
|
106 |
+
):
|
107 |
+
assert q_w.shape == w_ref.shape
|
108 |
+
|
109 |
+
orig_device = q_w.device
|
110 |
+
k_size, _ = q_w.shape
|
111 |
+
|
112 |
+
g_idx = torch.zeros((k_size,), dtype=torch.int32)
|
113 |
+
for i in range(k_size):
|
114 |
+
g_idx[i] = i // group_size
|
115 |
+
|
116 |
+
# Simulate act_order by doing a random permutation on K
|
117 |
+
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
118 |
+
|
119 |
+
g_idx = g_idx[rand_perm].contiguous()
|
120 |
+
q_w = q_w[rand_perm, :].contiguous()
|
121 |
+
w_ref = w_ref[rand_perm, :].contiguous()
|
122 |
+
|
123 |
+
return (
|
124 |
+
w_ref.to(device=orig_device),
|
125 |
+
q_w.to(device=orig_device),
|
126 |
+
g_idx.to(device=orig_device),
|
127 |
+
rand_perm.to(device=orig_device),
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def quantize_weights(
|
132 |
+
w: torch.Tensor,
|
133 |
+
quant_type: ScalarType,
|
134 |
+
group_size: Optional[int],
|
135 |
+
zero_points: bool = False,
|
136 |
+
ref_zero_points_after_scales: bool = False,
|
137 |
+
):
|
138 |
+
assert (
|
139 |
+
quant_type.is_integer()
|
140 |
+
), "Floating point quantization may work but has not been tested"
|
141 |
+
assert not zero_points or group_size is not None, (
|
142 |
+
"to have group zero points, group_size must be provided "
|
143 |
+
"(-1 group_size is channelwise)"
|
144 |
+
)
|
145 |
+
|
146 |
+
orig_device = w.device
|
147 |
+
orig_type = w.dtype
|
148 |
+
size_k, size_n = w.shape
|
149 |
+
|
150 |
+
assert w.is_floating_point(), "w must be float"
|
151 |
+
|
152 |
+
if group_size == -1:
|
153 |
+
group_size = size_k
|
154 |
+
|
155 |
+
# Reshape to [groupsize, -1]
|
156 |
+
if group_size is not None and group_size < size_k:
|
157 |
+
w = w.reshape((-1, group_size, size_n))
|
158 |
+
w = w.permute(1, 0, 2)
|
159 |
+
w = w.reshape((group_size, -1))
|
160 |
+
|
161 |
+
# Compute scale for each group
|
162 |
+
max_val = torch.max(w, 0, keepdim=True).values
|
163 |
+
min_val = torch.min(w, 0, keepdim=True).values
|
164 |
+
|
165 |
+
max_q_val = quant_type.max()
|
166 |
+
min_q_val = quant_type.min()
|
167 |
+
|
168 |
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
169 |
+
maybe_w_zp = None
|
170 |
+
if group_size is not None:
|
171 |
+
if zero_points:
|
172 |
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
173 |
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
174 |
+
maybe_w_zp = (
|
175 |
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
# If the bias is such that there are no possible negative/positive
|
179 |
+
# values, set the max value to inf to avoid divide by 0
|
180 |
+
w_s = torch.max(
|
181 |
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
182 |
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
183 |
+
)
|
184 |
+
|
185 |
+
# Quantize
|
186 |
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
187 |
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
188 |
+
|
189 |
+
# Compute ref (dequantized)
|
190 |
+
# For some kernels (namely Machete) the zero-points are applied after the
|
191 |
+
# scales are applied, for this case computing the reference in similar way
|
192 |
+
# allows us to use tighter error tolerances in our unit tests.
|
193 |
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
194 |
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
195 |
+
else:
|
196 |
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
197 |
+
|
198 |
+
if quant_type.has_bias():
|
199 |
+
w_q += quant_type.bias
|
200 |
+
|
201 |
+
# Restore original shapes
|
202 |
+
if group_size is not None and group_size < size_k:
|
203 |
+
|
204 |
+
def reshape_w(w):
|
205 |
+
w = w.reshape((group_size, -1, size_n))
|
206 |
+
w = w.permute(1, 0, 2)
|
207 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
208 |
+
return w
|
209 |
+
|
210 |
+
w_q = reshape_w(w_q)
|
211 |
+
w_ref = reshape_w(w_ref)
|
212 |
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
213 |
+
|
214 |
+
if maybe_w_zp is not None:
|
215 |
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
216 |
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
217 |
+
|
218 |
+
return (
|
219 |
+
w_ref.to(device=orig_device),
|
220 |
+
w_q.to(device=orig_device),
|
221 |
+
w_s if group_size is not None else None,
|
222 |
+
maybe_w_zp,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def gptq_quantize_weights(
|
227 |
+
w: torch.Tensor,
|
228 |
+
quant_type: ScalarType,
|
229 |
+
group_size: int,
|
230 |
+
act_order: bool,
|
231 |
+
test_perm: Optional[torch.Tensor] = None,
|
232 |
+
):
|
233 |
+
size_k, _ = w.shape
|
234 |
+
|
235 |
+
assert w.is_floating_point(), "w must be float"
|
236 |
+
assert (
|
237 |
+
quant_type in SUPPORTED_GPTQ_QUANT_TYPES
|
238 |
+
), f"Unsupported gptq type = {quant_type}"
|
239 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
240 |
+
size_k
|
241 |
+
], f"Unsupported groupsize = {group_size}"
|
242 |
+
|
243 |
+
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
244 |
+
|
245 |
+
# Apply act_order
|
246 |
+
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
247 |
+
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
248 |
+
if act_order:
|
249 |
+
assert (
|
250 |
+
group_size < size_k
|
251 |
+
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
252 |
+
group_size, size_k
|
253 |
+
)
|
254 |
+
|
255 |
+
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
|
256 |
+
|
257 |
+
return w_ref, w_q, w_s, g_idx, rand_perm
|
258 |
+
|
259 |
+
|
260 |
+
# QQQ employs different quant schemes for per-group and
|
261 |
+
# per-channel quantization.
|
262 |
+
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
263 |
+
orig_device = w.device
|
264 |
+
size_k, size_n = w.shape
|
265 |
+
|
266 |
+
assert w.is_floating_point(), "w must be float"
|
267 |
+
assert (
|
268 |
+
num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
|
269 |
+
), f"Unsupported num_bits = {num_bits}"
|
270 |
+
assert group_size in SUPPORTED_GROUP_SIZES + [
|
271 |
+
size_k
|
272 |
+
], f"Unsupported groupsize = {group_size}"
|
273 |
+
|
274 |
+
if group_size == -1:
|
275 |
+
group_size = size_k
|
276 |
+
assert group_size <= size_k
|
277 |
+
|
278 |
+
if group_size < size_k:
|
279 |
+
# Reshape to [groupsize, -1]
|
280 |
+
w = w.reshape((-1, group_size, size_n))
|
281 |
+
w = w.permute(1, 0, 2)
|
282 |
+
w = w.reshape((group_size, -1))
|
283 |
+
|
284 |
+
max_q_val = 2**num_bits - 1
|
285 |
+
half_q_val = (max_q_val + 1) // 2
|
286 |
+
|
287 |
+
# Compute scale for each group
|
288 |
+
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
289 |
+
s_group *= 2 / max_q_val # 2 => symmetric
|
290 |
+
|
291 |
+
# Quantize
|
292 |
+
q_w = torch.round(w / s_group).int()
|
293 |
+
q_w += half_q_val
|
294 |
+
q_w = torch.clamp(q_w, 0, max_q_val)
|
295 |
+
# Compute ref (dequantized)
|
296 |
+
w_ref = (q_w - half_q_val).half() * s_group
|
297 |
+
|
298 |
+
# Restore original shapes
|
299 |
+
def reshape_w(w):
|
300 |
+
w = w.reshape((group_size, -1, size_n))
|
301 |
+
w = w.permute(1, 0, 2)
|
302 |
+
w = w.reshape((size_k, size_n)).contiguous()
|
303 |
+
return w
|
304 |
+
|
305 |
+
q_w = reshape_w(q_w)
|
306 |
+
w_ref = reshape_w(w_ref)
|
307 |
+
|
308 |
+
# Compute int8 quantization scale for each channel
|
309 |
+
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
|
310 |
+
s_channel /= 127.0
|
311 |
+
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
|
312 |
+
w_ref = t_int8.half() * s_channel
|
313 |
+
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
|
314 |
+
|
315 |
+
# Fuse scales
|
316 |
+
s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
|
317 |
+
dtype=torch.half
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
max_q_val = 2 ** (num_bits - 1) - 1
|
321 |
+
|
322 |
+
# Compute scale for each channel
|
323 |
+
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
324 |
+
s_channel /= max_q_val
|
325 |
+
|
326 |
+
# Quantize
|
327 |
+
q_w = torch.round(w / s_channel).int()
|
328 |
+
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
|
329 |
+
# Compute ref (dequantized)
|
330 |
+
w_ref = q_w.half() * s_channel
|
331 |
+
|
332 |
+
s_group = torch.tensor([], dtype=torch.half)
|
333 |
+
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
|
334 |
+
s_channel /= 2 ** (8 - num_bits)
|
335 |
+
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
|
336 |
+
|
337 |
+
return (
|
338 |
+
w_ref.to(device=orig_device),
|
339 |
+
q_w.to(device=orig_device),
|
340 |
+
s_group.to(device=orig_device),
|
341 |
+
s_channel.to(device=orig_device),
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
346 |
+
orig_device = q_w.device
|
347 |
+
|
348 |
+
sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
|
349 |
+
|
350 |
+
g_idx = g_idx[sort_indices].contiguous()
|
351 |
+
q_w = q_w[sort_indices, :].contiguous()
|
352 |
+
|
353 |
+
return (
|
354 |
+
q_w.to(device=orig_device),
|
355 |
+
g_idx.to(device=orig_device),
|
356 |
+
sort_indices.to(device=orig_device),
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
def pack_rows(
|
361 |
+
q_w: torch.Tensor,
|
362 |
+
num_bits: int,
|
363 |
+
size_k: int,
|
364 |
+
size_n: int,
|
365 |
+
):
|
366 |
+
assert q_w.shape == (size_k, size_n)
|
367 |
+
|
368 |
+
pack_factor = get_pack_factor(num_bits)
|
369 |
+
assert size_k % pack_factor == 0
|
370 |
+
|
371 |
+
orig_device = q_w.device
|
372 |
+
|
373 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
374 |
+
|
375 |
+
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
376 |
+
|
377 |
+
for i in range(pack_factor):
|
378 |
+
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
379 |
+
|
380 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
381 |
+
return q_res
|
382 |
+
|
383 |
+
|
384 |
+
def pack_cols(
|
385 |
+
q_w: torch.Tensor,
|
386 |
+
num_bits: int,
|
387 |
+
size_k: int,
|
388 |
+
size_n: int,
|
389 |
+
):
|
390 |
+
assert q_w.shape == (size_k, size_n)
|
391 |
+
|
392 |
+
pack_factor = get_pack_factor(num_bits)
|
393 |
+
assert size_n % pack_factor == 0
|
394 |
+
|
395 |
+
orig_device = q_w.device
|
396 |
+
|
397 |
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
398 |
+
|
399 |
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
400 |
+
|
401 |
+
for i in range(pack_factor):
|
402 |
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
403 |
+
|
404 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
405 |
+
q_res = q_res.contiguous()
|
406 |
+
|
407 |
+
return q_res
|
408 |
+
|
409 |
+
|
410 |
+
def unpack_cols(
|
411 |
+
packed_q_w: torch.Tensor,
|
412 |
+
num_bits: int,
|
413 |
+
size_k: int,
|
414 |
+
size_n: int,
|
415 |
+
):
|
416 |
+
pack_factor = get_pack_factor(num_bits)
|
417 |
+
assert size_n % pack_factor == 0
|
418 |
+
assert packed_q_w.shape == (
|
419 |
+
size_k,
|
420 |
+
size_n // pack_factor,
|
421 |
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
422 |
+
packed_q_w.shape, size_k, size_n, pack_factor
|
423 |
+
)
|
424 |
+
|
425 |
+
orig_device = packed_q_w.device
|
426 |
+
|
427 |
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
428 |
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
429 |
+
|
430 |
+
mask = (1 << num_bits) - 1
|
431 |
+
for i in range(pack_factor):
|
432 |
+
vals = packed_q_w_cpu & mask
|
433 |
+
packed_q_w_cpu >>= num_bits
|
434 |
+
q_res[:, i::pack_factor] = vals
|
435 |
+
|
436 |
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
437 |
+
q_res = q_res.contiguous()
|
438 |
+
|
439 |
+
return q_res
|
440 |
+
|
441 |
+
|
442 |
+
def gptq_pack(
|
443 |
+
q_w: torch.Tensor,
|
444 |
+
num_bits: int,
|
445 |
+
size_k: int,
|
446 |
+
size_n: int,
|
447 |
+
):
|
448 |
+
return pack_rows(q_w, num_bits, size_k, size_n)
|
449 |
+
|
450 |
+
|
451 |
+
def awq_pack(
|
452 |
+
q_w: torch.Tensor,
|
453 |
+
num_bits: int,
|
454 |
+
size_k: int,
|
455 |
+
size_n: int,
|
456 |
+
):
|
457 |
+
assert q_w.shape == (size_k, size_n)
|
458 |
+
|
459 |
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
460 |
+
if num_bits == 4:
|
461 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
462 |
+
elif num_bits == 8:
|
463 |
+
interleave = numpy.array([0, 2, 1, 3])
|
464 |
+
else:
|
465 |
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
466 |
+
|
467 |
+
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
468 |
+
q_w = q_w.reshape((-1, size_n)).contiguous()
|
469 |
+
|
470 |
+
return pack_cols(q_w, num_bits, size_k, size_n)
|
marlin/dense/LICENSE
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Contains code from https://github.com/IST-DASLab/marlin
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright {yyyy} {name of copyright owner}
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
204 |
+
|
205 |
+
------------------------------------------------------------------------------------
|
206 |
+
|
207 |
+
This product bundles various third-party components under other open source licenses.
|
208 |
+
This section summarizes those components and their licenses. See licenses/
|
209 |
+
for text of these licenses.
|
marlin/dense/common/base.h
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Modified by HandH1998
|
3 |
+
* Modified by Neural Magic
|
4 |
+
* Copyright (C) Marlin.2024 Elias Frantar
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
22 |
+
|
23 |
+
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
24 |
+
// for instance as inputs to tensor core operations. Consequently, all
|
25 |
+
// corresponding index accesses must be compile-time constants, which is why we
|
26 |
+
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
27 |
+
// this.
|
28 |
+
template <typename T, int n>
|
29 |
+
struct Vec {
|
30 |
+
T elems[n];
|
31 |
+
__device__ T& operator[](int i) { return elems[i]; }
|
32 |
+
};
|
marlin/dense/common/mem.h
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Modified by HandH1998
|
3 |
+
* Modified by Neural Magic
|
4 |
+
* Copyright (C) Marlin.2024 Elias Frantar
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
22 |
+
// predication to handle batchsizes that are not multiples of 16.
|
23 |
+
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
24 |
+
bool pred = true) {
|
25 |
+
const int BYTES = 16;
|
26 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
27 |
+
asm volatile(
|
28 |
+
"{\n"
|
29 |
+
" .reg .pred p;\n"
|
30 |
+
" setp.ne.b32 p, %0, 0;\n"
|
31 |
+
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
32 |
+
"}\n" ::"r"((int)pred),
|
33 |
+
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
34 |
+
}
|
35 |
+
|
36 |
+
// Asynchronous global->shared copy
|
37 |
+
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
38 |
+
const int BYTES = 16;
|
39 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
40 |
+
asm volatile(
|
41 |
+
"{\n"
|
42 |
+
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
43 |
+
"}\n" ::"r"(smem),
|
44 |
+
"l"(glob_ptr), "n"(BYTES));
|
45 |
+
}
|
46 |
+
|
47 |
+
// Async copy fence.
|
48 |
+
__device__ inline void cp_async_fence() {
|
49 |
+
asm volatile("cp.async.commit_group;\n" ::);
|
50 |
+
}
|
51 |
+
|
52 |
+
// Wait until at most `n` async copy stages are still pending.
|
53 |
+
template <int n>
|
54 |
+
__device__ inline void cp_async_wait() {
|
55 |
+
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
56 |
+
}
|
57 |
+
|
58 |
+
// Wait until barrier reaches `count`, then lock for current threadblock.
|
59 |
+
__device__ inline void barrier_acquire(int* lock, int count) {
|
60 |
+
if (threadIdx.x == 0) {
|
61 |
+
int state = -1;
|
62 |
+
do
|
63 |
+
// Guarantee that subsequent writes by this threadblock will be visible
|
64 |
+
// globally.
|
65 |
+
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
66 |
+
: "=r"(state)
|
67 |
+
: "l"(lock));
|
68 |
+
while (state != count);
|
69 |
+
}
|
70 |
+
__syncthreads();
|
71 |
+
}
|
72 |
+
|
73 |
+
// Release barrier and increment visitation count.
|
74 |
+
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
75 |
+
__syncthreads();
|
76 |
+
if (threadIdx.x == 0) {
|
77 |
+
if (reset) {
|
78 |
+
lock[0] = 0;
|
79 |
+
return;
|
80 |
+
}
|
81 |
+
int val = 1;
|
82 |
+
// Make sure that all writes since acquiring this barrier are visible
|
83 |
+
// globally, while releasing the barrier.
|
84 |
+
asm volatile("fence.acq_rel.gpu;\n");
|
85 |
+
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
86 |
+
:
|
87 |
+
: "l"(lock), "r"(val));
|
88 |
+
}
|
89 |
+
}
|
marlin/dense/marlin_cuda_kernel.cu
ADDED
@@ -0,0 +1,1068 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Modified by Neural Magic
|
3 |
+
* Copyright (C) Marlin.2024 Elias Frantar
|
4 |
+
*
|
5 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
* you may not use this file except in compliance with the License.
|
7 |
+
* You may obtain a copy of the License at
|
8 |
+
*
|
9 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
*
|
11 |
+
* Unless required by applicable law or agreed to in writing, software
|
12 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
* See the License for the specific language governing permissions and
|
15 |
+
* limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
#include <torch/all.h>
|
19 |
+
|
20 |
+
#include <ATen/cuda/CUDAContext.h>
|
21 |
+
#include <c10/cuda/CUDAGuard.h>
|
22 |
+
#include <cuda.h>
|
23 |
+
#include <cuda_fp16.h>
|
24 |
+
#include <cuda_runtime.h>
|
25 |
+
|
26 |
+
#include <iostream>
|
27 |
+
|
28 |
+
#include "common/base.h"
|
29 |
+
|
30 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
31 |
+
#include "common/mem.h"
|
32 |
+
#endif
|
33 |
+
|
34 |
+
template <typename T>
|
35 |
+
inline std::string str(T x) {
|
36 |
+
return std::to_string(x);
|
37 |
+
}
|
38 |
+
|
39 |
+
namespace marlin_dense {
|
40 |
+
|
41 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
42 |
+
|
43 |
+
using I4 = Vec<int, 4>;
|
44 |
+
// Matrix fragments for tensor core instructions; their precise layout is
|
45 |
+
// documented here:
|
46 |
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
47 |
+
using FragA = Vec<half2, 4>;
|
48 |
+
using FragB = Vec<half2, 2>;
|
49 |
+
using FragC = Vec<float, 4>;
|
50 |
+
using FragS = Vec<half2, 1>; // quantization scales
|
51 |
+
|
52 |
+
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
53 |
+
// output/accumulation.
|
54 |
+
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
|
55 |
+
FragC& frag_c) {
|
56 |
+
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
57 |
+
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
58 |
+
float* c = reinterpret_cast<float*>(&frag_c);
|
59 |
+
asm volatile(
|
60 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
61 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
62 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
63 |
+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
64 |
+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
65 |
+
}
|
66 |
+
|
67 |
+
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
68 |
+
// memory, directly in tensor core layout.
|
69 |
+
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
70 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
71 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
72 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
73 |
+
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
74 |
+
: "r"(smem));
|
75 |
+
}
|
76 |
+
|
77 |
+
// Lookup-table based 3-input logical operation; explicitly used for
|
78 |
+
// dequantization as the compiler does not seem to automatically recognize it in
|
79 |
+
// all cases.
|
80 |
+
template <int lut>
|
81 |
+
__device__ inline int lop3(int a, int b, int c) {
|
82 |
+
int res;
|
83 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
84 |
+
: "=r"(res)
|
85 |
+
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
86 |
+
return res;
|
87 |
+
}
|
88 |
+
|
89 |
+
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
90 |
+
// values. We mostly follow the strategy in the link below, with some small
|
91 |
+
// changes:
|
92 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
93 |
+
__device__ inline FragB dequant(int q) {
|
94 |
+
const int LO = 0x000f000f;
|
95 |
+
const int HI = 0x00f000f0;
|
96 |
+
const int EX = 0x64006400;
|
97 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
98 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
99 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
100 |
+
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
101 |
+
// directly into `SUB` and `ADD`.
|
102 |
+
const int SUB = 0x64086408;
|
103 |
+
const int MUL = 0x2c002c00;
|
104 |
+
const int ADD = 0xd480d480;
|
105 |
+
FragB frag_b;
|
106 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
107 |
+
*reinterpret_cast<const half2*>(&SUB));
|
108 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
109 |
+
*reinterpret_cast<const half2*>(&MUL),
|
110 |
+
*reinterpret_cast<const half2*>(&ADD));
|
111 |
+
return frag_b;
|
112 |
+
}
|
113 |
+
|
114 |
+
// Multiply dequantized values by the corresponding quantization scale; used
|
115 |
+
// only for grouped quantization.
|
116 |
+
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
117 |
+
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
118 |
+
frag_b[0] = __hmul2(frag_b[0], s);
|
119 |
+
frag_b[1] = __hmul2(frag_b[1], s);
|
120 |
+
}
|
121 |
+
|
122 |
+
template <const int threads, // number of threads in a threadblock
|
123 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
124 |
+
// dimension (batchsize) of the
|
125 |
+
// threadblock
|
126 |
+
const int thread_n_blocks, // same for n dimension (output)
|
127 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
128 |
+
const int stages, // number of stages for the async global->shared
|
129 |
+
// fetch pipeline
|
130 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
131 |
+
// with a separate quantization scale
|
132 |
+
>
|
133 |
+
__global__ void Marlin(
|
134 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
135 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
136 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
137 |
+
const int4* __restrict__ s, // fp16 quantization scales of shape
|
138 |
+
// (k/groupsize)xn
|
139 |
+
int prob_m, // batch dimension m
|
140 |
+
int prob_n, // output dimension n
|
141 |
+
int prob_k, // reduction dimension k
|
142 |
+
int* locks // extra global storage for barrier synchronization
|
143 |
+
) {
|
144 |
+
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
145 |
+
// same size, which might involve multiple column "slices" (of width 16 *
|
146 |
+
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
147 |
+
// example:
|
148 |
+
// 0 1 3
|
149 |
+
// 0 2 3
|
150 |
+
// 1 2 4
|
151 |
+
// While this kind of partitioning makes things somewhat more complicated, it
|
152 |
+
// ensures good utilization of all SMs for many kinds of shape and GPU
|
153 |
+
// configurations, while requiring as few slow global cross-threadblock
|
154 |
+
// reductions as possible.
|
155 |
+
|
156 |
+
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
157 |
+
// better partitioning with less reductions
|
158 |
+
int parallel = 1;
|
159 |
+
if (prob_m > 16 * thread_m_blocks) {
|
160 |
+
parallel = prob_m / (16 * thread_m_blocks);
|
161 |
+
prob_m = 16 * thread_m_blocks;
|
162 |
+
}
|
163 |
+
|
164 |
+
int k_tiles = prob_k / 16 / thread_k_blocks;
|
165 |
+
int n_tiles = prob_n / 16 / thread_n_blocks;
|
166 |
+
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
|
167 |
+
// Ensure that the number of tiles in each stripe is a multiple of the
|
168 |
+
// groupsize; this avoids an annoying special case where a stripe starts in
|
169 |
+
// the middle of group.
|
170 |
+
if (group_blocks != -1)
|
171 |
+
iters = (group_blocks / thread_k_blocks) *
|
172 |
+
ceildiv(iters, (group_blocks / thread_k_blocks));
|
173 |
+
|
174 |
+
int slice_row = (iters * blockIdx.x) % k_tiles;
|
175 |
+
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
176 |
+
int slice_col = slice_col_par;
|
177 |
+
int slice_iters; // number of threadblock tiles in the current slice
|
178 |
+
int slice_count =
|
179 |
+
0; // total number of active threadblocks in the current slice
|
180 |
+
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
181 |
+
// top
|
182 |
+
|
183 |
+
// We can easily implement parallel problem execution by just remapping
|
184 |
+
// indices and advancing global pointers
|
185 |
+
if (slice_col_par >= n_tiles) {
|
186 |
+
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
|
187 |
+
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
188 |
+
locks += (slice_col_par / n_tiles) * n_tiles;
|
189 |
+
slice_col = slice_col_par % n_tiles;
|
190 |
+
}
|
191 |
+
|
192 |
+
// Compute all information about the current slice which is required for
|
193 |
+
// synchronization.
|
194 |
+
auto init_slice = [&]() {
|
195 |
+
slice_iters =
|
196 |
+
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
197 |
+
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
198 |
+
if (slice_iters == 0) return;
|
199 |
+
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
200 |
+
slice_count = 1;
|
201 |
+
slice_idx = 0;
|
202 |
+
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
203 |
+
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
204 |
+
int col_off = col_first - k_tiles * slice_col_par;
|
205 |
+
slice_count = ceildiv(k_tiles - col_off, iters);
|
206 |
+
if (col_off > 0) slice_count++;
|
207 |
+
int delta_first = iters * blockIdx.x - col_first;
|
208 |
+
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
209 |
+
slice_idx = slice_count - 1;
|
210 |
+
else {
|
211 |
+
slice_idx = slice_count - 1 - delta_first / iters;
|
212 |
+
if (col_off > 0) slice_idx--;
|
213 |
+
}
|
214 |
+
}
|
215 |
+
if (slice_col == n_tiles) {
|
216 |
+
A += 16 * thread_m_blocks * prob_k / 8;
|
217 |
+
C += 16 * thread_m_blocks * prob_n / 8;
|
218 |
+
locks += n_tiles;
|
219 |
+
slice_col = 0;
|
220 |
+
}
|
221 |
+
};
|
222 |
+
init_slice();
|
223 |
+
|
224 |
+
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
|
225 |
+
// We typically use `constexpr` to indicate that this value is a compile-time
|
226 |
+
// constant
|
227 |
+
constexpr int a_sh_stride =
|
228 |
+
16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
|
229 |
+
constexpr int a_gl_rd_delta_o =
|
230 |
+
16 * thread_k_blocks /
|
231 |
+
8; // delta between subsequent A tiles in global memory
|
232 |
+
int a_gl_rd_delta_i =
|
233 |
+
a_gl_stride *
|
234 |
+
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
|
235 |
+
constexpr int a_sh_wr_delta =
|
236 |
+
a_sh_stride *
|
237 |
+
(threads / a_gl_rd_delta_o); // between shared memory writes
|
238 |
+
constexpr int a_sh_rd_delta_o =
|
239 |
+
2 * ((threads / 32) /
|
240 |
+
(thread_n_blocks / 4)); // between shared memory tile reads
|
241 |
+
constexpr int a_sh_rd_delta_i =
|
242 |
+
a_sh_stride * 16; // within a shared memory tile
|
243 |
+
constexpr int a_sh_stage =
|
244 |
+
a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
|
245 |
+
constexpr int a_sh_wr_iters =
|
246 |
+
ceildiv(a_sh_stage,
|
247 |
+
a_sh_wr_delta); // number of shared write iterations for a tile
|
248 |
+
|
249 |
+
int b_gl_stride = 16 * prob_n / 32;
|
250 |
+
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
|
251 |
+
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
252 |
+
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
|
253 |
+
constexpr int b_sh_wr_delta = threads;
|
254 |
+
constexpr int b_sh_rd_delta = threads;
|
255 |
+
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
256 |
+
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
257 |
+
|
258 |
+
int s_gl_stride = prob_n / 8;
|
259 |
+
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
260 |
+
constexpr int s_sh_stage = s_sh_stride;
|
261 |
+
int s_gl_rd_delta = s_gl_stride;
|
262 |
+
|
263 |
+
// Global A read index of current thread.
|
264 |
+
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
265 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
266 |
+
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
267 |
+
// Shared write index of current thread.
|
268 |
+
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
269 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
270 |
+
// Shared read index.
|
271 |
+
int a_sh_rd =
|
272 |
+
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
273 |
+
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
274 |
+
|
275 |
+
int b_gl_rd =
|
276 |
+
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
|
277 |
+
b_gl_rd += b_sh_stride * slice_col;
|
278 |
+
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
279 |
+
int b_sh_wr = threadIdx.x;
|
280 |
+
int b_sh_rd = threadIdx.x;
|
281 |
+
|
282 |
+
int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
283 |
+
s_sh_stride * slice_col + threadIdx.x;
|
284 |
+
int s_sh_wr = threadIdx.x;
|
285 |
+
int s_sh_rd;
|
286 |
+
// We use a different scale layout for grouped and column-wise quantization as
|
287 |
+
// we scale a `half2` tile in column-major layout in the former and in
|
288 |
+
// row-major in the latter case.
|
289 |
+
if (group_blocks != -1)
|
290 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
291 |
+
(threadIdx.x % 32) / 4;
|
292 |
+
else
|
293 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
294 |
+
(threadIdx.x % 32) % 4;
|
295 |
+
|
296 |
+
// Precompute which thread should not read memory in which iterations; this is
|
297 |
+
// needed if there are more threads than required for a certain tilesize or
|
298 |
+
// when the batchsize is not a multiple of 16.
|
299 |
+
bool a_sh_wr_pred[a_sh_wr_iters];
|
300 |
+
#pragma unroll
|
301 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
302 |
+
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
303 |
+
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
304 |
+
|
305 |
+
// To ensure that writing and reading A tiles to/from shared memory, the
|
306 |
+
// latter in fragment format, is fully bank conflict free, we need to use a
|
307 |
+
// rather fancy XOR-based layout. The key here is that neither reads nor
|
308 |
+
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
|
309 |
+
// same shared memory banks. Further, it seems (based on NSight-Compute) that
|
310 |
+
// each warp must also write a consecutive memory segment?
|
311 |
+
auto transform_a = [&](int i) {
|
312 |
+
int row = i / a_gl_rd_delta_o;
|
313 |
+
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
314 |
+
};
|
315 |
+
// Since the computation of this remapping is non-trivial and, due to our main
|
316 |
+
// loop unrolls, all shared memory accesses are static, we simply precompute
|
317 |
+
// both transformed reads and writes.
|
318 |
+
int a_sh_wr_trans[a_sh_wr_iters];
|
319 |
+
#pragma unroll
|
320 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
321 |
+
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
322 |
+
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
323 |
+
#pragma unroll
|
324 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
325 |
+
#pragma unroll
|
326 |
+
for (int j = 0; j < thread_m_blocks; j++)
|
327 |
+
a_sh_rd_trans[i][j] =
|
328 |
+
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
329 |
+
}
|
330 |
+
|
331 |
+
// Since B-accesses have non-constant stride they have to be computed at
|
332 |
+
// runtime; we break dependencies between subsequent accesses with a tile by
|
333 |
+
// maintining multiple pointers (we have enough registers), a tiny
|
334 |
+
// optimization.
|
335 |
+
const int4* B_ptr[b_sh_wr_iters];
|
336 |
+
#pragma unroll
|
337 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
338 |
+
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
339 |
+
|
340 |
+
extern __shared__ int4 sh[];
|
341 |
+
// Shared memory storage for global fetch pipelines.
|
342 |
+
int4* sh_a = sh;
|
343 |
+
int4* sh_b = sh_a + (stages * a_sh_stage);
|
344 |
+
int4* sh_s = sh_b + (stages * b_sh_stage);
|
345 |
+
// Register storage for double buffer of shared memory reads.
|
346 |
+
FragA frag_a[2][thread_m_blocks];
|
347 |
+
I4 frag_b_quant[2];
|
348 |
+
FragC frag_c[thread_m_blocks][4][2];
|
349 |
+
FragS frag_s[2][4];
|
350 |
+
|
351 |
+
// Zero accumulators.
|
352 |
+
auto zero_accums = [&]() {
|
353 |
+
#pragma unroll
|
354 |
+
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
355 |
+
reinterpret_cast<float*>(frag_c)[i] = 0;
|
356 |
+
};
|
357 |
+
|
358 |
+
// Asynchronously fetch the next A, B and s tile from global to the next
|
359 |
+
// shared memory pipeline location.
|
360 |
+
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
361 |
+
if (pred) {
|
362 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
363 |
+
#pragma unroll
|
364 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
365 |
+
cp_async4_pred(
|
366 |
+
&sh_a_stage[a_sh_wr_trans[i]],
|
367 |
+
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
368 |
+
a_sh_wr_pred[i]);
|
369 |
+
}
|
370 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
371 |
+
#pragma unroll
|
372 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
373 |
+
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
|
374 |
+
B_ptr[i] += b_gl_rd_delta_o;
|
375 |
+
}
|
376 |
+
// Only fetch scales if this tile starts a new group
|
377 |
+
if constexpr (group_blocks != -1) {
|
378 |
+
// This assumes group_blocks >= thread_k_blocks
|
379 |
+
// and would need to be modified to support smaller groups.
|
380 |
+
static_assert(group_blocks >= thread_k_blocks);
|
381 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
382 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
383 |
+
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
384 |
+
s_gl_rd += s_gl_rd_delta;
|
385 |
+
}
|
386 |
+
}
|
387 |
+
}
|
388 |
+
// Insert a fence even when we are winding down the pipeline to ensure that
|
389 |
+
// waiting is also correct at this point.
|
390 |
+
cp_async_fence();
|
391 |
+
};
|
392 |
+
|
393 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
394 |
+
auto wait_for_stage = [&]() {
|
395 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
396 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
397 |
+
// shared memory load is fully complete (as it may otherwise be
|
398 |
+
// overwritten).
|
399 |
+
cp_async_wait<stages - 2>();
|
400 |
+
__syncthreads();
|
401 |
+
};
|
402 |
+
|
403 |
+
// Load the next sub-tile from the current location in the shared memory pipe
|
404 |
+
// into the current register buffer.
|
405 |
+
auto fetch_to_registers = [&](int k, int pipe) {
|
406 |
+
// It may seem inefficient that we reload the groups for every sub-tile;
|
407 |
+
// however, this does not seem to be a significant bottleneck, while some
|
408 |
+
// theoretically better attempts have lead to bad instruction ordering by
|
409 |
+
// the compiler and correspondingly a noticeable drop in performance.
|
410 |
+
if constexpr (group_blocks != -1) {
|
411 |
+
// This assumes group_blocks >= thread_k_blocks
|
412 |
+
// and would need to be modified to support smaller groups.
|
413 |
+
static_assert(group_blocks >= thread_k_blocks);
|
414 |
+
int4* sh_s_stage =
|
415 |
+
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
416 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
417 |
+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
418 |
+
}
|
419 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
420 |
+
#pragma unroll
|
421 |
+
for (int i = 0; i < thread_m_blocks; i++)
|
422 |
+
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
423 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
424 |
+
frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
|
425 |
+
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
|
426 |
+
};
|
427 |
+
|
428 |
+
// Execute the actual tensor core matmul of a sub-tile.
|
429 |
+
auto matmul = [&](int k) {
|
430 |
+
// We have the m dimension as the inner loop in order to encourage overlapping
|
431 |
+
// dequantization and matmul operations.
|
432 |
+
#pragma unroll
|
433 |
+
for (int j = 0; j < 4; j++) {
|
434 |
+
int b_quant = frag_b_quant[k % 2][j];
|
435 |
+
int b_quant_shift = b_quant >> 8;
|
436 |
+
FragB frag_b0 = dequant(b_quant);
|
437 |
+
// If there are no groups, we can just scale the final output once and can
|
438 |
+
// avoid doing so for each weight.
|
439 |
+
if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
|
440 |
+
FragB frag_b1 = dequant(b_quant_shift);
|
441 |
+
if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
|
442 |
+
#pragma unroll
|
443 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
444 |
+
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
445 |
+
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
446 |
+
}
|
447 |
+
}
|
448 |
+
};
|
449 |
+
|
450 |
+
// Since we slice across the k dimension of a tile in order to increase the
|
451 |
+
// number of warps while keeping the n dimension of a tile reasonable, we have
|
452 |
+
// multiple warps that accumulate their partial sums of the same output
|
453 |
+
// location; which we have to reduce over in the end. We do in shared memory.
|
454 |
+
auto thread_block_reduce = [&]() {
|
455 |
+
constexpr int red_off = threads / b_sh_stride / 2;
|
456 |
+
if (red_off >= 1) {
|
457 |
+
int red_idx = threadIdx.x / b_sh_stride;
|
458 |
+
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
|
459 |
+
constexpr int red_sh_delta = b_sh_stride;
|
460 |
+
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
|
461 |
+
(threadIdx.x % b_sh_stride);
|
462 |
+
|
463 |
+
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
464 |
+
// unnecessary read or write iterations, e.g., for two warps we write only
|
465 |
+
// once by warp 1 and read only once by warp 0.
|
466 |
+
|
467 |
+
#pragma unroll
|
468 |
+
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
469 |
+
#pragma unroll
|
470 |
+
for (int i = red_off; i > 0; i /= 2) {
|
471 |
+
if (i <= red_idx && red_idx < 2 * i) {
|
472 |
+
#pragma unroll
|
473 |
+
for (int j = 0; j < 4 * 2; j++) {
|
474 |
+
int red_sh_wr =
|
475 |
+
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
476 |
+
if (i < red_off) {
|
477 |
+
float* c_rd =
|
478 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
479 |
+
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
480 |
+
#pragma unroll
|
481 |
+
for (int k = 0; k < 4; k++)
|
482 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
483 |
+
c_rd[k] + c_wr[k];
|
484 |
+
}
|
485 |
+
sh[red_sh_wr] =
|
486 |
+
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
487 |
+
}
|
488 |
+
}
|
489 |
+
__syncthreads();
|
490 |
+
}
|
491 |
+
if (red_idx == 0) {
|
492 |
+
#pragma unroll
|
493 |
+
for (int i = 0; i < 4 * 2; i++) {
|
494 |
+
float* c_rd =
|
495 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
496 |
+
#pragma unroll
|
497 |
+
for (int j = 0; j < 4; j++)
|
498 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
499 |
+
c_rd[j];
|
500 |
+
}
|
501 |
+
}
|
502 |
+
__syncthreads();
|
503 |
+
}
|
504 |
+
}
|
505 |
+
};
|
506 |
+
|
507 |
+
// Since multiple threadblocks may process parts of the same column slice, we
|
508 |
+
// finally have to globally reduce over the results. As the striped
|
509 |
+
// partitioning minimizes the number of such reductions and our outputs are
|
510 |
+
// usually rather small, we perform this reduction serially in L2 cache.
|
511 |
+
auto global_reduce = [&](bool first = false, bool last = false) {
|
512 |
+
// We are very careful here to reduce directly in the output buffer to
|
513 |
+
// maximize L2 cache utilization in this step. To do this, we write out
|
514 |
+
// results in FP16 (but still reduce with FP32 compute).
|
515 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
516 |
+
if (threadIdx.x < active_threads) {
|
517 |
+
int c_gl_stride = prob_n / 8;
|
518 |
+
int c_gl_wr_delta_o = 8 * c_gl_stride;
|
519 |
+
int c_gl_wr_delta_i = 4 * (active_threads / 32);
|
520 |
+
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
|
521 |
+
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
522 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
523 |
+
constexpr int c_sh_wr_delta = active_threads;
|
524 |
+
int c_sh_wr = threadIdx.x;
|
525 |
+
|
526 |
+
int row = (threadIdx.x % 32) / 4;
|
527 |
+
|
528 |
+
if (!first) {
|
529 |
+
// Interestingly, doing direct global accesses here really seems to mess up
|
530 |
+
// the compiler and lead to slowdowns, hence we also use async-copies even
|
531 |
+
// though these fetches are not actually asynchronous.
|
532 |
+
#pragma unroll
|
533 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
534 |
+
cp_async4_pred(
|
535 |
+
&sh[c_sh_wr + c_sh_wr_delta * i],
|
536 |
+
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
537 |
+
c_gl_wr_delta_i * (i % 2)],
|
538 |
+
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
539 |
+
}
|
540 |
+
cp_async_fence();
|
541 |
+
cp_async_wait<0>();
|
542 |
+
}
|
543 |
+
|
544 |
+
#pragma unroll
|
545 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
546 |
+
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
547 |
+
if (!first) {
|
548 |
+
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
549 |
+
#pragma unroll
|
550 |
+
for (int j = 0; j < 2 * 4; j++) {
|
551 |
+
reinterpret_cast<float*>(
|
552 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
553 |
+
__half2float(reinterpret_cast<__half*>(&c_red)[j]);
|
554 |
+
}
|
555 |
+
}
|
556 |
+
if (!last) {
|
557 |
+
int4 c;
|
558 |
+
#pragma unroll
|
559 |
+
for (int j = 0; j < 2 * 4; j++) {
|
560 |
+
reinterpret_cast<__half*>(&c)[j] =
|
561 |
+
__float2half(reinterpret_cast<float*>(
|
562 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
563 |
+
}
|
564 |
+
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
565 |
+
c;
|
566 |
+
}
|
567 |
+
}
|
568 |
+
}
|
569 |
+
}
|
570 |
+
};
|
571 |
+
|
572 |
+
// Write out the reduce final result in the correct layout. We only actually
|
573 |
+
// reshuffle matrix fragments in this step, the reduction above is performed
|
574 |
+
// in fragment layout.
|
575 |
+
auto write_result = [&]() {
|
576 |
+
int c_gl_stride = prob_n / 8;
|
577 |
+
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
578 |
+
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
579 |
+
constexpr int c_sh_rd_delta =
|
580 |
+
c_sh_stride * (threads / (2 * thread_n_blocks));
|
581 |
+
|
582 |
+
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
583 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
584 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
585 |
+
int c_sh_wr =
|
586 |
+
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
|
587 |
+
c_sh_wr += 32 * (threadIdx.x / 32);
|
588 |
+
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
589 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
590 |
+
|
591 |
+
int c_gl_wr_end = c_gl_stride * prob_m;
|
592 |
+
|
593 |
+
// We first reorder in shared memory to guarantee the most efficient final
|
594 |
+
// global write patterns
|
595 |
+
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
596 |
+
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
597 |
+
if (group_blocks ==
|
598 |
+
-1) // for per-column quantization we finally apply the scale here
|
599 |
+
res = __hmul2(res, s[0]);
|
600 |
+
((half2*)sh)[idx] = res;
|
601 |
+
};
|
602 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
603 |
+
#pragma unroll
|
604 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
605 |
+
#pragma unroll
|
606 |
+
for (int j = 0; j < 4; j++) {
|
607 |
+
int wr = c_sh_wr + 8 * j;
|
608 |
+
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
609 |
+
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
610 |
+
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
611 |
+
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
612 |
+
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
613 |
+
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
614 |
+
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
615 |
+
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
616 |
+
}
|
617 |
+
c_sh_wr += 16 * (4 * c_sh_stride);
|
618 |
+
}
|
619 |
+
}
|
620 |
+
__syncthreads();
|
621 |
+
|
622 |
+
#pragma unroll
|
623 |
+
for (int i = 0;
|
624 |
+
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
625 |
+
i++) {
|
626 |
+
if (c_gl_wr < c_gl_wr_end) {
|
627 |
+
C[c_gl_wr] = sh[c_sh_rd];
|
628 |
+
c_gl_wr += c_gl_wr_delta;
|
629 |
+
c_sh_rd += c_sh_rd_delta;
|
630 |
+
}
|
631 |
+
}
|
632 |
+
};
|
633 |
+
|
634 |
+
// Start global fetch and register load pipelines.
|
635 |
+
auto start_pipes = [&]() {
|
636 |
+
#pragma unroll
|
637 |
+
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
638 |
+
zero_accums();
|
639 |
+
wait_for_stage();
|
640 |
+
fetch_to_registers(0, 0);
|
641 |
+
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
642 |
+
};
|
643 |
+
start_pipes();
|
644 |
+
|
645 |
+
// Main loop.
|
646 |
+
while (slice_iters) {
|
647 |
+
// We unroll over both the global fetch and the register load pipeline to
|
648 |
+
// ensure all shared memory accesses are static. Note that both pipelines have
|
649 |
+
// even length meaning that the next iteration will always start at index 0.
|
650 |
+
#pragma unroll
|
651 |
+
for (int pipe = 0; pipe < stages;) {
|
652 |
+
#pragma unroll
|
653 |
+
for (int k = 0; k < b_sh_wr_iters; k++) {
|
654 |
+
fetch_to_registers(k + 1, pipe % stages);
|
655 |
+
if (k == b_sh_wr_iters - 2) {
|
656 |
+
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
657 |
+
slice_iters >= stages);
|
658 |
+
pipe++;
|
659 |
+
wait_for_stage();
|
660 |
+
}
|
661 |
+
matmul(k);
|
662 |
+
}
|
663 |
+
slice_iters--;
|
664 |
+
if (slice_iters == 0) break;
|
665 |
+
}
|
666 |
+
a_gl_rd += a_gl_rd_delta_o * stages;
|
667 |
+
|
668 |
+
// Process results and, if necessary, proceed to the next column slice.
|
669 |
+
// While this pattern may not be the most readable, other ways of writing
|
670 |
+
// the loop seemed to noticeably worse performance after compilation.
|
671 |
+
if (slice_iters == 0) {
|
672 |
+
cp_async_wait<0>();
|
673 |
+
bool last = slice_idx == slice_count - 1;
|
674 |
+
// For per-column scales, we only fetch them here in the final step before
|
675 |
+
// write-out
|
676 |
+
if (group_blocks == -1 && last) {
|
677 |
+
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
678 |
+
cp_async_fence();
|
679 |
+
}
|
680 |
+
thread_block_reduce();
|
681 |
+
if (group_blocks == -1 && last) {
|
682 |
+
cp_async_wait<0>();
|
683 |
+
__syncthreads();
|
684 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
685 |
+
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
686 |
+
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
687 |
+
}
|
688 |
+
}
|
689 |
+
if (slice_count > 1) { // only globally reduce if there is more than one
|
690 |
+
// block in a slice
|
691 |
+
barrier_acquire(&locks[slice_col], slice_idx);
|
692 |
+
global_reduce(slice_idx == 0, last);
|
693 |
+
barrier_release(&locks[slice_col], last);
|
694 |
+
}
|
695 |
+
if (last) // only the last block in a slice actually writes the result
|
696 |
+
write_result();
|
697 |
+
slice_row = 0;
|
698 |
+
slice_col_par++;
|
699 |
+
slice_col++;
|
700 |
+
init_slice();
|
701 |
+
if (slice_iters) {
|
702 |
+
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
703 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
704 |
+
#pragma unroll
|
705 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
706 |
+
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
707 |
+
if (slice_col == 0) {
|
708 |
+
#pragma unroll
|
709 |
+
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
710 |
+
}
|
711 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
712 |
+
start_pipes();
|
713 |
+
}
|
714 |
+
}
|
715 |
+
}
|
716 |
+
}
|
717 |
+
|
718 |
+
#else
|
719 |
+
|
720 |
+
template <const int threads, // number of threads in a threadblock
|
721 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
722 |
+
// dimension (batchsize) of the
|
723 |
+
// threadblock
|
724 |
+
const int thread_n_blocks, // same for n dimension (output)
|
725 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
726 |
+
const int stages, // number of stages for the async global->shared
|
727 |
+
// fetch pipeline
|
728 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
729 |
+
// with a separate quantization scale
|
730 |
+
>
|
731 |
+
__global__ void Marlin(
|
732 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
733 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
734 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
735 |
+
const int4* __restrict__ s, // fp16 quantization scales of shape
|
736 |
+
// (k/groupsize)xn
|
737 |
+
int prob_m, // batch dimension m
|
738 |
+
int prob_n, // output dimension n
|
739 |
+
int prob_k, // reduction dimension k
|
740 |
+
int* locks // extra global storage for barrier synchronization
|
741 |
+
) {
|
742 |
+
// Marlin is not implemented yet for SM < 8.0
|
743 |
+
assert(false);
|
744 |
+
return;
|
745 |
+
}
|
746 |
+
|
747 |
+
#endif
|
748 |
+
|
749 |
+
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
750 |
+
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
751 |
+
// we want relatively few warps to have many registers per warp and small tiles.
|
752 |
+
const int USER_THREADS =
|
753 |
+
256; // Note: This is only used with user-provided thread_k/n
|
754 |
+
const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
755 |
+
const int SHARED_MEM =
|
756 |
+
96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
|
757 |
+
|
758 |
+
static constexpr int min_thread_n = 64;
|
759 |
+
static constexpr int min_thread_k = 64;
|
760 |
+
|
761 |
+
static constexpr int tile_size = 16;
|
762 |
+
static constexpr int max_par = 16;
|
763 |
+
|
764 |
+
static constexpr int pack_factor_4bit =
|
765 |
+
8; // We have 8 4-bit vals inside a 32 bit
|
766 |
+
|
767 |
+
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
768 |
+
GROUP_BLOCKS, NUM_THREADS) \
|
769 |
+
else if (thread_m_blocks == THREAD_M_BLOCKS && \
|
770 |
+
thread_n_blocks == THREAD_N_BLOCKS && \
|
771 |
+
thread_k_blocks == THREAD_K_BLOCKS && \
|
772 |
+
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
773 |
+
cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
774 |
+
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
|
775 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
776 |
+
SHARED_MEM); \
|
777 |
+
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
778 |
+
STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
|
779 |
+
A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \
|
780 |
+
}
|
781 |
+
|
782 |
+
typedef struct {
|
783 |
+
int thread_k;
|
784 |
+
int thread_n;
|
785 |
+
int num_threads;
|
786 |
+
} thread_config_t;
|
787 |
+
|
788 |
+
thread_config_t small_batch_thread_configs[] = {
|
789 |
+
// Ordered by priority
|
790 |
+
|
791 |
+
// thread_k, thread_n, num_threads
|
792 |
+
{128, 128, 256}, // Default
|
793 |
+
{128, 64, 128}, // Reduce N 2X, same K
|
794 |
+
{64, 256, 256}, // Reduce K 2X, increase N 2X
|
795 |
+
{64, 128, 128}, // Reduce K 2X, same N
|
796 |
+
};
|
797 |
+
|
798 |
+
thread_config_t large_batch_thread_configs[] = {
|
799 |
+
// Ordered by priority
|
800 |
+
|
801 |
+
// thread_k, thread_n, num_threads
|
802 |
+
{64, 256, 256}, // Default
|
803 |
+
{128, 128, 256}, // Reduce N 2X, increase K 2X
|
804 |
+
{64, 128, 128}, // Reduce N 2X, same K
|
805 |
+
{128, 64, 128}, // Reduce N 4X, increase K 2X
|
806 |
+
};
|
807 |
+
|
808 |
+
bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
|
809 |
+
int prob_k) {
|
810 |
+
// Sanity
|
811 |
+
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
812 |
+
th_config.num_threads == -1) {
|
813 |
+
return false;
|
814 |
+
}
|
815 |
+
|
816 |
+
// Verify K/N are divisible by thread K/N
|
817 |
+
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
818 |
+
return false;
|
819 |
+
}
|
820 |
+
|
821 |
+
// thread_k can be only 128 or 64 (because it must be less than groupsize
|
822 |
+
// which is 128)
|
823 |
+
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
|
824 |
+
return false;
|
825 |
+
}
|
826 |
+
|
827 |
+
// Verify min for thread K/N
|
828 |
+
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
829 |
+
return false;
|
830 |
+
}
|
831 |
+
|
832 |
+
// num_threads must be at least 128 (= 4 warps)
|
833 |
+
if (th_config.num_threads < 128) {
|
834 |
+
return false;
|
835 |
+
}
|
836 |
+
|
837 |
+
return true;
|
838 |
+
}
|
839 |
+
|
840 |
+
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
841 |
+
if (prob_m <= 16) {
|
842 |
+
for (auto th_config : small_batch_thread_configs) {
|
843 |
+
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
844 |
+
return th_config;
|
845 |
+
}
|
846 |
+
}
|
847 |
+
|
848 |
+
} else {
|
849 |
+
for (auto th_config : large_batch_thread_configs) {
|
850 |
+
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
851 |
+
return th_config;
|
852 |
+
}
|
853 |
+
}
|
854 |
+
}
|
855 |
+
|
856 |
+
return thread_config_t{-1, -1, -1};
|
857 |
+
}
|
858 |
+
|
859 |
+
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
860 |
+
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
861 |
+
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
862 |
+
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
863 |
+
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
864 |
+
__CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
865 |
+
__CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
866 |
+
__CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
867 |
+
__CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
868 |
+
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
869 |
+
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
|
870 |
+
|
871 |
+
void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
|
872 |
+
int prob_n, int prob_k, void* workspace, int groupsize = -1,
|
873 |
+
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
|
874 |
+
int thread_n = -1, int sms = -1, int max_par = 16) {
|
875 |
+
int tot_m = prob_m;
|
876 |
+
int tot_m_blocks = ceildiv(tot_m, 16);
|
877 |
+
int pad = 16 * tot_m_blocks - tot_m;
|
878 |
+
|
879 |
+
if (sms == -1)
|
880 |
+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
881 |
+
|
882 |
+
// Set thread config
|
883 |
+
thread_config_t th_config;
|
884 |
+
if (thread_k != -1 && thread_n != -1) {
|
885 |
+
// User-defined config
|
886 |
+
th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
|
887 |
+
} else {
|
888 |
+
// Auto config
|
889 |
+
th_config = determine_thread_config(prob_m, prob_n, prob_k);
|
890 |
+
}
|
891 |
+
|
892 |
+
if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
893 |
+
throw std::runtime_error(
|
894 |
+
"Invalid thread config: thread_k = " + str(th_config.thread_k) +
|
895 |
+
", thread_n = " + str(th_config.thread_n) +
|
896 |
+
", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
|
897 |
+
str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
|
898 |
+
}
|
899 |
+
|
900 |
+
// Uncomment for debug
|
901 |
+
// std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
|
902 |
+
// ", thread_n = " + str(th_config.thread_n) +
|
903 |
+
// ", num_threads = " + str(th_config.num_threads) + " for
|
904 |
+
// MKN = [" + str(prob_m) +
|
905 |
+
// ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
|
906 |
+
|
907 |
+
int num_threads = th_config.num_threads;
|
908 |
+
thread_k = th_config.thread_k;
|
909 |
+
thread_n = th_config.thread_n;
|
910 |
+
|
911 |
+
int thread_k_blocks = thread_k / 16;
|
912 |
+
int thread_n_blocks = thread_n / 16;
|
913 |
+
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
|
914 |
+
int blocks = sms;
|
915 |
+
|
916 |
+
if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
|
917 |
+
return;
|
918 |
+
}
|
919 |
+
|
920 |
+
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
921 |
+
" is not divisible by thread_n = ", thread_n);
|
922 |
+
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
923 |
+
" is not divisible by thread_k = ", thread_k);
|
924 |
+
if (group_blocks != -1) {
|
925 |
+
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
926 |
+
" is not divisible by group_blocks = ", group_blocks);
|
927 |
+
}
|
928 |
+
|
929 |
+
const int4* A_ptr = (const int4*)A;
|
930 |
+
const int4* B_ptr = (const int4*)B;
|
931 |
+
int4* C_ptr = (int4*)C;
|
932 |
+
const int4* s_ptr = (const int4*)s;
|
933 |
+
|
934 |
+
int* locks = (int*)workspace;
|
935 |
+
|
936 |
+
for (int i = 0; i < tot_m_blocks; i += 4) {
|
937 |
+
int thread_m_blocks = tot_m_blocks - i;
|
938 |
+
prob_m = tot_m - 16 * i;
|
939 |
+
int par = 1;
|
940 |
+
if (thread_m_blocks > 4) {
|
941 |
+
// Note that parallel > 1 currently only works for inputs without any
|
942 |
+
// padding
|
943 |
+
par = (16 * thread_m_blocks - pad) / 64;
|
944 |
+
if (par > max_par) par = max_par;
|
945 |
+
prob_m = 64 * par;
|
946 |
+
i += 4 * (par - 1);
|
947 |
+
thread_m_blocks = 4;
|
948 |
+
}
|
949 |
+
|
950 |
+
// For compilation speed, we only define the kernel configurations that have
|
951 |
+
// seemed useful (in terms of performance) in our testing, however many more
|
952 |
+
// are, in principle, possible.
|
953 |
+
if (false) {
|
954 |
+
}
|
955 |
+
CALL_IF(8, 8, 256)
|
956 |
+
CALL_IF(16, 4, 256)
|
957 |
+
CALL_IF(8, 4, 128)
|
958 |
+
CALL_IF(4, 8, 128)
|
959 |
+
else {
|
960 |
+
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
|
961 |
+
", " + str(prob_k) + ", " + str(prob_n) + "]" +
|
962 |
+
", groupsize = " + str(groupsize) +
|
963 |
+
", thread_m_blocks = " + str(thread_m_blocks) +
|
964 |
+
", thread_n_blocks = " + str(thread_n_blocks) +
|
965 |
+
", thread_k_blocks = " + str(thread_k_blocks));
|
966 |
+
}
|
967 |
+
|
968 |
+
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
969 |
+
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
|
970 |
+
}
|
971 |
+
}
|
972 |
+
|
973 |
+
} // namespace marlin_dense
|
974 |
+
|
975 |
+
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
976 |
+
torch::Tensor& b_scales, torch::Tensor& workspace,
|
977 |
+
int64_t size_m, int64_t size_n, int64_t size_k) {
|
978 |
+
// Verify M
|
979 |
+
TORCH_CHECK(size_m == a.size(0),
|
980 |
+
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
981 |
+
", size_m = " + str(size_m));
|
982 |
+
|
983 |
+
// Verify K
|
984 |
+
TORCH_CHECK(size_k == a.size(1),
|
985 |
+
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
986 |
+
", size_k = " + str(size_k));
|
987 |
+
TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
|
988 |
+
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
|
989 |
+
str(marlin_dense::tile_size));
|
990 |
+
TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
|
991 |
+
"Shape mismatch: b_q_weight.size(0) = " +
|
992 |
+
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
993 |
+
", tile_size = " + str(marlin_dense::tile_size));
|
994 |
+
|
995 |
+
// Verify N
|
996 |
+
TORCH_CHECK(b_scales.size(1) == size_n,
|
997 |
+
"b_scales.size(1) = " + str(b_scales.size(1)) +
|
998 |
+
", size_n = " + str(size_n));
|
999 |
+
TORCH_CHECK(
|
1000 |
+
b_q_weight.size(1) % marlin_dense::tile_size == 0,
|
1001 |
+
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
1002 |
+
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
|
1003 |
+
|
1004 |
+
int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
|
1005 |
+
marlin_dense::pack_factor_4bit;
|
1006 |
+
TORCH_CHECK(
|
1007 |
+
size_n == actual_size_n,
|
1008 |
+
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
1009 |
+
|
1010 |
+
// Verify A device and strides
|
1011 |
+
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
1012 |
+
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
1013 |
+
|
1014 |
+
// Verify B device and strides
|
1015 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
1016 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
1017 |
+
|
1018 |
+
// Verify scales device and strides
|
1019 |
+
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
1020 |
+
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
1021 |
+
|
1022 |
+
// Alloc C matrix
|
1023 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
1024 |
+
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
1025 |
+
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
1026 |
+
|
1027 |
+
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
1028 |
+
// auto -1)
|
1029 |
+
int thread_k = -1;
|
1030 |
+
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
1031 |
+
// auto -1)
|
1032 |
+
int thread_n = -1;
|
1033 |
+
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
|
1034 |
+
int sms = -1;
|
1035 |
+
|
1036 |
+
// Detect groupsize
|
1037 |
+
if (b_scales.size(0) != 1) {
|
1038 |
+
TORCH_CHECK(size_k % b_scales.size(0) == 0,
|
1039 |
+
"size_k = " + str(size_k) +
|
1040 |
+
", is not divisible by b_scales.size(0) = " +
|
1041 |
+
str(b_scales.size(0)));
|
1042 |
+
}
|
1043 |
+
int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0);
|
1044 |
+
|
1045 |
+
// Verify groupsize
|
1046 |
+
TORCH_CHECK(groupsize == -1 || groupsize == 128,
|
1047 |
+
"Unexpected groupsize = " + str(groupsize));
|
1048 |
+
|
1049 |
+
// Verify workspace size
|
1050 |
+
TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
|
1051 |
+
"size_n = " + str(size_n) +
|
1052 |
+
", is not divisible by min_thread_n = " +
|
1053 |
+
str(marlin_dense::min_thread_n));
|
1054 |
+
int min_workspace_size =
|
1055 |
+
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
|
1056 |
+
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
1057 |
+
"workspace.numel = " + str(workspace.numel()) +
|
1058 |
+
" is below min_workspace_size = " + str(min_workspace_size));
|
1059 |
+
|
1060 |
+
int dev = a.get_device();
|
1061 |
+
marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
|
1062 |
+
b_scales.data_ptr(), size_m, size_n, size_k,
|
1063 |
+
workspace.data_ptr(), groupsize, dev,
|
1064 |
+
at::cuda::getCurrentCUDAStream(dev), thread_k,
|
1065 |
+
thread_n, sms, marlin_dense::max_par);
|
1066 |
+
|
1067 |
+
return c;
|
1068 |
+
}
|
marlin/qqq/marlin_qqq_gemm_kernel.cu
ADDED
@@ -0,0 +1,1243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Adapted from
|
3 |
+
* https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu
|
4 |
+
* https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp
|
5 |
+
* Modified by HandH1998
|
6 |
+
* Copyright (C) 2024 HandH1998
|
7 |
+
* Copyright (C) Marlin.2024 Elias Frantar
|
8 |
+
*
|
9 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
* you may not use this file except in compliance with the License.
|
11 |
+
* You may obtain a copy of the License at
|
12 |
+
*
|
13 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
*
|
15 |
+
* Unless required by applicable law or agreed to in writing, software
|
16 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
* See the License for the specific language governing permissions and
|
19 |
+
* limitations under the License.
|
20 |
+
*/
|
21 |
+
|
22 |
+
#include <torch/all.h>
|
23 |
+
|
24 |
+
#include <ATen/cuda/CUDAContext.h>
|
25 |
+
#include <c10/cuda/CUDAGuard.h>
|
26 |
+
#include <cuda.h>
|
27 |
+
#include <cuda_fp16.h>
|
28 |
+
#include <cuda_runtime.h>
|
29 |
+
|
30 |
+
#include <iostream>
|
31 |
+
|
32 |
+
#include "../dense/common/base.h"
|
33 |
+
|
34 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
35 |
+
#include "../dense/common/mem.h"
|
36 |
+
#endif
|
37 |
+
|
38 |
+
template <typename T>
|
39 |
+
inline std::string str(T x) {
|
40 |
+
return std::to_string(x);
|
41 |
+
}
|
42 |
+
|
43 |
+
namespace {
|
44 |
+
|
45 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
46 |
+
|
47 |
+
using I4 = Vec<int, 4>;
|
48 |
+
// Matrix fragments for tensor core instructions; their precise layout is
|
49 |
+
// documented here:
|
50 |
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type
|
51 |
+
using FragA = Vec<uint32_t, 2>;
|
52 |
+
using FragB = Vec<uint32_t, 1>;
|
53 |
+
using FragC = Vec<int, 4>;
|
54 |
+
using FragS_GROUP = Vec<half2, 1>; // weight per-group quantization scales
|
55 |
+
using FragS_CHANNEL =
|
56 |
+
Vec<float, 2>; // weight per-channel quantization scales or activaton
|
57 |
+
// per-token quantization scales
|
58 |
+
|
59 |
+
// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however,
|
60 |
+
// cp.async.ca can support BYTES = 4, 8, 16;
|
61 |
+
// as s_tok's shape is equal to prob_m, we need set s_tok to float type,
|
62 |
+
// and cp_size = 1 float, i.e., 4 BYTES
|
63 |
+
// Asynchronous global->shared copy for activation quantizaton scales s_tok
|
64 |
+
__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) {
|
65 |
+
const int BYTES = 4;
|
66 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
67 |
+
asm volatile(
|
68 |
+
"{\n"
|
69 |
+
" cp.async.ca.shared.global [%0], [%1], %2;\n"
|
70 |
+
"}\n" ::"r"(smem),
|
71 |
+
"l"(glob_ptr), "n"(BYTES));
|
72 |
+
}
|
73 |
+
|
74 |
+
// m16n8k16 tensor core mma instruction with int8 inputs and int32
|
75 |
+
// output/accumulation.
|
76 |
+
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
|
77 |
+
FragC& frag_c) {
|
78 |
+
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
79 |
+
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
80 |
+
int* c = reinterpret_cast<int*>(&frag_c);
|
81 |
+
asm volatile(
|
82 |
+
"mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 "
|
83 |
+
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
84 |
+
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
85 |
+
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
86 |
+
"r"(c[3]));
|
87 |
+
}
|
88 |
+
|
89 |
+
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
90 |
+
// memory, directly in int8 tensor core layout.
|
91 |
+
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
92 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
93 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
94 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
95 |
+
: "=r"(a[0]), "=r"(a[1])
|
96 |
+
: "r"(smem));
|
97 |
+
}
|
98 |
+
|
99 |
+
inline __device__ half2 float2_to_half2(float2 f) {
|
100 |
+
uint32_t res;
|
101 |
+
// NOTE(HandH1998): h0,h1 should be uint16_t, not half
|
102 |
+
uint16_t h0, h1;
|
103 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x));
|
104 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y));
|
105 |
+
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1));
|
106 |
+
return reinterpret_cast<half2&>(res);
|
107 |
+
}
|
108 |
+
|
109 |
+
inline __device__ float int32_to_float(int h) {
|
110 |
+
float res;
|
111 |
+
asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h));
|
112 |
+
return res;
|
113 |
+
}
|
114 |
+
|
115 |
+
// Lookup-table based 3-input logical operation; explicitly used for
|
116 |
+
// dequantization as the compiler does not seem to automatically recognize it in
|
117 |
+
// all cases.
|
118 |
+
template <int lut>
|
119 |
+
__device__ inline int lop3(int a, int b, int c) {
|
120 |
+
int res;
|
121 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
122 |
+
: "=r"(res)
|
123 |
+
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
124 |
+
return res;
|
125 |
+
}
|
126 |
+
|
127 |
+
// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
|
128 |
+
// for weight per channel dequant.
|
129 |
+
__device__ inline FragB dequant_per_channel(int q) {
|
130 |
+
static constexpr int MASK = 0xf0f0f0f0;
|
131 |
+
FragB frag_b;
|
132 |
+
frag_b[0] = (q & MASK);
|
133 |
+
return frag_b;
|
134 |
+
}
|
135 |
+
|
136 |
+
// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
|
137 |
+
// for weight per group dequant.
|
138 |
+
__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
|
139 |
+
static constexpr uint32_t LO = 0x000f000f;
|
140 |
+
static constexpr uint32_t HI = 0x00f000f0;
|
141 |
+
static constexpr uint32_t EX = 0x64006400;
|
142 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
143 |
+
uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
144 |
+
uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
145 |
+
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
146 |
+
// directly into `SUB` and `ADD`.
|
147 |
+
static constexpr uint32_t SUB = 0x64086408;
|
148 |
+
static constexpr uint32_t MUL = 0x2c002c00;
|
149 |
+
static constexpr uint32_t ADD = 0xd480d480;
|
150 |
+
*reinterpret_cast<half2*>(&t0) = __hsub2(
|
151 |
+
*reinterpret_cast<half2*>(&t0), *reinterpret_cast<const half2*>(&SUB));
|
152 |
+
*reinterpret_cast<half2*>(&t1) = __hfma2(
|
153 |
+
*reinterpret_cast<half2*>(&t1), *reinterpret_cast<const half2*>(&MUL),
|
154 |
+
*reinterpret_cast<const half2*>(&ADD));
|
155 |
+
|
156 |
+
uint16_t s = reinterpret_cast<uint16_t*>(&frag_s)[i];
|
157 |
+
uint32_t double_s;
|
158 |
+
// pack 2xfp16 to half2
|
159 |
+
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s));
|
160 |
+
// dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4
|
161 |
+
// half, respectively)
|
162 |
+
static constexpr uint32_t MAGIC_NUM = 0x64806480;
|
163 |
+
*reinterpret_cast<half2*>(&t0) = __hfma2(
|
164 |
+
*reinterpret_cast<half2*>(&t0), *reinterpret_cast<half2*>(&double_s),
|
165 |
+
*reinterpret_cast<const half2*>(&MAGIC_NUM));
|
166 |
+
*reinterpret_cast<half2*>(&t1) = __hfma2(
|
167 |
+
*reinterpret_cast<half2*>(&t1), *reinterpret_cast<half2*>(&double_s),
|
168 |
+
*reinterpret_cast<const half2*>(&MAGIC_NUM));
|
169 |
+
// take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4
|
170 |
+
// int8 into 1 uint32
|
171 |
+
FragB frag_b;
|
172 |
+
uint32_t uint8s;
|
173 |
+
static constexpr uint32_t MASK_0246 = 0x6420;
|
174 |
+
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
|
175 |
+
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
176 |
+
: "=r"(uint8s)
|
177 |
+
: "r"(t0), "r"(t1), "n"(MASK_0246));
|
178 |
+
frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK);
|
179 |
+
return frag_b;
|
180 |
+
}
|
181 |
+
|
182 |
+
template <const int threads, // number of threads in a threadblock
|
183 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
184 |
+
// dimension (batchsize) of the
|
185 |
+
// threadblock
|
186 |
+
const int thread_n_blocks, // same for n dimension (output)
|
187 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
188 |
+
const int stages, // number of stages for the async global->shared
|
189 |
+
// fetch pipeline
|
190 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
191 |
+
// with a separate quantization scale
|
192 |
+
>
|
193 |
+
__global__ void Marlin(
|
194 |
+
const int4* __restrict__ A, // int8 input matrix of shape mxk
|
195 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
196 |
+
int4* __restrict__ C, // int32 global_reduce buffer of shape
|
197 |
+
// (max_par*16*4)xn, as int8 tensor core's output is
|
198 |
+
// int32 dtype
|
199 |
+
int4* __restrict__ D, // fp16 output buffer of shape mxn
|
200 |
+
const float* __restrict__ s_tok, // fp32 activation per-token quantization
|
201 |
+
// scales of shape mx1
|
202 |
+
const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
|
203 |
+
// scales of shape 1xn
|
204 |
+
const int4* __restrict__ s_group, // fp16 weight per-group quantization
|
205 |
+
// scales of shape (k/groupsize)xn, when
|
206 |
+
// group_blocks=-1, it should be nullptr
|
207 |
+
int prob_m, // batch dimension m
|
208 |
+
int prob_n, // output dimension n
|
209 |
+
int prob_k, // reduction dimension k
|
210 |
+
int* locks // extra global storage for barrier synchronization
|
211 |
+
) {
|
212 |
+
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
213 |
+
// same size, which might involve multiple column "slices" (of width 16 *
|
214 |
+
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
215 |
+
// example:
|
216 |
+
// 0 1 3
|
217 |
+
// 0 2 3
|
218 |
+
// 1 2 4
|
219 |
+
// While this kind of partitioning makes things somewhat more complicated, it
|
220 |
+
// ensures good utilization of all SMs for many kinds of shape and GPU
|
221 |
+
// configurations, while requiring as few slow global cross-threadblock
|
222 |
+
// reductions as possible.
|
223 |
+
|
224 |
+
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
225 |
+
// better partitioning with less reductions
|
226 |
+
int parallel = 1;
|
227 |
+
if (prob_m > 16 * thread_m_blocks) {
|
228 |
+
parallel = prob_m / (16 * thread_m_blocks);
|
229 |
+
prob_m = 16 * thread_m_blocks;
|
230 |
+
}
|
231 |
+
|
232 |
+
int k_tiles = prob_k / 16 / thread_k_blocks;
|
233 |
+
int n_tiles = prob_n / 16 / thread_n_blocks;
|
234 |
+
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
|
235 |
+
// Ensure that the number of tiles in each stripe is a multiple of the
|
236 |
+
// groupsize; this avoids an annoying special case where a stripe starts in
|
237 |
+
// the middle of group.
|
238 |
+
if constexpr (group_blocks != -1)
|
239 |
+
iters = (group_blocks / thread_k_blocks) *
|
240 |
+
ceildiv(iters, (group_blocks / thread_k_blocks));
|
241 |
+
|
242 |
+
int slice_row = (iters * blockIdx.x) % k_tiles;
|
243 |
+
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
244 |
+
int slice_col = slice_col_par;
|
245 |
+
int slice_iters; // number of threadblock tiles in the current slice
|
246 |
+
int slice_count =
|
247 |
+
0; // total number of active threadblocks in the current slice
|
248 |
+
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
249 |
+
// top
|
250 |
+
|
251 |
+
// We can easily implement parallel problem execution by just remapping
|
252 |
+
// indices and advancing global pointers
|
253 |
+
if (slice_col_par >= n_tiles) {
|
254 |
+
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16;
|
255 |
+
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4;
|
256 |
+
D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
257 |
+
s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
|
258 |
+
locks += (slice_col_par / n_tiles) * n_tiles;
|
259 |
+
slice_col = slice_col_par % n_tiles;
|
260 |
+
}
|
261 |
+
|
262 |
+
// Compute all information about the current slice which is required for
|
263 |
+
// synchronization.
|
264 |
+
auto init_slice = [&]() {
|
265 |
+
slice_iters =
|
266 |
+
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
267 |
+
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
268 |
+
if (slice_iters == 0) return;
|
269 |
+
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
270 |
+
slice_count = 1;
|
271 |
+
slice_idx = 0;
|
272 |
+
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
273 |
+
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
274 |
+
int col_off = col_first - k_tiles * slice_col_par;
|
275 |
+
slice_count = ceildiv(k_tiles - col_off, iters);
|
276 |
+
if (col_off > 0) slice_count++;
|
277 |
+
int delta_first = iters * blockIdx.x - col_first;
|
278 |
+
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
279 |
+
slice_idx = slice_count - 1;
|
280 |
+
else {
|
281 |
+
slice_idx = slice_count - 1 - delta_first / iters;
|
282 |
+
if (col_off > 0) slice_idx--;
|
283 |
+
}
|
284 |
+
}
|
285 |
+
if (slice_col == n_tiles) {
|
286 |
+
A += 16 * thread_m_blocks * prob_k / 16;
|
287 |
+
C += 16 * thread_m_blocks * prob_n / 4;
|
288 |
+
D += 16 * thread_m_blocks * prob_n / 8;
|
289 |
+
s_tok += 16 * thread_m_blocks;
|
290 |
+
locks += n_tiles;
|
291 |
+
slice_col = 0;
|
292 |
+
}
|
293 |
+
};
|
294 |
+
init_slice();
|
295 |
+
|
296 |
+
int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory
|
297 |
+
// We typically use `constexpr` to indicate that this value is a compile-time
|
298 |
+
// constant
|
299 |
+
constexpr int a_sh_stride =
|
300 |
+
16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory
|
301 |
+
constexpr int a_gl_rd_delta_o =
|
302 |
+
16 * thread_k_blocks /
|
303 |
+
16; // delta between subsequent A tiles in global memory
|
304 |
+
int a_gl_rd_delta_i =
|
305 |
+
a_gl_stride *
|
306 |
+
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
|
307 |
+
constexpr int a_sh_wr_delta =
|
308 |
+
a_sh_stride *
|
309 |
+
(threads / a_gl_rd_delta_o); // between shared memory writes
|
310 |
+
constexpr int a_sh_rd_delta_o =
|
311 |
+
1 * ((threads / 32) /
|
312 |
+
(thread_n_blocks / 4)); // between shared memory tile reads
|
313 |
+
constexpr int a_sh_rd_delta_i =
|
314 |
+
a_sh_stride * 16; // within a shared memory tile
|
315 |
+
constexpr int a_sh_stage =
|
316 |
+
a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
|
317 |
+
constexpr int a_sh_wr_iters =
|
318 |
+
ceildiv(a_sh_stage,
|
319 |
+
a_sh_wr_delta); // number of shared write iterations for a tile
|
320 |
+
|
321 |
+
int b_gl_stride = 16 * prob_n / 32;
|
322 |
+
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
|
323 |
+
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
324 |
+
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
|
325 |
+
constexpr int b_sh_wr_delta = threads;
|
326 |
+
constexpr int b_sh_rd_delta = threads;
|
327 |
+
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
328 |
+
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
329 |
+
|
330 |
+
constexpr int s_tok_sh_stride = 16 * thread_m_blocks;
|
331 |
+
|
332 |
+
constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4;
|
333 |
+
|
334 |
+
int s_group_gl_stride = prob_n / 8;
|
335 |
+
constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8;
|
336 |
+
constexpr int s_group_sh_stage = s_group_sh_stride;
|
337 |
+
int s_group_gl_rd_delta = s_group_gl_stride;
|
338 |
+
|
339 |
+
// Global A read index of current thread.
|
340 |
+
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
341 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
342 |
+
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
343 |
+
// Shared write index of current thread.
|
344 |
+
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
345 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
346 |
+
// Shared read index.
|
347 |
+
// NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix
|
348 |
+
int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16);
|
349 |
+
a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
350 |
+
|
351 |
+
int b_gl_rd =
|
352 |
+
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
|
353 |
+
b_gl_rd += b_sh_stride * slice_col;
|
354 |
+
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
355 |
+
int b_sh_wr = threadIdx.x;
|
356 |
+
int b_sh_rd = threadIdx.x;
|
357 |
+
|
358 |
+
int s_tok_gl_rd = threadIdx.x;
|
359 |
+
// NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
|
360 |
+
// 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
|
361 |
+
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
|
362 |
+
// s_tok's size is not fixed, we can not shuffle before inference we shuffle
|
363 |
+
// it when fetching s_tok from global memory to shared memory, that's why
|
364 |
+
// s_tok_sh_wr is like this
|
365 |
+
int s_tok_sh_wr =
|
366 |
+
(threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8;
|
367 |
+
int s_tok_sh_rd = (threadIdx.x % 32) / 4;
|
368 |
+
bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
|
369 |
+
|
370 |
+
int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
|
371 |
+
int s_ch_sh_wr = threadIdx.x;
|
372 |
+
int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
373 |
+
2 * ((threadIdx.x % 32) % 4);
|
374 |
+
bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
|
375 |
+
|
376 |
+
int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd;
|
377 |
+
bool s_group_sh_wr_pred;
|
378 |
+
if constexpr (group_blocks != -1) {
|
379 |
+
s_group_gl_rd =
|
380 |
+
s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
381 |
+
s_group_sh_stride * slice_col + threadIdx.x;
|
382 |
+
s_group_sh_wr = threadIdx.x;
|
383 |
+
// NOTE(HandH1998): s_group_sh_rd is related to mma output C
|
384 |
+
s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
385 |
+
(threadIdx.x % 32) / 4;
|
386 |
+
s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride;
|
387 |
+
}
|
388 |
+
|
389 |
+
// Precompute which thread should not read memory in which iterations; this is
|
390 |
+
// needed if there are more threads than required for a certain tilesize or
|
391 |
+
// when the batchsize is not a multiple of 16.
|
392 |
+
bool a_sh_wr_pred[a_sh_wr_iters];
|
393 |
+
#pragma unroll
|
394 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
395 |
+
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
396 |
+
|
397 |
+
// To ensure that writing and reading A tiles to/from shared memory, the
|
398 |
+
// latter in fragment format, is fully bank conflict free, we need to use a
|
399 |
+
// rather fancy XOR-based layout. The key here is that neither reads nor
|
400 |
+
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
|
401 |
+
// same shared memory banks. Further, it seems (based on NSight-Compute) that
|
402 |
+
// each warp must also write a consecutive memory segment?
|
403 |
+
auto transform_a = [&](int i) {
|
404 |
+
int row = i / a_gl_rd_delta_o;
|
405 |
+
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
406 |
+
};
|
407 |
+
// Since the computation of this remapping is non-trivial and, due to our main
|
408 |
+
// loop unrolls, all shared memory accesses are static, we simply precompute
|
409 |
+
// both transformed reads and writes.
|
410 |
+
int a_sh_wr_trans[a_sh_wr_iters];
|
411 |
+
#pragma unroll
|
412 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
413 |
+
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
414 |
+
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
415 |
+
#pragma unroll
|
416 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
417 |
+
#pragma unroll
|
418 |
+
for (int j = 0; j < thread_m_blocks; j++)
|
419 |
+
a_sh_rd_trans[i][j] =
|
420 |
+
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
421 |
+
}
|
422 |
+
|
423 |
+
// Since B-accesses have non-constant stride they have to be computed at
|
424 |
+
// runtime; we break dependencies between subsequent accesses with a tile by
|
425 |
+
// maintining multiple pointers (we have enough registers), a tiny
|
426 |
+
// optimization.
|
427 |
+
const int4* B_ptr[b_sh_wr_iters];
|
428 |
+
#pragma unroll
|
429 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
430 |
+
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
431 |
+
|
432 |
+
extern __shared__ int4 sh[];
|
433 |
+
// Shared memory storage for global fetch pipelines.
|
434 |
+
// NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages *
|
435 |
+
// a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage)
|
436 |
+
int4* sh_a = sh;
|
437 |
+
int4* sh_b = sh_a + (stages * a_sh_stage);
|
438 |
+
int4* sh_s_tok = sh_b + (stages * b_sh_stage);
|
439 |
+
int4* sh_s_ch = sh_s_tok + s_tok_sh_stride;
|
440 |
+
int4* sh_s_group = sh_s_ch + s_ch_sh_stride;
|
441 |
+
|
442 |
+
// Register storage for double buffer of shared memory reads.
|
443 |
+
FragA frag_a[2][thread_m_blocks];
|
444 |
+
I4 frag_b_quant[2];
|
445 |
+
FragC frag_c[thread_m_blocks][4][2];
|
446 |
+
FragS_GROUP frag_s_group[2][4];
|
447 |
+
FragS_CHANNEL frag_s_tok[thread_m_blocks];
|
448 |
+
FragS_CHANNEL frag_s_ch[2][4];
|
449 |
+
|
450 |
+
// Zero accumulators.
|
451 |
+
auto zero_accums = [&]() {
|
452 |
+
#pragma unroll
|
453 |
+
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
454 |
+
reinterpret_cast<int*>(frag_c)[i] = 0;
|
455 |
+
};
|
456 |
+
|
457 |
+
// Asynchronously fetch the next A, B and s tile from global to the next
|
458 |
+
// shared memory pipeline location.
|
459 |
+
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
460 |
+
if (pred) {
|
461 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
462 |
+
#pragma unroll
|
463 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
464 |
+
cp_async4_pred(
|
465 |
+
&sh_a_stage[a_sh_wr_trans[i]],
|
466 |
+
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
467 |
+
a_sh_wr_pred[i]);
|
468 |
+
}
|
469 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
470 |
+
#pragma unroll
|
471 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
472 |
+
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
|
473 |
+
B_ptr[i] += b_gl_rd_delta_o;
|
474 |
+
}
|
475 |
+
// Only fetch scales if this tile starts a new group
|
476 |
+
if constexpr (group_blocks != -1) {
|
477 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
478 |
+
int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe;
|
479 |
+
if (s_group_sh_wr_pred)
|
480 |
+
cp_async4(&sh_s_group_stage[s_group_sh_wr],
|
481 |
+
&s_group[s_group_gl_rd]);
|
482 |
+
s_group_gl_rd += s_group_gl_rd_delta;
|
483 |
+
}
|
484 |
+
}
|
485 |
+
}
|
486 |
+
// Insert a fence even when we are winding down the pipeline to ensure that
|
487 |
+
// waiting is also correct at this point.
|
488 |
+
cp_async_fence();
|
489 |
+
};
|
490 |
+
|
491 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
492 |
+
auto wait_for_stage = [&]() {
|
493 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
494 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
495 |
+
// shared memory load is fully complete (as it may otherwise be
|
496 |
+
// overwritten).
|
497 |
+
cp_async_wait<stages - 2>();
|
498 |
+
__syncthreads();
|
499 |
+
};
|
500 |
+
|
501 |
+
// Load the next sub-tile from the current location in the shared memory pipe
|
502 |
+
// into the current register buffer.
|
503 |
+
auto fetch_to_registers = [&](int k, int pipe) {
|
504 |
+
// It may seem inefficient that we reload the groups for every sub-tile;
|
505 |
+
// however, this does not seem to be a significant bottleneck, while some
|
506 |
+
// theoretically better attempts have lead to bad instruction ordering by
|
507 |
+
// the compiler and correspondingly a noticeable drop in performance.
|
508 |
+
if constexpr (group_blocks != -1) {
|
509 |
+
int4* sh_s_group_stage =
|
510 |
+
sh_s_group +
|
511 |
+
s_group_sh_stage * ((group_blocks / thread_k_blocks) *
|
512 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
513 |
+
reinterpret_cast<int4*>(&frag_s_group[k % 2])[0] =
|
514 |
+
sh_s_group_stage[s_group_sh_rd];
|
515 |
+
}
|
516 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
517 |
+
#pragma unroll
|
518 |
+
for (int i = 0; i < thread_m_blocks; i++)
|
519 |
+
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
520 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
521 |
+
frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
|
522 |
+
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
|
523 |
+
};
|
524 |
+
|
525 |
+
// Execute the actual tensor core matmul of a sub-tile.
|
526 |
+
auto matmul = [&](int k) {
|
527 |
+
// We have the m dimension as the inner loop in order to encourage overlapping
|
528 |
+
// dequantization and matmul operations.
|
529 |
+
#pragma unroll
|
530 |
+
for (int j = 0; j < 4; j++) {
|
531 |
+
int b_quant = frag_b_quant[k % 2][j];
|
532 |
+
// int b_quant_shift = b_quant << 4;
|
533 |
+
FragB frag_b0, frag_b1;
|
534 |
+
// If there are no groups, we can just scale the final output once and can
|
535 |
+
// avoid doing so for each weight.
|
536 |
+
if constexpr (group_blocks != -1) {
|
537 |
+
int b_quant_shift = b_quant >> 8;
|
538 |
+
frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0);
|
539 |
+
frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1);
|
540 |
+
} else {
|
541 |
+
int b_quant_shift = b_quant << 4;
|
542 |
+
frag_b0 = dequant_per_channel(b_quant);
|
543 |
+
frag_b1 = dequant_per_channel(b_quant_shift);
|
544 |
+
}
|
545 |
+
#pragma unroll
|
546 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
547 |
+
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
548 |
+
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
549 |
+
}
|
550 |
+
}
|
551 |
+
};
|
552 |
+
|
553 |
+
// Since we slice across the k dimension of a tile in order to increase the
|
554 |
+
// number of warps while keeping the n dimension of a tile reasonable, we have
|
555 |
+
// multiple warps that accumulate their partial sums of the same output
|
556 |
+
// location; which we have to reduce over in the end. We do in shared memory.
|
557 |
+
auto thread_block_reduce = [&]() {
|
558 |
+
constexpr int red_off = threads / b_sh_stride / 2;
|
559 |
+
if (red_off >= 1) {
|
560 |
+
int red_idx = threadIdx.x / b_sh_stride;
|
561 |
+
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
|
562 |
+
constexpr int red_sh_delta = b_sh_stride;
|
563 |
+
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
|
564 |
+
(threadIdx.x % b_sh_stride);
|
565 |
+
|
566 |
+
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
567 |
+
// unnecessary read or write iterations, e.g., for two warps we write only
|
568 |
+
// once by warp 1 and read only once by warp 0.
|
569 |
+
|
570 |
+
#pragma unroll
|
571 |
+
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
572 |
+
#pragma unroll
|
573 |
+
for (int i = red_off; i > 0; i /= 2) {
|
574 |
+
if (i <= red_idx && red_idx < 2 * i) {
|
575 |
+
#pragma unroll
|
576 |
+
for (int j = 0; j < 4 * 2; j++) {
|
577 |
+
int red_sh_wr =
|
578 |
+
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
579 |
+
if (i < red_off) {
|
580 |
+
int* c_rd =
|
581 |
+
reinterpret_cast<int*>(&sh[red_sh_delta * j + red_sh_rd]);
|
582 |
+
int* c_wr = reinterpret_cast<int*>(&sh[red_sh_wr]);
|
583 |
+
#pragma unroll
|
584 |
+
for (int k = 0; k < 4; k++)
|
585 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
586 |
+
c_rd[k] + c_wr[k];
|
587 |
+
}
|
588 |
+
sh[red_sh_wr] =
|
589 |
+
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
590 |
+
}
|
591 |
+
}
|
592 |
+
__syncthreads();
|
593 |
+
}
|
594 |
+
if (red_idx == 0) {
|
595 |
+
#pragma unroll
|
596 |
+
for (int i = 0; i < 4 * 2; i++) {
|
597 |
+
int* c_rd =
|
598 |
+
reinterpret_cast<int*>(&sh[red_sh_delta * i + red_sh_rd]);
|
599 |
+
#pragma unroll
|
600 |
+
for (int j = 0; j < 4; j++)
|
601 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
602 |
+
c_rd[j];
|
603 |
+
}
|
604 |
+
}
|
605 |
+
__syncthreads();
|
606 |
+
}
|
607 |
+
}
|
608 |
+
};
|
609 |
+
|
610 |
+
// Since multiple threadblocks may process parts of the same column slice, we
|
611 |
+
// finally have to globally reduce over the results. As the striped
|
612 |
+
// partitioning minimizes the number of such reductions and our outputs are
|
613 |
+
// usually rather small, we perform this reduction serially in L2 cache.
|
614 |
+
// global_reduce works on INT32 elements, which are the results of INT8 GEMM.
|
615 |
+
// This is why we need another INT32 maxtrix `C` to reduce instead of the
|
616 |
+
// original half matrix `D`.
|
617 |
+
auto global_reduce = [&](bool first = false, bool last = false) {
|
618 |
+
// We are very careful here to reduce directly in the output buffer to
|
619 |
+
// maximize L2 cache utilization in this step. To do this, we write out
|
620 |
+
// results in FP16 (but still reduce with FP32 compute).
|
621 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
622 |
+
if (threadIdx.x < active_threads) {
|
623 |
+
int c_gl_stride = prob_n / 4;
|
624 |
+
int c_gl_wr_delta_o = 8 * c_gl_stride;
|
625 |
+
int c_gl_wr_delta_i = 8 * (active_threads / 32);
|
626 |
+
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
|
627 |
+
8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
|
628 |
+
c_gl_wr += (4 * thread_n_blocks) * slice_col;
|
629 |
+
constexpr int c_sh_wr_delta = active_threads * 2;
|
630 |
+
int c_sh_wr = 2 * threadIdx.x;
|
631 |
+
|
632 |
+
int row = (threadIdx.x % 32) / 4;
|
633 |
+
|
634 |
+
if (!first) {
|
635 |
+
// Interestingly, doing direct global accesses here really seems to mess up
|
636 |
+
// the compiler and lead to slowdowns, hence we also use async-copies even
|
637 |
+
// though these fetches are not actually asynchronous.
|
638 |
+
#pragma unroll
|
639 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
640 |
+
cp_async4_pred(
|
641 |
+
&sh[c_sh_wr + c_sh_wr_delta * i],
|
642 |
+
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
643 |
+
c_gl_wr_delta_i * (i % 2)],
|
644 |
+
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
645 |
+
cp_async4_pred(
|
646 |
+
&sh[c_sh_wr + c_sh_wr_delta * i + 1],
|
647 |
+
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
648 |
+
c_gl_wr_delta_i * (i % 2) + 1],
|
649 |
+
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
650 |
+
}
|
651 |
+
cp_async_fence();
|
652 |
+
cp_async_wait<0>();
|
653 |
+
}
|
654 |
+
|
655 |
+
#pragma unroll
|
656 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
657 |
+
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
658 |
+
if (!first) {
|
659 |
+
int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta];
|
660 |
+
int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1];
|
661 |
+
#pragma unroll
|
662 |
+
for (int j = 0; j < 4; j++) {
|
663 |
+
reinterpret_cast<int*>(
|
664 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
665 |
+
reinterpret_cast<int*>(&d_red1)[j];
|
666 |
+
}
|
667 |
+
#pragma unroll
|
668 |
+
for (int j = 0; j < 4; j++) {
|
669 |
+
reinterpret_cast<int*>(
|
670 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] +=
|
671 |
+
reinterpret_cast<int*>(&d_red2)[j];
|
672 |
+
}
|
673 |
+
}
|
674 |
+
if (!last) {
|
675 |
+
int4 d1, d2;
|
676 |
+
#pragma unroll
|
677 |
+
for (int j = 0; j < 4; j++) {
|
678 |
+
reinterpret_cast<int*>(&d1)[j] = reinterpret_cast<int*>(
|
679 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)];
|
680 |
+
}
|
681 |
+
#pragma unroll
|
682 |
+
for (int j = 0; j < 4; j++) {
|
683 |
+
reinterpret_cast<int*>(&d2)[j] = reinterpret_cast<int*>(
|
684 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)];
|
685 |
+
}
|
686 |
+
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
687 |
+
d1;
|
688 |
+
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) +
|
689 |
+
1] = d2;
|
690 |
+
}
|
691 |
+
}
|
692 |
+
}
|
693 |
+
}
|
694 |
+
};
|
695 |
+
|
696 |
+
// Write out the reduce final result in the correct layout. We only actually
|
697 |
+
// reshuffle matrix fragments in this step, the reduction above is performed
|
698 |
+
// in fragment layout.
|
699 |
+
auto write_result = [&]() {
|
700 |
+
int d_gl_stride = prob_n / 8;
|
701 |
+
constexpr int d_sh_stride = 2 * thread_n_blocks + 1;
|
702 |
+
int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks));
|
703 |
+
constexpr int d_sh_rd_delta =
|
704 |
+
d_sh_stride * (threads / (2 * thread_n_blocks));
|
705 |
+
|
706 |
+
int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
707 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
708 |
+
d_gl_wr += (2 * thread_n_blocks) * slice_col;
|
709 |
+
int d_sh_wr =
|
710 |
+
(4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
|
711 |
+
d_sh_wr += 32 * (threadIdx.x / 32);
|
712 |
+
int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
713 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
714 |
+
|
715 |
+
int d_gl_wr_end = d_gl_stride * prob_m;
|
716 |
+
|
717 |
+
// We first reorder in shared memory to guarantee the most efficient final
|
718 |
+
// global write patterns
|
719 |
+
auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) {
|
720 |
+
float2 deq_res;
|
721 |
+
deq_res.x = int32_to_float(c0) * w_s[0] * a_s;
|
722 |
+
deq_res.y = int32_to_float(c1) * w_s[1] * a_s;
|
723 |
+
((half2*)sh)[idx] = float2_to_half2(deq_res);
|
724 |
+
};
|
725 |
+
|
726 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
727 |
+
#pragma unroll
|
728 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
729 |
+
#pragma unroll
|
730 |
+
for (int j = 0; j < 4; j++) {
|
731 |
+
int wr = d_sh_wr + 8 * j;
|
732 |
+
write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
733 |
+
frag_c[i][j][0][1], frag_s_tok[i][0],
|
734 |
+
frag_s_ch[j / 2][2 * (j % 2) + 0]);
|
735 |
+
write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
736 |
+
frag_c[i][j][0][3], frag_s_tok[i][1],
|
737 |
+
frag_s_ch[j / 2][2 * (j % 2) + 0]);
|
738 |
+
write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
739 |
+
frag_c[i][j][1][1], frag_s_tok[i][0],
|
740 |
+
frag_s_ch[j / 2][2 * (j % 2) + 1]);
|
741 |
+
write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
742 |
+
frag_c[i][j][1][3], frag_s_tok[i][1],
|
743 |
+
frag_s_ch[j / 2][2 * (j % 2) + 1]);
|
744 |
+
}
|
745 |
+
d_sh_wr += 16 * (4 * d_sh_stride);
|
746 |
+
}
|
747 |
+
}
|
748 |
+
__syncthreads();
|
749 |
+
|
750 |
+
#pragma unroll
|
751 |
+
for (int i = 0;
|
752 |
+
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
753 |
+
i++) {
|
754 |
+
if (d_gl_wr < d_gl_wr_end) {
|
755 |
+
D[d_gl_wr] = sh[d_sh_rd];
|
756 |
+
d_gl_wr += d_gl_wr_delta;
|
757 |
+
d_sh_rd += d_sh_rd_delta;
|
758 |
+
}
|
759 |
+
}
|
760 |
+
};
|
761 |
+
|
762 |
+
// Start global fetch and register load pipelines.
|
763 |
+
auto start_pipes = [&]() {
|
764 |
+
#pragma unroll
|
765 |
+
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
766 |
+
zero_accums();
|
767 |
+
wait_for_stage();
|
768 |
+
fetch_to_registers(0, 0);
|
769 |
+
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
770 |
+
};
|
771 |
+
start_pipes();
|
772 |
+
|
773 |
+
// Main loop.
|
774 |
+
while (slice_iters) {
|
775 |
+
// We unroll over both the global fetch and the register load pipeline to
|
776 |
+
// ensure all shared memory accesses are static. Note that both pipelines have
|
777 |
+
// even length meaning that the next iteration will always start at index 0.
|
778 |
+
#pragma unroll
|
779 |
+
for (int pipe = 0; pipe < stages;) {
|
780 |
+
#pragma unroll
|
781 |
+
for (int k = 0; k < b_sh_wr_iters; k++) {
|
782 |
+
fetch_to_registers(k + 1, pipe % stages);
|
783 |
+
if (k == b_sh_wr_iters - 2) {
|
784 |
+
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
785 |
+
slice_iters >= stages);
|
786 |
+
pipe++;
|
787 |
+
wait_for_stage();
|
788 |
+
}
|
789 |
+
matmul(k);
|
790 |
+
}
|
791 |
+
slice_iters--;
|
792 |
+
if (slice_iters == 0) break;
|
793 |
+
}
|
794 |
+
a_gl_rd += a_gl_rd_delta_o * stages;
|
795 |
+
|
796 |
+
// Process results and, if necessary, proceed to the next column slice.
|
797 |
+
// While this pattern may not be the most readable, other ways of writing
|
798 |
+
// the loop seemed to noticeably worse performance after compilation.
|
799 |
+
if (slice_iters == 0) {
|
800 |
+
cp_async_wait<0>();
|
801 |
+
bool last = slice_idx == slice_count - 1;
|
802 |
+
// For per-column scales, we only fetch them here in the final step before
|
803 |
+
// write-out
|
804 |
+
if (last) {
|
805 |
+
if (s_tok_sh_wr_pred) {
|
806 |
+
cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]);
|
807 |
+
}
|
808 |
+
if (s_ch_sh_wr_pred) {
|
809 |
+
cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]);
|
810 |
+
}
|
811 |
+
cp_async_fence();
|
812 |
+
}
|
813 |
+
thread_block_reduce();
|
814 |
+
if (last) {
|
815 |
+
cp_async_wait<0>();
|
816 |
+
__syncthreads();
|
817 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
818 |
+
#pragma unroll
|
819 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
820 |
+
frag_s_tok[i][0] =
|
821 |
+
*reinterpret_cast<float*>(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]);
|
822 |
+
frag_s_tok[i][1] = *reinterpret_cast<float*>(
|
823 |
+
&sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]);
|
824 |
+
}
|
825 |
+
reinterpret_cast<int4*>(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0];
|
826 |
+
reinterpret_cast<int4*>(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1];
|
827 |
+
reinterpret_cast<int4*>(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8];
|
828 |
+
reinterpret_cast<int4*>(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9];
|
829 |
+
}
|
830 |
+
}
|
831 |
+
if (slice_count > 1) { // only globally reduce if there is more than one
|
832 |
+
// block in a slice
|
833 |
+
barrier_acquire(&locks[slice_col], slice_idx);
|
834 |
+
global_reduce(slice_idx == 0, last);
|
835 |
+
barrier_release(&locks[slice_col], last);
|
836 |
+
}
|
837 |
+
if (last) // only the last block in a slice actually writes the result
|
838 |
+
write_result();
|
839 |
+
slice_row = 0;
|
840 |
+
slice_col_par++;
|
841 |
+
slice_col++;
|
842 |
+
init_slice();
|
843 |
+
if (slice_iters) {
|
844 |
+
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
845 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
846 |
+
#pragma unroll
|
847 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
848 |
+
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
849 |
+
if (slice_col == 0) {
|
850 |
+
#pragma unroll
|
851 |
+
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
852 |
+
}
|
853 |
+
s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x;
|
854 |
+
s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
|
855 |
+
start_pipes();
|
856 |
+
}
|
857 |
+
}
|
858 |
+
}
|
859 |
+
}
|
860 |
+
|
861 |
+
#else
|
862 |
+
|
863 |
+
template <const int threads, // number of threads in a threadblock
|
864 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
865 |
+
// dimension (batchsize) of the
|
866 |
+
// threadblock
|
867 |
+
const int thread_n_blocks, // same for n dimension (output)
|
868 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
869 |
+
const int stages, // number of stages for the async global->shared
|
870 |
+
// fetch pipeline
|
871 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
872 |
+
// with a separate quantization scale
|
873 |
+
>
|
874 |
+
__global__ void Marlin(
|
875 |
+
const int4* __restrict__ A, // int8 input matrix of shape mxk
|
876 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
877 |
+
int4* __restrict__ C, // int32 global_reduce buffer of shape
|
878 |
+
// (max_par*16*4)xn, as int8 tensor core's output is
|
879 |
+
// int32 dtype
|
880 |
+
int4* __restrict__ D, // fp16 output buffer of shape mxn
|
881 |
+
const float* __restrict__ s_tok, // fp32 activation per-token quantization
|
882 |
+
// scales of shape mx1
|
883 |
+
const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
|
884 |
+
// scales of shape 1xn
|
885 |
+
const int4* __restrict__ s_group, // fp16 weight per-group quantization
|
886 |
+
// scales of shape (k/groupsize)xn, when
|
887 |
+
// group_blocks=-1, it should be nullptr
|
888 |
+
int prob_m, // batch dimension m
|
889 |
+
int prob_n, // output dimension n
|
890 |
+
int prob_k, // reduction dimension k
|
891 |
+
int* locks // extra global storage for barrier synchronization
|
892 |
+
) {
|
893 |
+
// Marlin is not implemented yet for SM < 8.0
|
894 |
+
assert(false);
|
895 |
+
return;
|
896 |
+
}
|
897 |
+
|
898 |
+
#endif
|
899 |
+
|
900 |
+
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
901 |
+
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
902 |
+
// we want relatively few warps to have many registers per warp and small tiles.
|
903 |
+
const int USER_THREADS =
|
904 |
+
256; // Note: This is only used with user-provided thread_k/n
|
905 |
+
const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
906 |
+
|
907 |
+
static constexpr int min_thread_n = 64;
|
908 |
+
static constexpr int min_thread_k = 64;
|
909 |
+
|
910 |
+
static constexpr int tile_size = 16;
|
911 |
+
static constexpr int max_par = 16;
|
912 |
+
|
913 |
+
static constexpr int pack_factor_4bit =
|
914 |
+
8; // We have 8 4-bit vals inside a 32 bit
|
915 |
+
|
916 |
+
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
917 |
+
GROUP_BLOCKS, NUM_THREADS) \
|
918 |
+
else if (thread_m_blocks == THREAD_M_BLOCKS && \
|
919 |
+
thread_n_blocks == THREAD_N_BLOCKS && \
|
920 |
+
thread_k_blocks == THREAD_K_BLOCKS && \
|
921 |
+
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
922 |
+
cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
923 |
+
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
|
924 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, \
|
925 |
+
max_shared_mem); \
|
926 |
+
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
927 |
+
STAGES, GROUP_BLOCKS> \
|
928 |
+
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
929 |
+
A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \
|
930 |
+
prob_m, prob_n, prob_k, locks); \
|
931 |
+
}
|
932 |
+
|
933 |
+
typedef struct {
|
934 |
+
int thread_k;
|
935 |
+
int thread_n;
|
936 |
+
int num_threads;
|
937 |
+
} thread_config_t;
|
938 |
+
|
939 |
+
thread_config_t small_batch_thread_configs[] = {
|
940 |
+
// Ordered by priority
|
941 |
+
|
942 |
+
// thread_k, thread_n, num_threads
|
943 |
+
{128, 128, 256}, // Default
|
944 |
+
{128, 64, 128}, // Reduce N 2X, same K
|
945 |
+
{64, 256, 256}, // Reduce K 2X, increase N 2X
|
946 |
+
{64, 128, 128}, // Reduce K 2X, same N
|
947 |
+
};
|
948 |
+
|
949 |
+
thread_config_t large_batch_thread_configs[] = {
|
950 |
+
// Ordered by priority
|
951 |
+
|
952 |
+
// thread_k, thread_n, num_threads
|
953 |
+
{64, 256, 256}, // Default
|
954 |
+
{128, 128, 256}, // Reduce N 2X, increase K 2X
|
955 |
+
{64, 128, 128}, // Reduce N 2X, same K
|
956 |
+
{128, 64, 128}, // Reduce N 4X, increase K 2X
|
957 |
+
};
|
958 |
+
|
959 |
+
bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
|
960 |
+
int prob_k) {
|
961 |
+
// Sanity
|
962 |
+
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
963 |
+
th_config.num_threads == -1) {
|
964 |
+
return false;
|
965 |
+
}
|
966 |
+
|
967 |
+
// Verify K/N are divisible by thread K/N
|
968 |
+
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
969 |
+
return false;
|
970 |
+
}
|
971 |
+
|
972 |
+
// thread_k can be only 128 or 64 (because it must be less than groupsize
|
973 |
+
// which is 128)
|
974 |
+
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
|
975 |
+
return false;
|
976 |
+
}
|
977 |
+
|
978 |
+
// Verify min for thread K/N
|
979 |
+
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
980 |
+
return false;
|
981 |
+
}
|
982 |
+
|
983 |
+
// num_threads must be at least 128 (= 4 warps)
|
984 |
+
if (th_config.num_threads < 128) {
|
985 |
+
return false;
|
986 |
+
}
|
987 |
+
|
988 |
+
return true;
|
989 |
+
}
|
990 |
+
|
991 |
+
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
992 |
+
if (prob_m <= 16) {
|
993 |
+
for (auto th_config : small_batch_thread_configs) {
|
994 |
+
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
995 |
+
return th_config;
|
996 |
+
}
|
997 |
+
}
|
998 |
+
|
999 |
+
} else {
|
1000 |
+
for (auto th_config : large_batch_thread_configs) {
|
1001 |
+
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
1002 |
+
return th_config;
|
1003 |
+
}
|
1004 |
+
}
|
1005 |
+
}
|
1006 |
+
|
1007 |
+
return thread_config_t{-1, -1, -1};
|
1008 |
+
}
|
1009 |
+
|
1010 |
+
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
1011 |
+
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
1012 |
+
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
1013 |
+
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
1014 |
+
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
1015 |
+
__CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
1016 |
+
__CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
1017 |
+
__CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
1018 |
+
__CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
|
1019 |
+
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
|
1020 |
+
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
|
1021 |
+
|
1022 |
+
void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D,
|
1023 |
+
void* s_tok, void* s_ch, void* s_group, int prob_m,
|
1024 |
+
int prob_n, int prob_k, void* workspace,
|
1025 |
+
int groupsize = -1, int dev = 0, cudaStream_t stream = 0,
|
1026 |
+
int thread_k = -1, int thread_n = -1, int sms = -1,
|
1027 |
+
int max_par = 16) {
|
1028 |
+
int tot_m = prob_m;
|
1029 |
+
int tot_m_blocks = ceildiv(tot_m, 16);
|
1030 |
+
int pad = 16 * tot_m_blocks - tot_m;
|
1031 |
+
|
1032 |
+
if (sms == -1)
|
1033 |
+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
1034 |
+
|
1035 |
+
int max_shared_mem = 0;
|
1036 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
1037 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
1038 |
+
TORCH_CHECK(max_shared_mem > 0);
|
1039 |
+
|
1040 |
+
// Set thread config
|
1041 |
+
thread_config_t th_config;
|
1042 |
+
if (thread_k != -1 && thread_n != -1) {
|
1043 |
+
// User-defined config
|
1044 |
+
th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
|
1045 |
+
} else {
|
1046 |
+
// Auto config
|
1047 |
+
th_config = determine_thread_config(prob_m, prob_n, prob_k);
|
1048 |
+
}
|
1049 |
+
|
1050 |
+
if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
1051 |
+
throw std::runtime_error(
|
1052 |
+
"Invalid thread config: thread_k = " + str(th_config.thread_k) +
|
1053 |
+
", thread_n = " + str(th_config.thread_n) +
|
1054 |
+
", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
|
1055 |
+
str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
|
1056 |
+
}
|
1057 |
+
|
1058 |
+
int num_threads = th_config.num_threads;
|
1059 |
+
thread_k = th_config.thread_k;
|
1060 |
+
thread_n = th_config.thread_n;
|
1061 |
+
|
1062 |
+
int thread_k_blocks = thread_k / 16;
|
1063 |
+
int thread_n_blocks = thread_n / 16;
|
1064 |
+
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
|
1065 |
+
int blocks = sms;
|
1066 |
+
|
1067 |
+
if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
|
1068 |
+
return;
|
1069 |
+
}
|
1070 |
+
|
1071 |
+
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
1072 |
+
" is not divisible by thread_n = ", thread_n);
|
1073 |
+
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
1074 |
+
" is not divisible by thread_k = ", thread_k);
|
1075 |
+
if (group_blocks != -1) {
|
1076 |
+
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
1077 |
+
" is not divisible by group_blocks = ", group_blocks);
|
1078 |
+
}
|
1079 |
+
|
1080 |
+
const int4* A_ptr = (const int4*)A;
|
1081 |
+
const int4* B_ptr = (const int4*)B;
|
1082 |
+
int4* C_ptr = (int4*)C;
|
1083 |
+
int4* D_ptr = (int4*)D;
|
1084 |
+
const float* s_tok_ptr = (const float*)s_tok;
|
1085 |
+
const int4* s_ch_ptr = (const int4*)s_ch;
|
1086 |
+
const int4* s_group_ptr = (const int4*)s_group;
|
1087 |
+
|
1088 |
+
int* locks = (int*)workspace;
|
1089 |
+
|
1090 |
+
for (int i = 0; i < tot_m_blocks; i += 4) {
|
1091 |
+
int thread_m_blocks = tot_m_blocks - i;
|
1092 |
+
prob_m = tot_m - 16 * i;
|
1093 |
+
int par = 1;
|
1094 |
+
if (thread_m_blocks > 4) {
|
1095 |
+
// Note that parallel > 1 currently only works for inputs without any
|
1096 |
+
// padding
|
1097 |
+
par = (16 * thread_m_blocks - pad) / 64;
|
1098 |
+
if (par > max_par) par = max_par;
|
1099 |
+
prob_m = 64 * par;
|
1100 |
+
i += 4 * (par - 1);
|
1101 |
+
thread_m_blocks = 4;
|
1102 |
+
}
|
1103 |
+
|
1104 |
+
// For compilation speed, we only define the kernel configurations that have
|
1105 |
+
// seemed useful (in terms of performance) in our testing, however many more
|
1106 |
+
// are, in principle, possible.
|
1107 |
+
if (false) {
|
1108 |
+
}
|
1109 |
+
CALL_IF(8, 8, 256)
|
1110 |
+
CALL_IF(16, 4, 256)
|
1111 |
+
CALL_IF(8, 4, 128)
|
1112 |
+
CALL_IF(4, 8, 128)
|
1113 |
+
else {
|
1114 |
+
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
|
1115 |
+
", " + str(prob_k) + ", " + str(prob_n) + "]" +
|
1116 |
+
", groupsize = " + str(groupsize) +
|
1117 |
+
", thread_m_blocks = " + str(thread_m_blocks) +
|
1118 |
+
", thread_n_blocks = " + str(thread_n_blocks) +
|
1119 |
+
", thread_k_blocks = " + str(thread_k_blocks));
|
1120 |
+
}
|
1121 |
+
|
1122 |
+
A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par;
|
1123 |
+
D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
|
1124 |
+
s_tok_ptr += 16 * thread_m_blocks * par;
|
1125 |
+
}
|
1126 |
+
}
|
1127 |
+
} // anonymous namespace
|
1128 |
+
|
1129 |
+
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
1130 |
+
torch::Tensor const& b_q_weight,
|
1131 |
+
torch::Tensor const& s_tok,
|
1132 |
+
torch::Tensor const& s_ch,
|
1133 |
+
torch::Tensor const& s_group,
|
1134 |
+
torch::Tensor& workspace, int64_t size_m,
|
1135 |
+
int64_t size_n, int64_t size_k) {
|
1136 |
+
// Verify M
|
1137 |
+
TORCH_CHECK(size_m == a.size(0),
|
1138 |
+
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
1139 |
+
", size_m = " + str(size_m));
|
1140 |
+
TORCH_CHECK(size_m == s_tok.numel(),
|
1141 |
+
"Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) +
|
1142 |
+
", size_m = " + str(size_m));
|
1143 |
+
|
1144 |
+
// Verify K
|
1145 |
+
TORCH_CHECK(size_k == a.size(1),
|
1146 |
+
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
1147 |
+
", size_k = " + str(size_k));
|
1148 |
+
TORCH_CHECK(size_k % tile_size == 0,
|
1149 |
+
"size_k = " + str(size_k) +
|
1150 |
+
" is not divisible by tile_size = " + str(tile_size));
|
1151 |
+
TORCH_CHECK(
|
1152 |
+
(size_k / tile_size) == b_q_weight.size(0),
|
1153 |
+
"Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) +
|
1154 |
+
", size_k = " + str(size_k) + ", tile_size = " + str(tile_size));
|
1155 |
+
|
1156 |
+
int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0);
|
1157 |
+
// Verify groupsize
|
1158 |
+
TORCH_CHECK(groupsize == -1 || groupsize == 128,
|
1159 |
+
"Unexpected groupsize = " + str(groupsize));
|
1160 |
+
|
1161 |
+
// Verify N
|
1162 |
+
TORCH_CHECK(s_ch.numel() == size_n,
|
1163 |
+
"Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) +
|
1164 |
+
", size_n = " + str(size_n));
|
1165 |
+
TORCH_CHECK(b_q_weight.size(1) % tile_size == 0,
|
1166 |
+
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
1167 |
+
" is not divisible by tile_size = " + str(tile_size));
|
1168 |
+
if (groupsize != -1) {
|
1169 |
+
TORCH_CHECK(s_group.size(1) == size_n,
|
1170 |
+
"Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) +
|
1171 |
+
", size_n = " + str(size_n));
|
1172 |
+
TORCH_CHECK(
|
1173 |
+
size_k % s_group.size(0) == 0,
|
1174 |
+
"size_k = " + str(size_k) +
|
1175 |
+
", is not divisible by s_group.size(0) = " + str(s_group.size(0)));
|
1176 |
+
}
|
1177 |
+
|
1178 |
+
int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit;
|
1179 |
+
TORCH_CHECK(size_n == actual_size_n,
|
1180 |
+
"Shape mismatch: size_n = " + str(size_n) +
|
1181 |
+
", actual_size_n = " + str(actual_size_n));
|
1182 |
+
|
1183 |
+
// Verify A device and strides
|
1184 |
+
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
1185 |
+
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
1186 |
+
|
1187 |
+
// Verify B device and strides
|
1188 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
1189 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
1190 |
+
|
1191 |
+
// Verify s_tok device, strides and dtype
|
1192 |
+
TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU");
|
1193 |
+
TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous");
|
1194 |
+
TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32");
|
1195 |
+
|
1196 |
+
// Verify s_ch device, strides and dtype
|
1197 |
+
TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU");
|
1198 |
+
TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous");
|
1199 |
+
TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32");
|
1200 |
+
|
1201 |
+
// Verify s_group device, strides and dtype
|
1202 |
+
TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU");
|
1203 |
+
TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous");
|
1204 |
+
TORCH_CHECK(s_group.dtype() == torch::kFloat16,
|
1205 |
+
"s_group's dtype is not float16");
|
1206 |
+
|
1207 |
+
// Verify workspace size
|
1208 |
+
TORCH_CHECK(size_n % min_thread_n == 0,
|
1209 |
+
"size_n = " + str(size_n) +
|
1210 |
+
", is not divisible by min_thread_n = " + str(min_thread_n));
|
1211 |
+
int min_workspace_size = (size_n / min_thread_n) * max_par;
|
1212 |
+
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
1213 |
+
"workspace.numel = " + str(workspace.numel()) +
|
1214 |
+
" is below min_workspace_size = " + str(min_workspace_size));
|
1215 |
+
|
1216 |
+
// Alloc C matrix
|
1217 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
1218 |
+
auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device());
|
1219 |
+
torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c);
|
1220 |
+
|
1221 |
+
// Alloc D matrix
|
1222 |
+
auto options_d =
|
1223 |
+
torch::TensorOptions().dtype(torch::kFloat16).device(a.device());
|
1224 |
+
torch::Tensor d = torch::empty({size_m, size_n}, options_d);
|
1225 |
+
|
1226 |
+
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
1227 |
+
// auto -1)
|
1228 |
+
int thread_k = -1;
|
1229 |
+
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
1230 |
+
// auto -1)
|
1231 |
+
int thread_n = -1;
|
1232 |
+
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
|
1233 |
+
int sms = -1;
|
1234 |
+
|
1235 |
+
int dev = a.get_device();
|
1236 |
+
marlin_qqq_cuda(
|
1237 |
+
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(),
|
1238 |
+
s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n,
|
1239 |
+
size_k, workspace.data_ptr(), groupsize, dev,
|
1240 |
+
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par);
|
1241 |
+
|
1242 |
+
return d;
|
1243 |
+
}
|
marlin/sparse/LICENSE
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Contains code from https://github.com/IST-DASLab/Sparse-Marlin/
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright [yyyy] [name of copyright owner]
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
marlin/sparse/common/base.h
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
|
3 |
+
* Rights Reserved.
|
4 |
+
*
|
5 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
* you may not use this file except in compliance with the License.
|
7 |
+
* You may obtain a copy of the License at
|
8 |
+
*
|
9 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
*
|
11 |
+
* Unless required by applicable law or agreed to in writing, software
|
12 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
* See the License for the specific language governing permissions and
|
15 |
+
* limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
#pragma once
|
19 |
+
|
20 |
+
namespace marlin_24 {
|
21 |
+
|
22 |
+
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
23 |
+
|
24 |
+
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
25 |
+
// for instance as inputs to tensor core operations. Consequently, all
|
26 |
+
// corresponding index accesses must be compile-time constants, which is why we
|
27 |
+
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
28 |
+
// this.
|
29 |
+
template <typename T, int n>
|
30 |
+
struct Vec {
|
31 |
+
T elems[n];
|
32 |
+
__device__ T& operator[](int i) { return elems[i]; }
|
33 |
+
};
|
34 |
+
|
35 |
+
template <int M_, int N_, int K_>
|
36 |
+
struct ShapeBase {
|
37 |
+
static constexpr int M = M_, N = N_, K = K_;
|
38 |
+
};
|
39 |
+
|
40 |
+
using I4 = Vec<int, 4>;
|
41 |
+
|
42 |
+
// Matrix fragments for tensor core instructions; their precise layout is
|
43 |
+
// documented here:
|
44 |
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
45 |
+
using FragA = Vec<half2, 4>;
|
46 |
+
using FragB = Vec<half2, 2>;
|
47 |
+
using FragM = Vec<uint, 1>;
|
48 |
+
using FragC = Vec<float, 4>;
|
49 |
+
using FragS = Vec<half2, 1>; // quantization scales
|
50 |
+
|
51 |
+
} // namespace marlin_24
|
marlin/sparse/common/mem.h
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
|
3 |
+
* Rights Reserved.
|
4 |
+
*
|
5 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
* you may not use this file except in compliance with the License.
|
7 |
+
* You may obtain a copy of the License at
|
8 |
+
*
|
9 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
*
|
11 |
+
* Unless required by applicable law or agreed to in writing, software
|
12 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
* See the License for the specific language governing permissions and
|
15 |
+
* limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
#pragma once
|
19 |
+
#include "base.h"
|
20 |
+
|
21 |
+
namespace marlin_24 {
|
22 |
+
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
23 |
+
// predication to handle batchsizes that are not multiples of 16.
|
24 |
+
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
|
25 |
+
const void* glob_ptr,
|
26 |
+
bool pred = true,
|
27 |
+
const bool zfill = false) {
|
28 |
+
const int BYTES = 16;
|
29 |
+
int src_in_bytes = (zfill ? 0 : BYTES);
|
30 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
31 |
+
asm volatile(
|
32 |
+
"{\n"
|
33 |
+
" .reg .pred p;\n"
|
34 |
+
" setp.ne.b32 p, %0, 0;\n"
|
35 |
+
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
36 |
+
"}\n" ::"r"((int)pred),
|
37 |
+
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
38 |
+
}
|
39 |
+
|
40 |
+
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
41 |
+
bool pred = true) {
|
42 |
+
const int BYTES = 16;
|
43 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
44 |
+
asm volatile(
|
45 |
+
"{\n"
|
46 |
+
" .reg .pred p;\n"
|
47 |
+
" setp.ne.b32 p, %0, 0;\n"
|
48 |
+
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
49 |
+
"}\n" ::"r"((int)pred),
|
50 |
+
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
51 |
+
}
|
52 |
+
|
53 |
+
// Asynchronous global->shared copy
|
54 |
+
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
55 |
+
const int BYTES = 16;
|
56 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
57 |
+
asm volatile(
|
58 |
+
"{\n"
|
59 |
+
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
60 |
+
"}\n" ::"r"(smem),
|
61 |
+
"l"(glob_ptr), "n"(BYTES));
|
62 |
+
}
|
63 |
+
|
64 |
+
// Async copy fence.
|
65 |
+
__device__ inline void cp_async_fence() {
|
66 |
+
asm volatile("cp.async.commit_group;\n" ::);
|
67 |
+
}
|
68 |
+
|
69 |
+
// Wait until at most `n` async copy stages are still pending.
|
70 |
+
template <int n>
|
71 |
+
__device__ inline void cp_async_wait() {
|
72 |
+
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
73 |
+
}
|
74 |
+
|
75 |
+
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
76 |
+
// memory, directly in tensor core layout.
|
77 |
+
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
78 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
79 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
80 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
81 |
+
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
82 |
+
: "r"(smem));
|
83 |
+
}
|
84 |
+
|
85 |
+
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
|
86 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
|
87 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
88 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
89 |
+
: "=r"(a[0]), "=r"(a[1])
|
90 |
+
: "r"(smem));
|
91 |
+
}
|
92 |
+
|
93 |
+
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
94 |
+
// memory, directly in tensor core layout.
|
95 |
+
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
|
96 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
97 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
98 |
+
asm volatile(
|
99 |
+
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
100 |
+
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
101 |
+
: "r"(smem));
|
102 |
+
}
|
103 |
+
|
104 |
+
// Wait until barrier reaches `count`, then lock for current threadblock.
|
105 |
+
__device__ inline void barrier_acquire(int* lock, int count) {
|
106 |
+
if (threadIdx.x == 0) {
|
107 |
+
int state = -1;
|
108 |
+
do
|
109 |
+
// Guarantee that subsequent writes by this threadblock will be visible
|
110 |
+
// globally.
|
111 |
+
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
112 |
+
: "=r"(state)
|
113 |
+
: "l"(lock));
|
114 |
+
while (state != count);
|
115 |
+
}
|
116 |
+
__syncthreads();
|
117 |
+
}
|
118 |
+
|
119 |
+
// Release barrier and increment visitation count.
|
120 |
+
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
121 |
+
__syncthreads();
|
122 |
+
if (threadIdx.x == 0) {
|
123 |
+
if (reset) {
|
124 |
+
lock[0] = 0;
|
125 |
+
return;
|
126 |
+
}
|
127 |
+
int val = 1;
|
128 |
+
// Make sure that all writes since acquiring this barrier are visible
|
129 |
+
// globally, while releasing the barrier.
|
130 |
+
asm volatile("fence.acq_rel.gpu;\n");
|
131 |
+
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
132 |
+
:
|
133 |
+
: "l"(lock), "r"(val));
|
134 |
+
}
|
135 |
+
}
|
136 |
+
} // namespace marlin_24
|
marlin/sparse/common/mma.h
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
|
3 |
+
* Rights Reserved.
|
4 |
+
*
|
5 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
* you may not use this file except in compliance with the License.
|
7 |
+
* You may obtain a copy of the License at
|
8 |
+
*
|
9 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
*
|
11 |
+
* Unless required by applicable law or agreed to in writing, software
|
12 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
* See the License for the specific language governing permissions and
|
15 |
+
* limitations under the License.
|
16 |
+
*/
|
17 |
+
|
18 |
+
#pragma once
|
19 |
+
#include "base.h"
|
20 |
+
#include <cudaTypedefs.h>
|
21 |
+
|
22 |
+
namespace marlin_24 {
|
23 |
+
|
24 |
+
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
|
25 |
+
// is not supported. On later versions of CUDA the version without ordered
|
26 |
+
// metadata results in the following warning:
|
27 |
+
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
|
28 |
+
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
|
29 |
+
// | reduced performance on some future architectures
|
30 |
+
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
31 |
+
#define MMA_SP_INST \
|
32 |
+
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
33 |
+
#else
|
34 |
+
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
35 |
+
#endif
|
36 |
+
|
37 |
+
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
38 |
+
// output/accumulation.
|
39 |
+
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
40 |
+
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
|
41 |
+
const int psel) {
|
42 |
+
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
|
43 |
+
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
44 |
+
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
45 |
+
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
46 |
+
|
47 |
+
float* c = reinterpret_cast<float*>(&frag_c);
|
48 |
+
if (psel == 0) {
|
49 |
+
asm volatile(MMA_SP_INST
|
50 |
+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
51 |
+
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
52 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
53 |
+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
54 |
+
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
55 |
+
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
56 |
+
asm volatile(MMA_SP_INST
|
57 |
+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
58 |
+
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
59 |
+
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
60 |
+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
61 |
+
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
62 |
+
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
63 |
+
} else {
|
64 |
+
asm volatile(MMA_SP_INST
|
65 |
+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
66 |
+
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
67 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
68 |
+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
69 |
+
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
70 |
+
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
71 |
+
asm volatile(MMA_SP_INST
|
72 |
+
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
73 |
+
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
74 |
+
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
75 |
+
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
76 |
+
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
77 |
+
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
78 |
+
}
|
79 |
+
}
|
80 |
+
|
81 |
+
// Lookup-table based 3-input logical operation; explicitly used for
|
82 |
+
// dequantization as the compiler does not seem to automatically recognize it in
|
83 |
+
// all cases.
|
84 |
+
template <int lut>
|
85 |
+
__device__ inline int lop3(int a, int b, int c) {
|
86 |
+
int res;
|
87 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
88 |
+
: "=r"(res)
|
89 |
+
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
90 |
+
return res;
|
91 |
+
}
|
92 |
+
|
93 |
+
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
|
94 |
+
float c3) {
|
95 |
+
uint2 r;
|
96 |
+
asm("{\n\t"
|
97 |
+
".reg .f16 a, b, c, d; \n\t"
|
98 |
+
"cvt.rn.f16.f32 a, %2; \n\t"
|
99 |
+
"cvt.rn.f16.f32 b, %3; \n\t"
|
100 |
+
"cvt.rn.f16.f32 c, %4; \n\t"
|
101 |
+
"cvt.rn.f16.f32 d, %5; \n\t"
|
102 |
+
"mov.b32 %0, {a, b}; \n\t"
|
103 |
+
"mov.b32 %1, {c, d}; \n\t"
|
104 |
+
"}"
|
105 |
+
: "=r"(r.x), "=r"(r.y)
|
106 |
+
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
107 |
+
return r;
|
108 |
+
}
|
109 |
+
|
110 |
+
// Constructs destination register by taking bytes from 2 sources (based on
|
111 |
+
// mask)
|
112 |
+
template <int start_byte, int mask>
|
113 |
+
__device__ inline uint32_t prmt(uint32_t a) {
|
114 |
+
uint32_t res;
|
115 |
+
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
116 |
+
: "=r"(res)
|
117 |
+
: "r"(a), "n"(start_byte), "n"(mask));
|
118 |
+
return res;
|
119 |
+
}
|
120 |
+
|
121 |
+
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
122 |
+
// values. We mostly follow the strategy in the link below, with some small
|
123 |
+
// changes:
|
124 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
125 |
+
__device__ inline FragB dequant_4bit(int q) {
|
126 |
+
const int LO = 0x000f000f;
|
127 |
+
const int HI = 0x00f000f0;
|
128 |
+
const int EX = 0x64006400;
|
129 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
130 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
131 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
132 |
+
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
133 |
+
// directly into `SUB` and `ADD`.
|
134 |
+
const int SUB = 0x64086408;
|
135 |
+
const int MUL = 0x2c002c00;
|
136 |
+
const int ADD = 0xd480d480;
|
137 |
+
|
138 |
+
FragB frag_b;
|
139 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
140 |
+
*reinterpret_cast<const half2*>(&SUB));
|
141 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
142 |
+
*reinterpret_cast<const half2*>(&MUL),
|
143 |
+
*reinterpret_cast<const half2*>(&ADD));
|
144 |
+
return frag_b;
|
145 |
+
}
|
146 |
+
|
147 |
+
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
148 |
+
// values. We mostly follow the strategy in the link below, with some small
|
149 |
+
// changes:
|
150 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
151 |
+
__device__ inline FragB dequant_8bit(int q) {
|
152 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
153 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
154 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
155 |
+
|
156 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
157 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
158 |
+
|
159 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
160 |
+
|
161 |
+
FragB frag_b;
|
162 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
163 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
164 |
+
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
165 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
166 |
+
return frag_b;
|
167 |
+
}
|
168 |
+
|
169 |
+
// Multiply dequantized values by the corresponding quantization scale; used
|
170 |
+
// only for grouped quantization.
|
171 |
+
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
172 |
+
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
173 |
+
frag_b[0] = __hmul2(frag_b[0], s);
|
174 |
+
frag_b[1] = __hmul2(frag_b[1], s);
|
175 |
+
}
|
176 |
+
|
177 |
+
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
|
178 |
+
FragS& s0, float* c4, float* c5, float* c6,
|
179 |
+
float* c7, FragS& s1) {
|
180 |
+
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
181 |
+
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
182 |
+
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
183 |
+
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
|
184 |
+
|
185 |
+
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
|
186 |
+
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
|
187 |
+
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
|
188 |
+
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
189 |
+
}
|
190 |
+
|
191 |
+
} // namespace marlin_24
|
marlin/sparse/marlin_24_cuda_kernel.cu
ADDED
@@ -0,0 +1,1140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Notice: This file was modified by Neuralmagic inc to include 8-bit support
|
3 |
+
*
|
4 |
+
* Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
|
5 |
+
* Rights Reserved.
|
6 |
+
*
|
7 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
* you may not use this file except in compliance with the License.
|
9 |
+
* You may obtain a copy of the License at
|
10 |
+
*
|
11 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
*
|
13 |
+
* Unless required by applicable law or agreed to in writing, software
|
14 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
* See the License for the specific language governing permissions and
|
17 |
+
* limitations under the License.
|
18 |
+
*/
|
19 |
+
#include <torch/all.h>
|
20 |
+
|
21 |
+
#include <ATen/cuda/CUDAContext.h>
|
22 |
+
#include <c10/cuda/CUDAGuard.h>
|
23 |
+
#include <cuda.h>
|
24 |
+
#include <cuda_fp16.h>
|
25 |
+
#include <cuda_runtime.h>
|
26 |
+
|
27 |
+
#include <iostream>
|
28 |
+
|
29 |
+
#include "common/base.h"
|
30 |
+
#include "core/scalar_type.hpp"
|
31 |
+
|
32 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
33 |
+
|
34 |
+
#else
|
35 |
+
|
36 |
+
#include "common/mem.h"
|
37 |
+
#include "common/mma.h"
|
38 |
+
|
39 |
+
#endif
|
40 |
+
|
41 |
+
template <typename T>
|
42 |
+
inline std::string str(T x) {
|
43 |
+
return std::to_string(x);
|
44 |
+
}
|
45 |
+
|
46 |
+
namespace marlin_24 {
|
47 |
+
|
48 |
+
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
49 |
+
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
50 |
+
// we want relatively few warps to have many registers per warp and small tiles.
|
51 |
+
static constexpr int THREADS = 256;
|
52 |
+
static constexpr int STAGES = 4;
|
53 |
+
|
54 |
+
static constexpr int min_thread_n = 128;
|
55 |
+
|
56 |
+
static constexpr int tile_size = 16;
|
57 |
+
static constexpr int max_par = 64;
|
58 |
+
|
59 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
60 |
+
|
61 |
+
template <const int num_bits, // weight bits
|
62 |
+
const int threads, // number of threads in a threadblock
|
63 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
64 |
+
// dimension (batchsize) of the
|
65 |
+
// threadblock
|
66 |
+
const int thread_n_blocks, // same for n dimension (output)
|
67 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
68 |
+
const int stages, // number of stages for the async global->shared
|
69 |
+
// fetch pipeline
|
70 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
71 |
+
// with a separate quantization scale
|
72 |
+
>
|
73 |
+
__global__ void Marlin_24(
|
74 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
75 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
76 |
+
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
77 |
+
// format on B
|
78 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
79 |
+
const int4* __restrict__ s, // fp16 quantization scales of shape
|
80 |
+
// (k/groupsize)xn
|
81 |
+
int prob_m, // batch dimension m
|
82 |
+
int prob_n, // output dimension n
|
83 |
+
int prob_k, // reduction dimension k
|
84 |
+
int* locks // extra global storage for barrier synchronization
|
85 |
+
) {}
|
86 |
+
|
87 |
+
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
88 |
+
torch::Tensor& b_meta,
|
89 |
+
torch::Tensor& b_scales,
|
90 |
+
torch::Tensor& workspace,
|
91 |
+
vllm::ScalarTypeId const b_q_type_id,
|
92 |
+
int64_t size_m, int64_t size_n,
|
93 |
+
int64_t size_k) {
|
94 |
+
TORCH_CHECK_NOT_IMPLEMENTED(
|
95 |
+
false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0");
|
96 |
+
return torch::empty({1, 1});
|
97 |
+
}
|
98 |
+
|
99 |
+
#else
|
100 |
+
|
101 |
+
template <const int num_bits, // weight bits
|
102 |
+
const int threads, // number of threads in a threadblock
|
103 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
104 |
+
// dimension (batchsize) of the
|
105 |
+
// threadblock
|
106 |
+
const int thread_n_blocks, // same for n dimension (output)
|
107 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
108 |
+
const int stages, // number of stages for the async global->shared
|
109 |
+
// fetch pipeline
|
110 |
+
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
111 |
+
// with a separate quantization scale
|
112 |
+
>
|
113 |
+
__global__ void Marlin_24(
|
114 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
115 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
116 |
+
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
117 |
+
// format on B
|
118 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
119 |
+
const int4* __restrict__ s, // fp16 quantization scales of shape
|
120 |
+
// (k/groupsize)xn
|
121 |
+
int prob_m, // batch dimension m
|
122 |
+
int prob_n, // output dimension n
|
123 |
+
int prob_k, // reduction dimension k
|
124 |
+
int* locks // extra global storage for barrier synchronization
|
125 |
+
) {
|
126 |
+
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
127 |
+
// same size, which might involve multiple column "slices" (of width 16 *
|
128 |
+
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
129 |
+
// example:
|
130 |
+
// 0 1 3
|
131 |
+
// 0 2 3
|
132 |
+
// 1 2 4
|
133 |
+
// While this kind of partitioning makes things somewhat more complicated, it
|
134 |
+
// ensures good utilization of all SMs for many kinds of shape and GPU
|
135 |
+
// configurations, while requiring as few slow global cross-threadblock
|
136 |
+
// reductions as possible.
|
137 |
+
|
138 |
+
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
139 |
+
// better partitioning with less reductions
|
140 |
+
int parallel = 1;
|
141 |
+
if (prob_m > 16 * thread_m_blocks) {
|
142 |
+
parallel = prob_m / (16 * thread_m_blocks);
|
143 |
+
prob_m = 16 * thread_m_blocks;
|
144 |
+
}
|
145 |
+
|
146 |
+
// number of thread_k_blocks in k-dim
|
147 |
+
int k_tiles = prob_k / 32 / thread_k_blocks;
|
148 |
+
// number of thread_n_blocks in n-dim
|
149 |
+
int n_tiles = prob_n / 16 / thread_n_blocks;
|
150 |
+
// iters needed to cover all slices
|
151 |
+
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
|
152 |
+
|
153 |
+
// Ensure that the number of tiles in each stripe is a multiple of the
|
154 |
+
// groupsize; this avoids an annoying special case where a stripe starts in
|
155 |
+
// the middle of group.
|
156 |
+
if (group_blocks != -1)
|
157 |
+
iters = (group_blocks / thread_k_blocks) *
|
158 |
+
ceildiv(iters, (group_blocks / thread_k_blocks));
|
159 |
+
|
160 |
+
int slice_row = (iters * blockIdx.x) % k_tiles;
|
161 |
+
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
162 |
+
int slice_col = slice_col_par;
|
163 |
+
// number of threadblock tiles in the current slice
|
164 |
+
int slice_iters;
|
165 |
+
// total number of active threadblocks in the current slice
|
166 |
+
int slice_count = 0;
|
167 |
+
// index of threadblock in current slice; numbered bottom to top
|
168 |
+
int slice_idx;
|
169 |
+
|
170 |
+
// We can easily implement parallel problem execution by just remapping
|
171 |
+
// indices and advancing global pointers
|
172 |
+
if (slice_col_par >= n_tiles) {
|
173 |
+
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
|
174 |
+
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
175 |
+
locks += (slice_col_par / n_tiles) * n_tiles;
|
176 |
+
slice_col = slice_col_par % n_tiles;
|
177 |
+
}
|
178 |
+
|
179 |
+
// Compute all information about the current slice which is required for
|
180 |
+
// synchronization.
|
181 |
+
auto init_slice = [&]() {
|
182 |
+
slice_iters =
|
183 |
+
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
184 |
+
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
185 |
+
if (slice_iters == 0) return;
|
186 |
+
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
187 |
+
slice_count = 1;
|
188 |
+
slice_idx = 0;
|
189 |
+
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
190 |
+
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
191 |
+
int col_off = col_first - k_tiles * slice_col_par;
|
192 |
+
slice_count = ceildiv(k_tiles - col_off, iters);
|
193 |
+
if (col_off > 0) slice_count++;
|
194 |
+
int delta_first = iters * blockIdx.x - col_first;
|
195 |
+
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
196 |
+
slice_idx = slice_count - 1;
|
197 |
+
else {
|
198 |
+
slice_idx = slice_count - 1 - delta_first / iters;
|
199 |
+
if (col_off > 0) slice_idx--;
|
200 |
+
}
|
201 |
+
}
|
202 |
+
if (slice_col == n_tiles) {
|
203 |
+
A += 16 * thread_m_blocks * prob_k / 8;
|
204 |
+
C += 16 * thread_m_blocks * prob_n / 8;
|
205 |
+
locks += n_tiles;
|
206 |
+
slice_col = 0;
|
207 |
+
}
|
208 |
+
};
|
209 |
+
init_slice();
|
210 |
+
|
211 |
+
// RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
|
212 |
+
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
|
213 |
+
|
214 |
+
// stride of an A matrix tile in shared memory
|
215 |
+
constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
|
216 |
+
// delta between subsequent A tiles in global memory
|
217 |
+
constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8;
|
218 |
+
// between subsequent accesses within a tile
|
219 |
+
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
|
220 |
+
// between shared memory writes
|
221 |
+
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
|
222 |
+
// between shared memory tile reads //RLC: 2 * #warps k-dim
|
223 |
+
constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4));
|
224 |
+
// within a shared memory tile
|
225 |
+
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
|
226 |
+
// overall size of a tile
|
227 |
+
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
|
228 |
+
// number of shared write iterations for a tile
|
229 |
+
constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
|
230 |
+
|
231 |
+
constexpr int pack_factor = 32 / num_bits;
|
232 |
+
|
233 |
+
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
234 |
+
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
235 |
+
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
|
236 |
+
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
237 |
+
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
238 |
+
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
|
239 |
+
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
|
240 |
+
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
|
241 |
+
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
242 |
+
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
243 |
+
|
244 |
+
int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
|
245 |
+
constexpr int m_sh_stride =
|
246 |
+
(16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
|
247 |
+
int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
|
248 |
+
int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
|
249 |
+
constexpr int m_sh_wr_delta = threads / 2;
|
250 |
+
constexpr int m_sh_rd_delta = threads / 2;
|
251 |
+
constexpr int m_sh_stage = m_sh_stride * thread_k_blocks;
|
252 |
+
constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta);
|
253 |
+
|
254 |
+
int s_gl_stride = prob_n / 8;
|
255 |
+
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
256 |
+
constexpr int s_sh_stage = s_sh_stride;
|
257 |
+
int s_gl_rd_delta = s_gl_stride;
|
258 |
+
|
259 |
+
// Global A read index of current thread.
|
260 |
+
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
261 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
262 |
+
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
263 |
+
// Shared write index of current thread.
|
264 |
+
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
265 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
266 |
+
// Shared read index.
|
267 |
+
int a_sh_rd =
|
268 |
+
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
269 |
+
a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
270 |
+
|
271 |
+
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
|
272 |
+
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
273 |
+
b_gl_rd += b_sh_stride * slice_col;
|
274 |
+
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
275 |
+
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
276 |
+
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
277 |
+
|
278 |
+
int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
|
279 |
+
(threadIdx.x % (m_sh_stride));
|
280 |
+
m_gl_rd += (m_sh_stride)*slice_col;
|
281 |
+
m_gl_rd += m_gl_rd_delta_o * slice_row;
|
282 |
+
int m_sh_wr = threadIdx.x;
|
283 |
+
int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
|
284 |
+
|
285 |
+
int s_gl_rd;
|
286 |
+
if constexpr (group_blocks == -1) {
|
287 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
288 |
+
} else {
|
289 |
+
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
290 |
+
s_sh_stride * slice_col + threadIdx.x;
|
291 |
+
}
|
292 |
+
|
293 |
+
int s_sh_wr = threadIdx.x;
|
294 |
+
int s_sh_rd;
|
295 |
+
// We use a different scale layout for grouped and column-wise quantization as
|
296 |
+
// we scale a `half2` tile in column-major layout in the former and in
|
297 |
+
// row-major in the latter case.
|
298 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
299 |
+
(threadIdx.x % 32) / 4; // Note that in the original Marlin kernel
|
300 |
+
// this is (threadIdx.x % 32) / 4
|
301 |
+
|
302 |
+
// Precompute which thread should not read memory in which iterations; this is
|
303 |
+
// needed if there are more threads than required for a certain tilesize or
|
304 |
+
// when the batchsize is not a multiple of 16.
|
305 |
+
bool a_sh_wr_pred[a_sh_wr_iters];
|
306 |
+
#pragma unroll
|
307 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
308 |
+
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
309 |
+
}
|
310 |
+
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
311 |
+
|
312 |
+
// To ensure that writing and reading A tiles to/from shared memory, the
|
313 |
+
// latter in fragment format, is fully bank conflict free, we need to use a
|
314 |
+
// rather fancy XOR-based layout. The key here is that neither reads nor
|
315 |
+
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
|
316 |
+
// same shared memory banks. Further, it seems (based on NSight-Compute) that
|
317 |
+
// each warp must also write a consecutive memory segment?
|
318 |
+
auto transform_a = [&](int i) {
|
319 |
+
int row = i / a_gl_rd_delta_o;
|
320 |
+
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
321 |
+
};
|
322 |
+
// Since the computation of this remapping is non-trivial and, due to our main
|
323 |
+
// loop unrolls, all shared memory accesses are static, we simply precompute
|
324 |
+
// both transformed reads and writes.
|
325 |
+
int a_sh_wr_trans[a_sh_wr_iters];
|
326 |
+
#pragma unroll
|
327 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
328 |
+
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
329 |
+
int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
|
330 |
+
#pragma unroll
|
331 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
332 |
+
#pragma unroll
|
333 |
+
for (int j = 0; j < thread_m_blocks; j++) {
|
334 |
+
a_sh_rd_trans[0][i][j] =
|
335 |
+
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
336 |
+
a_sh_rd_trans[1][i][j] =
|
337 |
+
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2);
|
338 |
+
}
|
339 |
+
}
|
340 |
+
|
341 |
+
// Since B-accesses have non-constant stride they have to be computed at
|
342 |
+
// runtime; we break dependencies between subsequent accesses with a tile by
|
343 |
+
// maintining multiple pointers (we have enough registers), a tiny
|
344 |
+
// optimization.
|
345 |
+
const int4* B_ptr[b_sh_wr_iters];
|
346 |
+
#pragma unroll
|
347 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
348 |
+
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
349 |
+
|
350 |
+
bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
|
351 |
+
const int4* meta_ptr[m_sh_iters];
|
352 |
+
#pragma unroll
|
353 |
+
for (int i = 0; i < m_sh_iters; i++)
|
354 |
+
meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
|
355 |
+
|
356 |
+
extern __shared__ int4 sh[];
|
357 |
+
// Shared memory storage for global fetch pipelines.
|
358 |
+
int4* sh_a = sh;
|
359 |
+
int4* sh_b = sh_a + (stages * a_sh_stage);
|
360 |
+
int4* sh_s = sh_b + (stages * b_sh_stage);
|
361 |
+
int4* sh_m = sh_s + (stages * s_sh_stage);
|
362 |
+
// Register storage for double buffer of shared memory reads.
|
363 |
+
FragA frag_a[2][thread_m_blocks][2];
|
364 |
+
I4 frag_b_quant[2][b_thread_vecs];
|
365 |
+
FragM frag_m[2][2];
|
366 |
+
FragC frag_c[thread_m_blocks][4][2];
|
367 |
+
FragS frag_s[2][4];
|
368 |
+
|
369 |
+
// Zero accumulators.
|
370 |
+
auto zero_accums = [&]() {
|
371 |
+
#pragma unroll
|
372 |
+
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
373 |
+
reinterpret_cast<float*>(frag_c)[i] = 0;
|
374 |
+
};
|
375 |
+
|
376 |
+
// Asynchronously fetch the next A, B and s tile from global to the next
|
377 |
+
// shared memory pipeline location.
|
378 |
+
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
379 |
+
if (pred) {
|
380 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
381 |
+
#pragma unroll
|
382 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
383 |
+
cp_async4_pred(
|
384 |
+
&sh_a_stage[a_sh_wr_trans[i]],
|
385 |
+
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
386 |
+
a_sh_wr_pred[i]);
|
387 |
+
}
|
388 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
389 |
+
#pragma unroll
|
390 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
391 |
+
#pragma unroll
|
392 |
+
for (int j = 0; j < b_thread_vecs; j++) {
|
393 |
+
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
394 |
+
}
|
395 |
+
B_ptr[i] += b_gl_rd_delta_o;
|
396 |
+
}
|
397 |
+
int4* sh_meta_stage = sh_m + m_sh_stage * pipe;
|
398 |
+
#pragma unroll
|
399 |
+
for (int i = 0; i < m_sh_iters; i++) {
|
400 |
+
if (m_sh_wr_pred)
|
401 |
+
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
|
402 |
+
meta_ptr[i] += m_gl_rd_delta_o;
|
403 |
+
}
|
404 |
+
// Only fetch scales if this tile starts a new group
|
405 |
+
if constexpr (group_blocks != -1) {
|
406 |
+
// This assumes group_blocks >= thread_k_blocks
|
407 |
+
// and would need to be modified to support smaller groups.
|
408 |
+
static_assert(group_blocks >= thread_k_blocks);
|
409 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
410 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
411 |
+
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
412 |
+
s_gl_rd += s_gl_rd_delta;
|
413 |
+
}
|
414 |
+
}
|
415 |
+
}
|
416 |
+
// Insert a fence even when we are winding down the pipeline to ensure that
|
417 |
+
// waiting is also correct at this point.
|
418 |
+
cp_async_fence();
|
419 |
+
};
|
420 |
+
|
421 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
422 |
+
auto wait_for_stage = [&]() {
|
423 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
424 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
425 |
+
// shared memory load is fully complete (as it may otherwise be
|
426 |
+
// overwritten).
|
427 |
+
cp_async_wait<stages - 2>();
|
428 |
+
__syncthreads();
|
429 |
+
};
|
430 |
+
|
431 |
+
// Load the next sub-tile from the current location in the shared memory pipe
|
432 |
+
// into the current register buffer.
|
433 |
+
auto fetch_to_registers = [&](int k, int pipe) {
|
434 |
+
// It may seem inefficient that we reload the groups for every sub-tile;
|
435 |
+
// however, this does not seem to be a significant bottleneck, while some
|
436 |
+
// theoretically better attempts have lead to bad instruction ordering by
|
437 |
+
// the compiler and correspondingly a noticeable drop in performance.
|
438 |
+
if constexpr (group_blocks != -1) {
|
439 |
+
// This assumes group_blocks >= thread_k_blocks
|
440 |
+
// and would need to be modified to support smaller groups.
|
441 |
+
static_assert(group_blocks >= thread_k_blocks);
|
442 |
+
int4* sh_s_stage =
|
443 |
+
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
444 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
445 |
+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
446 |
+
}
|
447 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
448 |
+
#pragma unroll
|
449 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
450 |
+
ldsm4(frag_a[k % 2][i][0],
|
451 |
+
&sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
|
452 |
+
ldsm4(frag_a[k % 2][i][1],
|
453 |
+
&sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
|
454 |
+
}
|
455 |
+
|
456 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
457 |
+
#pragma unroll
|
458 |
+
for (int i = 0; i < b_thread_vecs; i++) {
|
459 |
+
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
460 |
+
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
461 |
+
}
|
462 |
+
|
463 |
+
// Load meta with ldsm4
|
464 |
+
int4* sh_m_stage = sh_m + m_sh_stage * pipe;
|
465 |
+
ldsm4_m(frag_m[k % 2][0],
|
466 |
+
&sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
|
467 |
+
};
|
468 |
+
|
469 |
+
// Execute the actual tensor core matmul of a sub-tile.
|
470 |
+
auto matmul = [&](int k) {
|
471 |
+
// We have the m dimension as the inner loop in order to encourage overlapping
|
472 |
+
// dequantization and matmul operations.
|
473 |
+
#pragma unroll
|
474 |
+
for (int j = 0; j < 4; j++) {
|
475 |
+
FragB frag_b0;
|
476 |
+
FragB frag_b1;
|
477 |
+
|
478 |
+
if constexpr (num_bits == 4) {
|
479 |
+
int b_quant = frag_b_quant[k % 2][0][j];
|
480 |
+
int b_quant_shift = b_quant >> 8;
|
481 |
+
|
482 |
+
frag_b0 = dequant_4bit(b_quant);
|
483 |
+
frag_b1 = dequant_4bit(b_quant_shift);
|
484 |
+
|
485 |
+
} else {
|
486 |
+
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
487 |
+
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
488 |
+
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
489 |
+
|
490 |
+
frag_b0 = dequant_8bit(b_quant_0);
|
491 |
+
frag_b1 = dequant_8bit(b_quant_1);
|
492 |
+
}
|
493 |
+
|
494 |
+
// If there are no groups, we can just scale the final output once and can
|
495 |
+
// avoid doing so for each weight.
|
496 |
+
if constexpr (group_blocks != -1) {
|
497 |
+
scale(frag_b0, frag_s[k % 2][j], 0);
|
498 |
+
}
|
499 |
+
if constexpr (group_blocks != -1) {
|
500 |
+
scale(frag_b1, frag_s[k % 2][j], 1);
|
501 |
+
}
|
502 |
+
|
503 |
+
#pragma unroll
|
504 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
505 |
+
mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
|
506 |
+
frag_m[k % 2][j / 2], j % 2);
|
507 |
+
}
|
508 |
+
}
|
509 |
+
};
|
510 |
+
|
511 |
+
// Since we slice across the k dimension of a tile in order to increase the
|
512 |
+
// number of warps while keeping the n dimension of a tile reasonable, we have
|
513 |
+
// multiple warps that accumulate their partial sums of the same output
|
514 |
+
// location; which we have to reduce over in the end. We do in shared memory.
|
515 |
+
auto thread_block_reduce = [&]() {
|
516 |
+
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
517 |
+
if (red_off >= 1) {
|
518 |
+
int red_idx = threadIdx.x / b_sh_stride_threads;
|
519 |
+
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
520 |
+
constexpr int red_sh_delta = b_sh_stride_threads;
|
521 |
+
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
522 |
+
(threadIdx.x % b_sh_stride_threads);
|
523 |
+
|
524 |
+
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
525 |
+
// unnecessary read or write iterations, e.g., for two warps we write only
|
526 |
+
// once by warp 1 and read only once by warp 0.
|
527 |
+
#pragma unroll
|
528 |
+
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
529 |
+
#pragma unroll
|
530 |
+
for (int i = red_off; i > 0; i /= 2) {
|
531 |
+
if (i <= red_idx && red_idx < 2 * i) {
|
532 |
+
#pragma unroll
|
533 |
+
for (int j = 0; j < 4 * 2; j++) {
|
534 |
+
int red_sh_wr =
|
535 |
+
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
536 |
+
if (i < red_off) {
|
537 |
+
float* c_rd =
|
538 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
539 |
+
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
540 |
+
#pragma unroll
|
541 |
+
for (int k = 0; k < 4; k++)
|
542 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
543 |
+
c_rd[k] + c_wr[k];
|
544 |
+
}
|
545 |
+
sh[red_sh_wr] =
|
546 |
+
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
547 |
+
}
|
548 |
+
}
|
549 |
+
__syncthreads();
|
550 |
+
}
|
551 |
+
if (red_idx == 0) {
|
552 |
+
#pragma unroll
|
553 |
+
for (int i = 0; i < 4 * 2; i++) {
|
554 |
+
float* c_rd =
|
555 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
556 |
+
#pragma unroll
|
557 |
+
for (int j = 0; j < 4; j++)
|
558 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
559 |
+
c_rd[j];
|
560 |
+
}
|
561 |
+
}
|
562 |
+
__syncthreads();
|
563 |
+
}
|
564 |
+
}
|
565 |
+
};
|
566 |
+
|
567 |
+
// Since multiple threadblocks may process parts of the same column slice, we
|
568 |
+
// finally have to globally reduce over the results. As the striped
|
569 |
+
// partitioning minimizes the number of such reductions and our outputs are
|
570 |
+
// usually rather small, we perform this reduction serially in L2 cache.
|
571 |
+
auto global_reduce = [&](bool first = false, bool last = false) {
|
572 |
+
// We are very careful here to reduce directly in the output buffer to
|
573 |
+
// maximize L2 cache utilization in this step. To do this, we write out
|
574 |
+
// results in FP16 (but still reduce with FP32 compute).
|
575 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
576 |
+
if (threadIdx.x < active_threads) {
|
577 |
+
int c_gl_stride = prob_n / 8;
|
578 |
+
int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
|
579 |
+
int c_gl_wr_delta_i =
|
580 |
+
c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
|
581 |
+
int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
|
582 |
+
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
|
583 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
584 |
+
constexpr int c_sh_wr_delta = active_threads;
|
585 |
+
int c_sh_wr = threadIdx.x;
|
586 |
+
|
587 |
+
int col = 2 * ((threadIdx.x % 32) % 4);
|
588 |
+
|
589 |
+
if (!first) {
|
590 |
+
// Interestingly, doing direct global accesses here really seems to mess up
|
591 |
+
// the compiler and lead to slowdowns, hence we also use async-copies even
|
592 |
+
// though these fetches are not actually asynchronous.
|
593 |
+
#pragma unroll
|
594 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
595 |
+
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
596 |
+
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
597 |
+
c_gl_wr_delta_i * (i % 2)],
|
598 |
+
i < (thread_m_blocks - 1) * 4 ||
|
599 |
+
8 * (i / 2) + col + (i % 2) < prob_m);
|
600 |
+
}
|
601 |
+
cp_async_fence();
|
602 |
+
cp_async_wait<0>();
|
603 |
+
}
|
604 |
+
|
605 |
+
#pragma unroll
|
606 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
607 |
+
if (i < (thread_m_blocks - 1) * 4 ||
|
608 |
+
8 * (i / 2) + col + (i % 2) < prob_m) {
|
609 |
+
if (!first) {
|
610 |
+
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
611 |
+
#pragma unroll
|
612 |
+
for (int j2 = 0; j2 < 2; j2++) {
|
613 |
+
#pragma unroll
|
614 |
+
for (int j1 = 0; j1 < 4; j1++) {
|
615 |
+
reinterpret_cast<float*>(
|
616 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
|
617 |
+
4 * ((i % 4) / 2) + i % 2] +=
|
618 |
+
__half2float(
|
619 |
+
reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]);
|
620 |
+
}
|
621 |
+
}
|
622 |
+
}
|
623 |
+
if (!last) {
|
624 |
+
int4 c;
|
625 |
+
#pragma unroll
|
626 |
+
for (int j2 = 0; j2 < 2; j2++) {
|
627 |
+
#pragma unroll
|
628 |
+
for (int j1 = 0; j1 < 4; j1++) {
|
629 |
+
reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] =
|
630 |
+
__float2half(reinterpret_cast<float*>(
|
631 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
|
632 |
+
4 * ((i % 4) / 2) + i % 2]);
|
633 |
+
}
|
634 |
+
}
|
635 |
+
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
636 |
+
c;
|
637 |
+
}
|
638 |
+
}
|
639 |
+
}
|
640 |
+
}
|
641 |
+
};
|
642 |
+
|
643 |
+
// Write out the reduce final result in the correct layout. We only actually
|
644 |
+
// reshuffle matrix fragments in this step, the reduction above is performed
|
645 |
+
// in fragment layout.
|
646 |
+
auto write_result = [&]() {
|
647 |
+
int c_gl_stride = prob_n / 8;
|
648 |
+
|
649 |
+
constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
|
650 |
+
constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
|
651 |
+
constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
|
652 |
+
|
653 |
+
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
654 |
+
|
655 |
+
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
656 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
657 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
658 |
+
|
659 |
+
int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
|
660 |
+
((threadIdx.x % 32) / 4); // RLC:
|
661 |
+
c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
|
662 |
+
|
663 |
+
constexpr int c_sh_rd_delta =
|
664 |
+
c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
|
665 |
+
int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
|
666 |
+
(threadIdx.x % (2 * 2 * thread_n_blocks));
|
667 |
+
|
668 |
+
int c_gl_wr_end = c_gl_stride * prob_m;
|
669 |
+
|
670 |
+
auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0,
|
671 |
+
float c4, float c5, float c6, float c7, FragS& s1) {
|
672 |
+
uint2 res[2];
|
673 |
+
res[0] = to_half4(c0, c1, c2, c3);
|
674 |
+
res[1] = to_half4(c4, c5, c6, c7);
|
675 |
+
half2* tmp = (half2*)&res;
|
676 |
+
// for per-column quantization we finally apply the scale here
|
677 |
+
if constexpr (group_blocks == -1 && num_bits == 4) {
|
678 |
+
tmp[0] = __hmul2(tmp[0], s0[0]);
|
679 |
+
tmp[1] = __hmul2(tmp[1], s0[1]);
|
680 |
+
tmp[2] = __hmul2(tmp[2], s1[0]);
|
681 |
+
tmp[3] = __hmul2(tmp[3], s1[1]);
|
682 |
+
}
|
683 |
+
((int4*)sh)[idx] = *((int4*)&res[0]);
|
684 |
+
};
|
685 |
+
|
686 |
+
// RLC: only warp 0 and 1 baseline example
|
687 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
688 |
+
#pragma unroll
|
689 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
690 |
+
int wr = c_sh_wr;
|
691 |
+
write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
|
692 |
+
frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2],
|
693 |
+
frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2],
|
694 |
+
frag_s[0][2]);
|
695 |
+
write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1],
|
696 |
+
frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0],
|
697 |
+
frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3],
|
698 |
+
frag_c[i][3][0][3], frag_s[0][2]);
|
699 |
+
write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0],
|
700 |
+
frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0],
|
701 |
+
frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2],
|
702 |
+
frag_c[i][3][1][2], frag_s[0][2]);
|
703 |
+
write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1],
|
704 |
+
frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1],
|
705 |
+
frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3],
|
706 |
+
frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]);
|
707 |
+
|
708 |
+
c_sh_wr += 8 * c_sh_stride_2;
|
709 |
+
}
|
710 |
+
}
|
711 |
+
__syncthreads();
|
712 |
+
|
713 |
+
#pragma unroll
|
714 |
+
for (int i = 0;
|
715 |
+
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
716 |
+
i++) {
|
717 |
+
if (c_gl_wr < c_gl_wr_end) {
|
718 |
+
C[c_gl_wr] = sh[c_sh_rd];
|
719 |
+
c_gl_wr += c_gl_wr_delta;
|
720 |
+
c_sh_rd += c_sh_rd_delta;
|
721 |
+
}
|
722 |
+
}
|
723 |
+
};
|
724 |
+
|
725 |
+
// Start global fetch and register load pipelines.
|
726 |
+
auto start_pipes = [&]() {
|
727 |
+
#pragma unroll
|
728 |
+
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
729 |
+
zero_accums();
|
730 |
+
wait_for_stage();
|
731 |
+
fetch_to_registers(0, 0);
|
732 |
+
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
733 |
+
};
|
734 |
+
start_pipes();
|
735 |
+
|
736 |
+
// Main loop.
|
737 |
+
while (slice_iters) {
|
738 |
+
// We unroll over both the global fetch and the register load pipeline to
|
739 |
+
// ensure all shared memory accesses are static. Note that both pipelines have
|
740 |
+
// even length meaning that the next iteration will always start at index 0.
|
741 |
+
#pragma unroll
|
742 |
+
for (int pipe = 0; pipe < stages;) {
|
743 |
+
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
744 |
+
slice_iters >= stages);
|
745 |
+
matmul(pipe);
|
746 |
+
wait_for_stage();
|
747 |
+
|
748 |
+
fetch_to_registers(pipe + 1, (pipe + 1) % stages);
|
749 |
+
|
750 |
+
pipe++;
|
751 |
+
slice_iters--;
|
752 |
+
if (slice_iters == 0) break;
|
753 |
+
}
|
754 |
+
a_gl_rd += a_gl_rd_delta_o * stages;
|
755 |
+
|
756 |
+
// Process results and, if necessary, proceed to the next column slice.
|
757 |
+
// While this pattern may not be the most readable, other ways of writing
|
758 |
+
// the loop seemed to noticeably worse performance after compilation.
|
759 |
+
if (slice_iters == 0) {
|
760 |
+
cp_async_wait<0>();
|
761 |
+
bool last = slice_idx == slice_count - 1;
|
762 |
+
// For per-column scales, we only fetch them here in the final step before
|
763 |
+
// write-out
|
764 |
+
if constexpr (group_blocks == -1) {
|
765 |
+
if constexpr (num_bits == 8) {
|
766 |
+
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
767 |
+
cp_async_fence();
|
768 |
+
} else {
|
769 |
+
if (last) {
|
770 |
+
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
771 |
+
cp_async_fence();
|
772 |
+
}
|
773 |
+
}
|
774 |
+
}
|
775 |
+
thread_block_reduce();
|
776 |
+
|
777 |
+
if constexpr (group_blocks == -1) {
|
778 |
+
if constexpr (num_bits == 8) {
|
779 |
+
cp_async_wait<0>();
|
780 |
+
__syncthreads();
|
781 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
782 |
+
*(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
|
783 |
+
}
|
784 |
+
} else {
|
785 |
+
if (last) {
|
786 |
+
cp_async_wait<0>();
|
787 |
+
__syncthreads();
|
788 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
789 |
+
*(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
|
790 |
+
}
|
791 |
+
}
|
792 |
+
}
|
793 |
+
}
|
794 |
+
|
795 |
+
// For 8-bit channelwise, we apply the scale before the global reduction
|
796 |
+
// that converts the fp32 results to fp16 (so that we avoid possible
|
797 |
+
// overflow in fp16)
|
798 |
+
if constexpr (group_blocks == -1 && num_bits == 8) {
|
799 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
800 |
+
#pragma unroll
|
801 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
802 |
+
scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
|
803 |
+
&frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
|
804 |
+
&frag_c[i][0][0][2], &frag_c[i][1][0][2],
|
805 |
+
&frag_c[i][2][0][2], &frag_c[i][3][0][2],
|
806 |
+
frag_s[0][2]);
|
807 |
+
|
808 |
+
scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1],
|
809 |
+
&frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0],
|
810 |
+
&frag_c[i][0][0][3], &frag_c[i][1][0][3],
|
811 |
+
&frag_c[i][2][0][3], &frag_c[i][3][0][3],
|
812 |
+
frag_s[0][2]);
|
813 |
+
|
814 |
+
scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0],
|
815 |
+
&frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0],
|
816 |
+
&frag_c[i][0][1][2], &frag_c[i][1][1][2],
|
817 |
+
&frag_c[i][2][1][2], &frag_c[i][3][1][2],
|
818 |
+
frag_s[0][2]);
|
819 |
+
|
820 |
+
scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1],
|
821 |
+
&frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0],
|
822 |
+
&frag_c[i][0][1][3], &frag_c[i][1][1][3],
|
823 |
+
&frag_c[i][2][1][3], &frag_c[i][3][1][3],
|
824 |
+
frag_s[0][2]);
|
825 |
+
}
|
826 |
+
}
|
827 |
+
}
|
828 |
+
|
829 |
+
if (slice_count > 1) { // only globally reduce if there is more than one
|
830 |
+
// block in a slice
|
831 |
+
barrier_acquire(&locks[slice_col], slice_idx);
|
832 |
+
global_reduce(slice_idx == 0, last);
|
833 |
+
barrier_release(&locks[slice_col], last);
|
834 |
+
}
|
835 |
+
if (last) // only the last block in a slice actually writes the result
|
836 |
+
write_result();
|
837 |
+
|
838 |
+
slice_row = 0;
|
839 |
+
slice_col_par++;
|
840 |
+
slice_col++;
|
841 |
+
init_slice();
|
842 |
+
if (slice_iters) {
|
843 |
+
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
844 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
845 |
+
#pragma unroll
|
846 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
847 |
+
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
848 |
+
#pragma unroll
|
849 |
+
for (int i = 0; i < m_sh_iters; i++)
|
850 |
+
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
|
851 |
+
if (slice_col == 0) {
|
852 |
+
#pragma unroll
|
853 |
+
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
854 |
+
#pragma unroll
|
855 |
+
for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
|
856 |
+
}
|
857 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
858 |
+
start_pipes();
|
859 |
+
}
|
860 |
+
}
|
861 |
+
}
|
862 |
+
}
|
863 |
+
|
864 |
+
#endif
|
865 |
+
|
866 |
+
#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
867 |
+
THREAD_K_BLOCKS, GROUP_BLOCKS) \
|
868 |
+
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
869 |
+
thread_n_blocks == THREAD_N_BLOCKS && \
|
870 |
+
thread_k_blocks == THREAD_K_BLOCKS && \
|
871 |
+
group_blocks == GROUP_BLOCKS) { \
|
872 |
+
cudaFuncSetAttribute( \
|
873 |
+
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
874 |
+
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
|
875 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
876 |
+
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
877 |
+
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
|
878 |
+
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
|
879 |
+
C_ptr, s_ptr, prob_n, \
|
880 |
+
prob_m, prob_k, locks); \
|
881 |
+
}
|
882 |
+
|
883 |
+
void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
884 |
+
void* s, int prob_m, int prob_n, int prob_k,
|
885 |
+
void* workspace, int num_bits, int groupsize = -1,
|
886 |
+
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
|
887 |
+
int thread_m = -1, int sms = -1, int max_par = 16) {
|
888 |
+
int tot_n = prob_n;
|
889 |
+
int tot_n_blocks = ceildiv(tot_n, 16);
|
890 |
+
int pad = 16 * tot_n_blocks - tot_n;
|
891 |
+
|
892 |
+
if (sms == -1) {
|
893 |
+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
894 |
+
}
|
895 |
+
TORCH_CHECK(sms > 0);
|
896 |
+
|
897 |
+
int max_shared_mem = 0;
|
898 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
899 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
900 |
+
TORCH_CHECK(max_shared_mem > 0);
|
901 |
+
|
902 |
+
if (thread_k == -1 || thread_m == -1) {
|
903 |
+
if (prob_n <= 16) {
|
904 |
+
// For small batchizes, better partitioningif is slightly more important
|
905 |
+
// than better compute utilization
|
906 |
+
thread_k = 128;
|
907 |
+
thread_m = 128;
|
908 |
+
} else {
|
909 |
+
thread_k = 64;
|
910 |
+
thread_m = 256;
|
911 |
+
}
|
912 |
+
// Also had
|
913 |
+
// if prob_n > 256
|
914 |
+
// thread_k = 32;
|
915 |
+
// thread_m = 512;
|
916 |
+
// but this is broken,
|
917 |
+
// TODO(Lucas, Alex M): figure out why
|
918 |
+
}
|
919 |
+
|
920 |
+
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
|
921 |
+
int thread_m_blocks = thread_m / 16;
|
922 |
+
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
|
923 |
+
int blocks = sms;
|
924 |
+
|
925 |
+
TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m,
|
926 |
+
" is not divisible by thread_m = ", thread_m);
|
927 |
+
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
928 |
+
" is not divisible by thread_k = ", thread_k);
|
929 |
+
if (group_blocks != -1) {
|
930 |
+
TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2,
|
931 |
+
" is not divisible by group_blocks = ", group_blocks);
|
932 |
+
}
|
933 |
+
|
934 |
+
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
935 |
+
", ", prob_n, ", ", prob_k, "]");
|
936 |
+
|
937 |
+
const int4* A_ptr = (const int4*)A;
|
938 |
+
const int4* B_ptr = (const int4*)B;
|
939 |
+
const int4* meta_ptr = (const int4*)meta;
|
940 |
+
int4* C_ptr = (int4*)C;
|
941 |
+
const int4* s_ptr = (const int4*)s;
|
942 |
+
|
943 |
+
constexpr int max_m_blocks = 4;
|
944 |
+
|
945 |
+
int* locks = (int*)workspace;
|
946 |
+
for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
|
947 |
+
int thread_n_blocks = tot_n_blocks - i;
|
948 |
+
prob_n = tot_n - 16 * i;
|
949 |
+
int par = 1;
|
950 |
+
if (thread_n_blocks > max_m_blocks) {
|
951 |
+
// Note that parallel > 1 currently only works for inputs without any
|
952 |
+
// padding
|
953 |
+
par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
|
954 |
+
if (par > max_par) par = max_par;
|
955 |
+
prob_n = (max_m_blocks * 16) * par;
|
956 |
+
i += max_m_blocks * (par - 1);
|
957 |
+
thread_n_blocks = max_m_blocks;
|
958 |
+
}
|
959 |
+
|
960 |
+
// For compilation speed, we only define the kernel configurations that have
|
961 |
+
// seemed useful (in terms of performance) in our testing, however many more
|
962 |
+
// are, in principle, possible.
|
963 |
+
|
964 |
+
// the false is start of the CALL_IF macros
|
965 |
+
if (false) {
|
966 |
+
} // BMxBNxBK, group
|
967 |
+
// 4-bit
|
968 |
+
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
969 |
+
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
970 |
+
|
971 |
+
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
972 |
+
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
973 |
+
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
974 |
+
CALL_IF_2_4(4, 16, 2, 2, 4)
|
975 |
+
CALL_IF_2_4(4, 16, 3, 2, -1)
|
976 |
+
CALL_IF_2_4(4, 16, 3, 2, 4)
|
977 |
+
CALL_IF_2_4(4, 16, 4, 2, -1)
|
978 |
+
CALL_IF_2_4(4, 16, 4, 2, 4)
|
979 |
+
|
980 |
+
CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
|
981 |
+
CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
|
982 |
+
CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
|
983 |
+
CALL_IF_2_4(4, 32, 2, 1, 4)
|
984 |
+
CALL_IF_2_4(4, 32, 3, 1, -1)
|
985 |
+
CALL_IF_2_4(4, 32, 3, 1, 4)
|
986 |
+
CALL_IF_2_4(4, 32, 4, 1, -1)
|
987 |
+
CALL_IF_2_4(4, 32, 4, 1, 4)
|
988 |
+
|
989 |
+
// 8-bit
|
990 |
+
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
991 |
+
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
992 |
+
|
993 |
+
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
994 |
+
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
995 |
+
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
996 |
+
CALL_IF_2_4(8, 16, 2, 2, 4)
|
997 |
+
CALL_IF_2_4(8, 16, 3, 2, -1)
|
998 |
+
CALL_IF_2_4(8, 16, 3, 2, 4)
|
999 |
+
CALL_IF_2_4(8, 16, 4, 2, -1)
|
1000 |
+
CALL_IF_2_4(8, 16, 4, 2, 4)
|
1001 |
+
|
1002 |
+
CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
|
1003 |
+
CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
|
1004 |
+
CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
|
1005 |
+
CALL_IF_2_4(8, 32, 2, 1, 4)
|
1006 |
+
CALL_IF_2_4(8, 32, 3, 1, -1)
|
1007 |
+
CALL_IF_2_4(8, 32, 3, 1, 4)
|
1008 |
+
CALL_IF_2_4(8, 32, 4, 1, -1)
|
1009 |
+
CALL_IF_2_4(8, 32, 4, 1, 4)
|
1010 |
+
else {
|
1011 |
+
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
|
1012 |
+
", " + str(prob_k) + ", " + str(prob_n) + "]" +
|
1013 |
+
", groupsize = " + str(groupsize) +
|
1014 |
+
", thread_m_blocks = " + str(thread_m_blocks) +
|
1015 |
+
", thread_n_blocks = " + str(thread_n_blocks) +
|
1016 |
+
", thread_k_blocks = " + str(thread_k_blocks));
|
1017 |
+
}
|
1018 |
+
|
1019 |
+
A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par;
|
1020 |
+
C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par;
|
1021 |
+
}
|
1022 |
+
}
|
1023 |
+
|
1024 |
+
} // namespace marlin_24
|
1025 |
+
|
1026 |
+
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
1027 |
+
torch::Tensor& b_meta,
|
1028 |
+
torch::Tensor& b_scales,
|
1029 |
+
torch::Tensor& workspace,
|
1030 |
+
vllm::ScalarTypeId const b_q_type_id,
|
1031 |
+
int64_t size_m, int64_t size_n,
|
1032 |
+
int64_t size_k) {
|
1033 |
+
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
1034 |
+
// Verify num_bits
|
1035 |
+
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
1036 |
+
"num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
|
1037 |
+
int pack_factor = 32 / b_q_type.size_bits();
|
1038 |
+
|
1039 |
+
// Verify M
|
1040 |
+
TORCH_CHECK(size_m == a.size(0),
|
1041 |
+
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
1042 |
+
", size_m = " + str(size_m));
|
1043 |
+
|
1044 |
+
// Verify K
|
1045 |
+
TORCH_CHECK(size_k == a.size(1),
|
1046 |
+
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
1047 |
+
", size_k = " + str(size_k));
|
1048 |
+
TORCH_CHECK(size_k % marlin_24::tile_size == 0,
|
1049 |
+
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
|
1050 |
+
str(marlin_24::tile_size));
|
1051 |
+
TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0),
|
1052 |
+
"Shape mismatch: b_q_weight.size(0) = " +
|
1053 |
+
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
1054 |
+
", tile_size = " + str(marlin_24::tile_size));
|
1055 |
+
|
1056 |
+
// Verify N
|
1057 |
+
TORCH_CHECK(b_scales.size(1) == size_n,
|
1058 |
+
"b_scales.size(1) = " + str(b_scales.size(1)) +
|
1059 |
+
", size_n = " + str(size_n));
|
1060 |
+
TORCH_CHECK(
|
1061 |
+
b_q_weight.size(1) % marlin_24::tile_size == 0,
|
1062 |
+
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
1063 |
+
" is not divisible by tile_size = " + str(marlin_24::tile_size));
|
1064 |
+
|
1065 |
+
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
|
1066 |
+
TORCH_CHECK(
|
1067 |
+
size_n == actual_size_n,
|
1068 |
+
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
1069 |
+
|
1070 |
+
// Verify meta
|
1071 |
+
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
|
1072 |
+
"b_meta.size(0) = ", b_meta.size(0),
|
1073 |
+
" is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2);
|
1074 |
+
TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1),
|
1075 |
+
" is not size_n * 2 = ", size_n * 2);
|
1076 |
+
|
1077 |
+
// Verify A device and strides
|
1078 |
+
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
1079 |
+
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
1080 |
+
TORCH_CHECK(a.dtype() == torch::kFloat16,
|
1081 |
+
"A is not float16, currently only float16 is supported");
|
1082 |
+
|
1083 |
+
// Verify B device and strides
|
1084 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
1085 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
1086 |
+
|
1087 |
+
// Verify b_meta device and strides
|
1088 |
+
TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU");
|
1089 |
+
TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous");
|
1090 |
+
|
1091 |
+
// Verify scales device and strides
|
1092 |
+
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
1093 |
+
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
1094 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat16,
|
1095 |
+
"A is not float16, currently only float16 is supported");
|
1096 |
+
|
1097 |
+
// Alloc C matrix
|
1098 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
1099 |
+
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
1100 |
+
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
1101 |
+
|
1102 |
+
int thread_k = -1;
|
1103 |
+
int thread_m = -1;
|
1104 |
+
int sms = -1;
|
1105 |
+
int max_par = marlin_24::max_par;
|
1106 |
+
|
1107 |
+
int groupsize = -1;
|
1108 |
+
if (b_scales.size(0) > 1) {
|
1109 |
+
TORCH_CHECK(size_k % b_scales.size(0) == 0,
|
1110 |
+
"size_k = " + str(size_k) +
|
1111 |
+
", is not divisible by b_scales.size(0) = " +
|
1112 |
+
str(b_scales.size(0)));
|
1113 |
+
groupsize = size_k / b_scales.size(0);
|
1114 |
+
groupsize /= 2; // Because of 24
|
1115 |
+
}
|
1116 |
+
|
1117 |
+
// Verify groupsize
|
1118 |
+
TORCH_CHECK(groupsize == -1 || groupsize == 64,
|
1119 |
+
"Unexpected groupsize = " + str(groupsize));
|
1120 |
+
|
1121 |
+
// Verify workspace size
|
1122 |
+
TORCH_CHECK(size_n % marlin_24::min_thread_n == 0,
|
1123 |
+
"size_n = " + str(size_n) +
|
1124 |
+
", is not divisible by min_thread_n = " +
|
1125 |
+
str(marlin_24::min_thread_n));
|
1126 |
+
int min_workspace_size =
|
1127 |
+
(size_n / marlin_24::min_thread_n) * marlin_24::max_par;
|
1128 |
+
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
1129 |
+
"workspace.numel = " + str(workspace.numel()) +
|
1130 |
+
" is below min_workspace_size = " + str(min_workspace_size));
|
1131 |
+
|
1132 |
+
int dev = a.get_device();
|
1133 |
+
marlin_24::marlin_cuda_2_4(
|
1134 |
+
a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
|
1135 |
+
b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
|
1136 |
+
b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
|
1137 |
+
thread_k, thread_m, sms, max_par);
|
1138 |
+
|
1139 |
+
return c;
|
1140 |
+
}
|
tests/kernels/test_marlin_gemm.py
ADDED
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tests for the marlin kernel.
|
2 |
+
|
3 |
+
Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import pytest
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import quantization
|
10 |
+
|
11 |
+
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
12 |
+
|
13 |
+
from quantization.utils.marlin_utils import (
|
14 |
+
GPTQ_MARLIN_24_MAX_PARALLEL,
|
15 |
+
GPTQ_MARLIN_24_MIN_THREAD_N,
|
16 |
+
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
17 |
+
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
18 |
+
GPTQ_MARLIN_MAX_PARALLEL,
|
19 |
+
GPTQ_MARLIN_MIN_THREAD_N,
|
20 |
+
MARLIN_SUPPORTED_GROUP_SIZES,
|
21 |
+
MARLIN_QQQ_MAX_PARALLEL,
|
22 |
+
MARLIN_QQQ_MIN_THREAD_N,
|
23 |
+
MARLIN_QQQ_SUPPORTED_GROUP_SIZES,
|
24 |
+
MARLIN_QQQ_SUPPORTED_NUM_BITS,
|
25 |
+
marlin_make_empty_g_idx,
|
26 |
+
marlin_permute_scales,
|
27 |
+
query_marlin_supported_quant_types,
|
28 |
+
)
|
29 |
+
from quantization.utils.marlin_utils_fp8 import (
|
30 |
+
pack_fp8_to_int32,
|
31 |
+
)
|
32 |
+
from quantization.utils.quant_utils import (
|
33 |
+
awq_pack,
|
34 |
+
gptq_pack,
|
35 |
+
gptq_quantize_weights,
|
36 |
+
quantize_weights,
|
37 |
+
sort_weights,
|
38 |
+
)
|
39 |
+
from quantization.scalar_type import scalar_types
|
40 |
+
|
41 |
+
from quantization.utils.marlin_utils_test import (
|
42 |
+
MarlinWorkspace,
|
43 |
+
awq_marlin_quantize,
|
44 |
+
get_weight_perm,
|
45 |
+
marlin_quantize,
|
46 |
+
marlin_weights,
|
47 |
+
)
|
48 |
+
from quantization.utils.marlin_utils_test_24 import (
|
49 |
+
marlin_24_quantize,
|
50 |
+
)
|
51 |
+
from quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
|
52 |
+
marlin_qqq_quantize,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
# Avoid torch._dynamo.exc.Unsupported: cache_size_limit reached
|
57 |
+
torch._dynamo.config.cache_size_limit = 128
|
58 |
+
|
59 |
+
|
60 |
+
capability = torch.cuda.get_device_capability()
|
61 |
+
capability = capability[0] * 10 + capability[1]
|
62 |
+
|
63 |
+
|
64 |
+
ACT_ORDER_OPTS = [False, True]
|
65 |
+
K_FULL_OPTS = [False, True]
|
66 |
+
USE_FP32_REDUCE_OPTS = [False, True]
|
67 |
+
|
68 |
+
MARLIN_K_CHUNKS = [128]
|
69 |
+
MARLIN_N_CHUNKS = [64, 256]
|
70 |
+
|
71 |
+
MARLIN_24_K_CHUNKS = [128]
|
72 |
+
MARLIN_24_N_CHUNKS = [512]
|
73 |
+
|
74 |
+
HQQ_SUPPORTED_GROUP_SIZES = [64]
|
75 |
+
|
76 |
+
MNK_FACTORS = [
|
77 |
+
(1, 1, 1),
|
78 |
+
(1, 4, 8),
|
79 |
+
(1, 7, 5),
|
80 |
+
(13, 17, 67),
|
81 |
+
(26, 37, 13),
|
82 |
+
(67, 13, 11),
|
83 |
+
(257, 13, 11),
|
84 |
+
(658, 13, 11),
|
85 |
+
]
|
86 |
+
|
87 |
+
DTYPES = [torch.float16, torch.bfloat16]
|
88 |
+
|
89 |
+
|
90 |
+
def compute_max_diff(output, output_ref):
|
91 |
+
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
92 |
+
torch.abs(output_ref)
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def rand_data(shape, dtype=torch.float16):
|
97 |
+
return torch.randn(shape, dtype=dtype, device="cuda")
|
98 |
+
|
99 |
+
|
100 |
+
@pytest.mark.skipif(
|
101 |
+
capability < 80,
|
102 |
+
reason="Marlin is not supported on this GPU type.",
|
103 |
+
)
|
104 |
+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
105 |
+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
106 |
+
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
|
107 |
+
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
108 |
+
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
109 |
+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
110 |
+
def test_gptq_marlin_repack(
|
111 |
+
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
|
112 |
+
):
|
113 |
+
m_factor, n_factor, k_factor = mnk_factors
|
114 |
+
|
115 |
+
size_k = k_chunk * k_factor
|
116 |
+
size_n = n_chunk * n_factor
|
117 |
+
|
118 |
+
# Filter act_order
|
119 |
+
if act_order:
|
120 |
+
if group_size == -1:
|
121 |
+
return
|
122 |
+
if group_size == size_k:
|
123 |
+
return
|
124 |
+
|
125 |
+
# Normalize group_size
|
126 |
+
if group_size == -1:
|
127 |
+
group_size = size_k
|
128 |
+
assert group_size <= size_k
|
129 |
+
|
130 |
+
# Create input
|
131 |
+
b_weight = rand_data((size_k, size_n))
|
132 |
+
|
133 |
+
# Quantize (and apply act_order if provided)
|
134 |
+
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
135 |
+
b_weight, quant_type, group_size, act_order
|
136 |
+
)
|
137 |
+
|
138 |
+
# Pack to GPTQ format
|
139 |
+
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
140 |
+
|
141 |
+
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
142 |
+
# increasing
|
143 |
+
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
144 |
+
if act_order:
|
145 |
+
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
146 |
+
|
147 |
+
# Pack to Marlin format
|
148 |
+
weight_perm = get_weight_perm(quant_type.size_bits)
|
149 |
+
marlin_q_w_1 = marlin_weights(
|
150 |
+
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
151 |
+
)
|
152 |
+
|
153 |
+
opcheck(
|
154 |
+
quantization._ops.ops.gptq_marlin_repack,
|
155 |
+
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
|
156 |
+
)
|
157 |
+
|
158 |
+
# Run Marlin repack GPU kernel
|
159 |
+
marlin_q_w_2 = quantization.gptq_marlin_repack(
|
160 |
+
q_w_gptq,
|
161 |
+
sort_indices,
|
162 |
+
size_k,
|
163 |
+
size_n,
|
164 |
+
quant_type.size_bits,
|
165 |
+
)
|
166 |
+
torch.cuda.synchronize()
|
167 |
+
|
168 |
+
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
169 |
+
|
170 |
+
|
171 |
+
@pytest.mark.skipif(
|
172 |
+
capability < 80,
|
173 |
+
reason="Marlin is not supported on this GPU type.",
|
174 |
+
)
|
175 |
+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
176 |
+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
177 |
+
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
|
178 |
+
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
179 |
+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
180 |
+
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
181 |
+
m_factor, n_factor, k_factor = mnk_factors
|
182 |
+
|
183 |
+
size_k = k_chunk * k_factor
|
184 |
+
size_n = n_chunk * n_factor
|
185 |
+
|
186 |
+
# Normalize group_size
|
187 |
+
if group_size == -1:
|
188 |
+
group_size = size_k
|
189 |
+
assert group_size <= size_k
|
190 |
+
|
191 |
+
# Create input
|
192 |
+
b_weight = rand_data((size_k, size_n))
|
193 |
+
|
194 |
+
# Quantize
|
195 |
+
w_ref, q_w, s, zp = quantize_weights(
|
196 |
+
b_weight, quant_type, group_size, zero_points=True
|
197 |
+
)
|
198 |
+
|
199 |
+
# Pack to AWQ format
|
200 |
+
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
201 |
+
|
202 |
+
# Pack to Marlin format
|
203 |
+
weight_perm = get_weight_perm(quant_type.size_bits)
|
204 |
+
marlin_q_w_1 = marlin_weights(
|
205 |
+
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
206 |
+
)
|
207 |
+
|
208 |
+
opcheck(
|
209 |
+
quantization._ops.ops.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
|
210 |
+
)
|
211 |
+
|
212 |
+
# Run Marlin repack GPU kernel
|
213 |
+
marlin_q_w_2 = quantization.awq_marlin_repack(
|
214 |
+
q_w_awq,
|
215 |
+
size_k,
|
216 |
+
size_n,
|
217 |
+
quant_type.size_bits,
|
218 |
+
)
|
219 |
+
torch.cuda.synchronize()
|
220 |
+
|
221 |
+
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
222 |
+
|
223 |
+
|
224 |
+
@pytest.mark.skipif(
|
225 |
+
capability < 80,
|
226 |
+
reason="Marlin is not supported on this GPU type.",
|
227 |
+
)
|
228 |
+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
229 |
+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
230 |
+
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
|
231 |
+
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
232 |
+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
233 |
+
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
234 |
+
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
235 |
+
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
236 |
+
def test_gptq_marlin_gemm(
|
237 |
+
k_chunk,
|
238 |
+
n_chunk,
|
239 |
+
quant_type,
|
240 |
+
group_size,
|
241 |
+
mnk_factors,
|
242 |
+
act_order,
|
243 |
+
is_k_full,
|
244 |
+
use_fp32_reduce,
|
245 |
+
):
|
246 |
+
m_factor, n_factor, k_factor = mnk_factors
|
247 |
+
|
248 |
+
size_m = m_factor
|
249 |
+
size_k = k_chunk * k_factor
|
250 |
+
size_n = n_chunk * n_factor
|
251 |
+
|
252 |
+
if act_order:
|
253 |
+
if group_size == -1:
|
254 |
+
return
|
255 |
+
if group_size == size_k:
|
256 |
+
return
|
257 |
+
|
258 |
+
a_input = rand_data((size_m, size_k))
|
259 |
+
b_weight = rand_data((size_k, size_n))
|
260 |
+
|
261 |
+
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
262 |
+
b_weight, quant_type, group_size, act_order
|
263 |
+
)
|
264 |
+
|
265 |
+
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
266 |
+
|
267 |
+
workspace = MarlinWorkspace(
|
268 |
+
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
269 |
+
)
|
270 |
+
|
271 |
+
opcheck(
|
272 |
+
quantization._ops.ops.gptq_marlin_gemm,
|
273 |
+
(
|
274 |
+
a_input,
|
275 |
+
marlin_q_w,
|
276 |
+
marlin_s,
|
277 |
+
marlin_zp,
|
278 |
+
g_idx,
|
279 |
+
sort_indices,
|
280 |
+
workspace.scratch,
|
281 |
+
quant_type.id,
|
282 |
+
a_input.shape[0],
|
283 |
+
b_weight.shape[1],
|
284 |
+
a_input.shape[1],
|
285 |
+
is_k_full,
|
286 |
+
False,
|
287 |
+
use_fp32_reduce,
|
288 |
+
False,
|
289 |
+
),
|
290 |
+
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
291 |
+
)
|
292 |
+
|
293 |
+
output = quantization.gptq_marlin_gemm(
|
294 |
+
a_input,
|
295 |
+
marlin_q_w,
|
296 |
+
marlin_s,
|
297 |
+
marlin_zp,
|
298 |
+
g_idx,
|
299 |
+
sort_indices,
|
300 |
+
workspace.scratch,
|
301 |
+
quant_type,
|
302 |
+
a_input.shape[0],
|
303 |
+
b_weight.shape[1],
|
304 |
+
a_input.shape[1],
|
305 |
+
is_k_full=is_k_full,
|
306 |
+
has_zp=False,
|
307 |
+
use_fp32_reduce=use_fp32_reduce,
|
308 |
+
is_zp_float=False,
|
309 |
+
)
|
310 |
+
output_ref = torch.matmul(a_input, w_ref)
|
311 |
+
|
312 |
+
torch.cuda.synchronize()
|
313 |
+
|
314 |
+
max_diff = compute_max_diff(output, output_ref)
|
315 |
+
|
316 |
+
assert max_diff < 0.04
|
317 |
+
|
318 |
+
|
319 |
+
# TODO: find better way to test this?
|
320 |
+
@torch.compile(fullgraph=True)
|
321 |
+
def marlin_24_gemm_tester(
|
322 |
+
a_input,
|
323 |
+
marlin_24_q_w_comp,
|
324 |
+
marlin_24_meta,
|
325 |
+
marlin_24_s,
|
326 |
+
scratch,
|
327 |
+
quant_type,
|
328 |
+
size_m,
|
329 |
+
size_n,
|
330 |
+
size_k,
|
331 |
+
):
|
332 |
+
return quantization.gptq_marlin_24_gemm(
|
333 |
+
a_input,
|
334 |
+
marlin_24_q_w_comp,
|
335 |
+
marlin_24_meta,
|
336 |
+
marlin_24_s,
|
337 |
+
scratch,
|
338 |
+
quant_type,
|
339 |
+
size_m,
|
340 |
+
size_n,
|
341 |
+
size_k,
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
@pytest.mark.skipif(
|
346 |
+
capability < 80,
|
347 |
+
reason="Marlin is not supported on this GPU type.",
|
348 |
+
)
|
349 |
+
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
350 |
+
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
351 |
+
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
352 |
+
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
353 |
+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
354 |
+
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
355 |
+
m_factor, n_factor, k_factor = mnk_factors
|
356 |
+
|
357 |
+
size_m = m_factor
|
358 |
+
size_k = k_chunk * k_factor
|
359 |
+
size_n = n_chunk * n_factor
|
360 |
+
|
361 |
+
a_input = rand_data((size_m, size_k))
|
362 |
+
b_weight = rand_data((size_k, size_n))
|
363 |
+
|
364 |
+
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
|
365 |
+
b_weight, quant_type, group_size
|
366 |
+
)
|
367 |
+
|
368 |
+
workspace_24 = MarlinWorkspace(
|
369 |
+
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
370 |
+
)
|
371 |
+
|
372 |
+
output_ref = torch.matmul(a_input, w_24_ref)
|
373 |
+
|
374 |
+
opcheck(
|
375 |
+
quantization._ops.ops.gptq_marlin_24_gemm,
|
376 |
+
(
|
377 |
+
a_input,
|
378 |
+
marlin_24_q_w_comp,
|
379 |
+
marlin_24_meta,
|
380 |
+
marlin_24_s,
|
381 |
+
workspace_24.scratch,
|
382 |
+
quant_type.id,
|
383 |
+
a_input.shape[0],
|
384 |
+
b_weight.shape[1],
|
385 |
+
a_input.shape[1],
|
386 |
+
),
|
387 |
+
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
388 |
+
)
|
389 |
+
|
390 |
+
output = marlin_24_gemm_tester(
|
391 |
+
a_input,
|
392 |
+
marlin_24_q_w_comp,
|
393 |
+
marlin_24_meta,
|
394 |
+
marlin_24_s,
|
395 |
+
workspace_24.scratch,
|
396 |
+
quant_type,
|
397 |
+
a_input.shape[0],
|
398 |
+
b_weight.shape[1],
|
399 |
+
a_input.shape[1],
|
400 |
+
)
|
401 |
+
|
402 |
+
torch.cuda.synchronize()
|
403 |
+
|
404 |
+
max_diff = compute_max_diff(output, output_ref)
|
405 |
+
|
406 |
+
assert max_diff < 0.04
|
407 |
+
|
408 |
+
|
409 |
+
@pytest.mark.skipif(
|
410 |
+
capability < 80,
|
411 |
+
reason="Marlin is not supported on this GPU type.",
|
412 |
+
)
|
413 |
+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
414 |
+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
415 |
+
@pytest.mark.parametrize("num_bits", [8])
|
416 |
+
@pytest.mark.parametrize("group_size", [-1])
|
417 |
+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
418 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
419 |
+
def test_fp8_marlin_gemm(
|
420 |
+
k_chunk,
|
421 |
+
n_chunk,
|
422 |
+
num_bits,
|
423 |
+
group_size,
|
424 |
+
mnk_factors,
|
425 |
+
dtype,
|
426 |
+
):
|
427 |
+
m_factor, n_factor, k_factor = mnk_factors
|
428 |
+
|
429 |
+
size_m = m_factor
|
430 |
+
size_k = k_chunk * k_factor
|
431 |
+
size_n = n_chunk * n_factor
|
432 |
+
|
433 |
+
a_input = rand_data((size_m, size_k), dtype=dtype)
|
434 |
+
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
435 |
+
|
436 |
+
# WEIGHTS
|
437 |
+
fp8_weight, weight_scale = quantization.scaled_fp8_quant(b_weight, scale=None)
|
438 |
+
# Repack weights to gptq format (packed int32 elements)
|
439 |
+
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
|
440 |
+
# Repack weights to marlin format
|
441 |
+
marlin_qweight = quantization.gptq_marlin_repack(
|
442 |
+
b_q_weight=packed_gptq_qweight,
|
443 |
+
perm=torch.empty(0, dtype=torch.int, device="cuda"),
|
444 |
+
size_k=size_k,
|
445 |
+
size_n=size_n,
|
446 |
+
num_bits=8,
|
447 |
+
)
|
448 |
+
|
449 |
+
# WEIGHT SCALES
|
450 |
+
# Currently Marlin doesn't support per-tensor scales, so we
|
451 |
+
# expand it to channelwise
|
452 |
+
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
|
453 |
+
# Permute scales
|
454 |
+
marlin_scales = marlin_permute_scales(
|
455 |
+
s=scales, size_k=size_k, size_n=size_n, group_size=-1
|
456 |
+
)
|
457 |
+
|
458 |
+
workspace = MarlinWorkspace(
|
459 |
+
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
460 |
+
)
|
461 |
+
|
462 |
+
opcheck(
|
463 |
+
quantization._ops.ops.fp8_marlin_gemm,
|
464 |
+
(
|
465 |
+
a_input,
|
466 |
+
marlin_qweight,
|
467 |
+
marlin_scales,
|
468 |
+
workspace.scratch,
|
469 |
+
num_bits,
|
470 |
+
a_input.shape[0],
|
471 |
+
b_weight.shape[1],
|
472 |
+
a_input.shape[1],
|
473 |
+
),
|
474 |
+
)
|
475 |
+
|
476 |
+
output = quantization.fp8_marlin_gemm(
|
477 |
+
a=a_input,
|
478 |
+
b_q_weight=marlin_qweight,
|
479 |
+
b_scales=marlin_scales,
|
480 |
+
workspace=workspace.scratch,
|
481 |
+
num_bits=num_bits,
|
482 |
+
size_m=a_input.shape[0],
|
483 |
+
size_n=b_weight.shape[1],
|
484 |
+
size_k=a_input.shape[1],
|
485 |
+
)
|
486 |
+
output_ref = torch.matmul(a_input, b_weight)
|
487 |
+
|
488 |
+
torch.cuda.synchronize()
|
489 |
+
|
490 |
+
max_diff = compute_max_diff(output, output_ref)
|
491 |
+
|
492 |
+
assert max_diff < 0.04
|
493 |
+
|
494 |
+
|
495 |
+
@pytest.mark.skipif(
|
496 |
+
capability < 80,
|
497 |
+
reason="Marlin is not supported on this GPU type.",
|
498 |
+
)
|
499 |
+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
500 |
+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
501 |
+
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
|
502 |
+
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
503 |
+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
504 |
+
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
505 |
+
def test_awq_marlin_gemm(
|
506 |
+
k_chunk,
|
507 |
+
n_chunk,
|
508 |
+
quant_type,
|
509 |
+
group_size,
|
510 |
+
mnk_factors,
|
511 |
+
use_fp32_reduce,
|
512 |
+
):
|
513 |
+
m_factor, n_factor, k_factor = mnk_factors
|
514 |
+
|
515 |
+
size_m = m_factor
|
516 |
+
size_k = k_chunk * k_factor
|
517 |
+
size_n = n_chunk * n_factor
|
518 |
+
|
519 |
+
a_input = rand_data((size_m, size_k))
|
520 |
+
b_weight = rand_data((size_k, size_n))
|
521 |
+
|
522 |
+
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
523 |
+
b_weight, quant_type, group_size
|
524 |
+
)
|
525 |
+
|
526 |
+
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
527 |
+
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
528 |
+
is_k_full = True
|
529 |
+
has_zp = True
|
530 |
+
|
531 |
+
workspace = MarlinWorkspace(
|
532 |
+
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
533 |
+
)
|
534 |
+
|
535 |
+
output = quantization.gptq_marlin_gemm(
|
536 |
+
a_input,
|
537 |
+
marlin_q_w,
|
538 |
+
marlin_s,
|
539 |
+
marlin_zp,
|
540 |
+
g_idx,
|
541 |
+
sort_indices,
|
542 |
+
workspace.scratch,
|
543 |
+
quant_type,
|
544 |
+
a_input.shape[0],
|
545 |
+
b_weight.shape[1],
|
546 |
+
a_input.shape[1],
|
547 |
+
is_k_full=is_k_full,
|
548 |
+
has_zp=has_zp,
|
549 |
+
use_fp32_reduce=use_fp32_reduce,
|
550 |
+
is_zp_float=False,
|
551 |
+
)
|
552 |
+
output_ref = torch.matmul(a_input, w_ref)
|
553 |
+
|
554 |
+
torch.cuda.synchronize()
|
555 |
+
|
556 |
+
max_diff = compute_max_diff(output, output_ref)
|
557 |
+
|
558 |
+
assert max_diff < 0.04
|
559 |
+
|
560 |
+
|
561 |
+
@pytest.mark.skipif(
|
562 |
+
capability < 80,
|
563 |
+
reason="Marlin is not supported on this GPU type.",
|
564 |
+
)
|
565 |
+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
566 |
+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
567 |
+
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
|
568 |
+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
569 |
+
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
570 |
+
def test_hqq_marlin_gemm(
|
571 |
+
k_chunk,
|
572 |
+
n_chunk,
|
573 |
+
group_size,
|
574 |
+
mnk_factors,
|
575 |
+
use_fp32_reduce,
|
576 |
+
):
|
577 |
+
m_factor, n_factor, k_factor = mnk_factors
|
578 |
+
|
579 |
+
size_m = m_factor
|
580 |
+
size_k = k_chunk * k_factor
|
581 |
+
size_n = n_chunk * n_factor
|
582 |
+
|
583 |
+
quant_type = scalar_types.uint4
|
584 |
+
|
585 |
+
a_input = rand_data((size_m, size_k))
|
586 |
+
dev = a_input.device
|
587 |
+
|
588 |
+
b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
|
589 |
+
scale = rand_data((size_n, size_k // group_size))
|
590 |
+
zero = rand_data((size_n, size_k // group_size))
|
591 |
+
|
592 |
+
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
|
593 |
+
|
594 |
+
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
595 |
+
marlin_w_q = quantization.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
|
596 |
+
dev
|
597 |
+
)
|
598 |
+
marlin_s = marlin_permute_scales(
|
599 |
+
scale.transpose(1, 0), size_k, size_n, group_size
|
600 |
+
).to(dev)
|
601 |
+
marlin_zp = marlin_permute_scales(
|
602 |
+
zero.transpose(1, 0), size_k, size_n, group_size
|
603 |
+
).to(dev)
|
604 |
+
|
605 |
+
g_idx = marlin_make_empty_g_idx(dev)
|
606 |
+
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
607 |
+
|
608 |
+
workspace = MarlinWorkspace(
|
609 |
+
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
610 |
+
)
|
611 |
+
|
612 |
+
output = quantization.gptq_marlin_gemm(
|
613 |
+
a_input,
|
614 |
+
marlin_w_q,
|
615 |
+
marlin_s,
|
616 |
+
marlin_zp,
|
617 |
+
g_idx,
|
618 |
+
g_idx_sort_indices,
|
619 |
+
workspace.scratch,
|
620 |
+
quant_type,
|
621 |
+
a_input.shape[0],
|
622 |
+
b_weight.shape[0],
|
623 |
+
a_input.shape[1],
|
624 |
+
is_k_full=True,
|
625 |
+
has_zp=True,
|
626 |
+
use_fp32_reduce=use_fp32_reduce,
|
627 |
+
is_zp_float=True,
|
628 |
+
)
|
629 |
+
|
630 |
+
b_flat = b_weight.reshape(-1, group_size)
|
631 |
+
zp_flat = zero.reshape(-1, 1)
|
632 |
+
s_flat = scale.reshape(-1, 1)
|
633 |
+
dequant = (b_flat - zp_flat) * s_flat
|
634 |
+
|
635 |
+
output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
|
636 |
+
|
637 |
+
torch.cuda.synchronize()
|
638 |
+
|
639 |
+
max_diff = compute_max_diff(output, output_ref)
|
640 |
+
|
641 |
+
assert max_diff < 0.04
|
642 |
+
|
643 |
+
|
644 |
+
@pytest.mark.skipif(
|
645 |
+
capability < 80,
|
646 |
+
reason="Marlin is not supported on this GPU type.",
|
647 |
+
)
|
648 |
+
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
649 |
+
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
650 |
+
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
651 |
+
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
|
652 |
+
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
653 |
+
def test_marlin_qqq_gemm(
|
654 |
+
k_chunk,
|
655 |
+
n_chunk,
|
656 |
+
num_bits,
|
657 |
+
group_size,
|
658 |
+
mnk_factors,
|
659 |
+
):
|
660 |
+
int8_traits = torch.iinfo(torch.int8)
|
661 |
+
m_factor, n_factor, k_factor = mnk_factors
|
662 |
+
|
663 |
+
size_m = m_factor
|
664 |
+
size_k = k_chunk * k_factor
|
665 |
+
size_n = n_chunk * n_factor
|
666 |
+
|
667 |
+
a_input = rand_data((size_m, size_k))
|
668 |
+
b_weight = rand_data((size_k, size_n))
|
669 |
+
|
670 |
+
# Quantize activations
|
671 |
+
s_a = (
|
672 |
+
a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(torch.float)
|
673 |
+
)
|
674 |
+
q_a = (a_input / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
|
675 |
+
|
676 |
+
# Quantize weights
|
677 |
+
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = (
|
678 |
+
marlin_qqq_quantize(b_weight, num_bits, group_size)
|
679 |
+
)
|
680 |
+
|
681 |
+
workspace = MarlinWorkspace(
|
682 |
+
size_n, MARLIN_QQQ_MIN_THREAD_N, MARLIN_QQQ_MAX_PARALLEL
|
683 |
+
)
|
684 |
+
|
685 |
+
opcheck(
|
686 |
+
quantization._ops.ops.marlin_qqq_gemm,
|
687 |
+
(
|
688 |
+
q_a,
|
689 |
+
marlin_qqq_q_w,
|
690 |
+
s_a,
|
691 |
+
marlin_qqq_s_channel,
|
692 |
+
marlin_qqq_s_group,
|
693 |
+
workspace.scratch,
|
694 |
+
a_input.shape[0],
|
695 |
+
b_weight.shape[1],
|
696 |
+
a_input.shape[1],
|
697 |
+
),
|
698 |
+
)
|
699 |
+
|
700 |
+
output = quantization.marlin_qqq_gemm(
|
701 |
+
q_a,
|
702 |
+
marlin_qqq_q_w,
|
703 |
+
s_a,
|
704 |
+
marlin_qqq_s_channel,
|
705 |
+
marlin_qqq_s_group,
|
706 |
+
workspace.scratch,
|
707 |
+
a_input.shape[0],
|
708 |
+
b_weight.shape[1],
|
709 |
+
a_input.shape[1],
|
710 |
+
)
|
711 |
+
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
|
712 |
+
|
713 |
+
torch.cuda.synchronize()
|
714 |
+
|
715 |
+
max_diff = compute_max_diff(output, output_ref)
|
716 |
+
|
717 |
+
assert max_diff < 0.04
|
718 |
+
|
719 |
+
|
720 |
+
def test_marlin_gemm_opcheck():
|
721 |
+
size_m = 2048
|
722 |
+
size_n = 4096
|
723 |
+
size_k = 4096
|
724 |
+
a = torch.rand((size_m, size_n), device="cuda", dtype=torch.float16)
|
725 |
+
w = torch.randint(-5, 5, (256, 8192), device="cuda", dtype=torch.int32)
|
726 |
+
s = torch.full((32, size_k), 0.125, device="cuda", dtype=torch.float16)
|
727 |
+
wk = MarlinWorkspace(
|
728 |
+
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
729 |
+
).scratch
|
730 |
+
x = quantization._ops.ops.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
731 |
+
y = quantization._ops.ops.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
732 |
+
torch.testing.assert_close(x, y)
|
733 |
+
opcheck(quantization._ops.ops.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))
|
tests/kernels/utils.py
CHANGED
@@ -4,13 +4,20 @@ import itertools
|
|
4 |
import random
|
5 |
import unittest
|
6 |
from numbers import Number
|
7 |
-
from typing import
|
8 |
-
Union)
|
9 |
|
10 |
import pytest
|
11 |
import torch
|
12 |
from torch._prims_common import TensorLikeType
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
15 |
"test_schema",
|
16 |
"test_autograd_registration",
|
@@ -18,6 +25,7 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
|
18 |
"test_aot_dispatch_dynamic",
|
19 |
)
|
20 |
|
|
|
21 |
# Copied/modified from torch._refs.__init__.py
|
22 |
def fp8_allclose(
|
23 |
a: TensorLikeType,
|
@@ -29,34 +37,37 @@ def fp8_allclose(
|
|
29 |
"""
|
30 |
Reference implementation of torch.allclose
|
31 |
"""
|
32 |
-
torch._refs._check_close_args(name="torch.allclose",
|
33 |
-
a=a,
|
34 |
-
b=b,
|
35 |
-
rtol=rtol,
|
36 |
-
atol=atol)
|
37 |
|
38 |
return bool(
|
39 |
torch.all(
|
40 |
-
torch.isclose(
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
# A special version of op check that has a restricted default set of test_utils
|
47 |
# and a patched version of allclose that supports fp8 types.
|
48 |
-
def opcheck(
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import random
|
5 |
import unittest
|
6 |
from numbers import Number
|
7 |
+
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
|
|
8 |
|
9 |
import pytest
|
10 |
import torch
|
11 |
from torch._prims_common import TensorLikeType
|
12 |
|
13 |
+
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
14 |
+
# bugs related to this test in PyTorch 2.4.
|
15 |
+
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
16 |
+
"test_schema",
|
17 |
+
"test_autograd_registration",
|
18 |
+
"test_faketensor",
|
19 |
+
)
|
20 |
+
|
21 |
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
22 |
"test_schema",
|
23 |
"test_autograd_registration",
|
|
|
25 |
"test_aot_dispatch_dynamic",
|
26 |
)
|
27 |
|
28 |
+
|
29 |
# Copied/modified from torch._refs.__init__.py
|
30 |
def fp8_allclose(
|
31 |
a: TensorLikeType,
|
|
|
37 |
"""
|
38 |
Reference implementation of torch.allclose
|
39 |
"""
|
40 |
+
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
|
|
|
|
|
|
|
|
41 |
|
42 |
return bool(
|
43 |
torch.all(
|
44 |
+
torch.isclose(
|
45 |
+
a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
|
46 |
+
)
|
47 |
+
).item()
|
48 |
+
)
|
49 |
+
|
50 |
|
51 |
# A special version of op check that has a restricted default set of test_utils
|
52 |
# and a patched version of allclose that supports fp8 types.
|
53 |
+
def opcheck(
|
54 |
+
op: Union[
|
55 |
+
torch._ops.OpOverload,
|
56 |
+
torch._ops.OpOverloadPacket,
|
57 |
+
torch._library.custom_ops.CustomOpDef,
|
58 |
+
],
|
59 |
+
args: Tuple[Any, ...],
|
60 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
61 |
+
*,
|
62 |
+
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
63 |
+
raise_exception: bool = True,
|
64 |
+
cond: bool = True
|
65 |
+
) -> Dict[str, str]:
|
66 |
+
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
67 |
+
return (
|
68 |
+
torch.library.opcheck(
|
69 |
+
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
70 |
+
)
|
71 |
+
if cond
|
72 |
+
else {}
|
73 |
+
)
|