danieldk HF staff commited on
Commit
165b25c
·
1 Parent(s): 2dd62c9

Add full Marlin support and tests for Marlin/CUTLASS

Browse files
build.toml CHANGED
@@ -10,9 +10,7 @@ src = [
10
  "ext-torch/torch_binding.h"
11
  ]
12
  include = [ "." ]
13
- pysrc = [
14
- "ext-torch/__init__.py"
15
- ]
16
 
17
  [kernel.cutlass_w8a8]
18
  capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
@@ -59,7 +57,6 @@ src = [
59
  "gptq_marlin/marlin.cuh",
60
  "gptq_marlin/marlin_dtypes.cuh",
61
  ]
62
- #include = [ "." ]
63
  depends = [ "torch" ]
64
 
65
  [kernel.int8_common]
@@ -83,3 +80,21 @@ src = [
83
  ]
84
  include = [ "." ]
85
  depends = [ "torch" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  "ext-torch/torch_binding.h"
11
  ]
12
  include = [ "." ]
13
+ pyroot = "ext-torch"
 
 
14
 
15
  [kernel.cutlass_w8a8]
16
  capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
 
57
  "gptq_marlin/marlin.cuh",
58
  "gptq_marlin/marlin_dtypes.cuh",
59
  ]
 
60
  depends = [ "torch" ]
61
 
62
  [kernel.int8_common]
 
80
  ]
81
  include = [ "." ]
82
  depends = [ "torch" ]
83
+
84
+ [kernel.marlin]
85
+ capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
86
+ src = [
87
+ "core/scalar_type.hpp",
88
+ "marlin/dense/common/base.h",
89
+ "marlin/dense/common/mem.h",
90
+ "marlin/dense/marlin_cuda_kernel.cu",
91
+ "marlin/qqq/marlin_qqq_gemm_kernel.cu",
92
+ "marlin/sparse/common/base.h",
93
+ "marlin/sparse/common/mem.h",
94
+ "marlin/sparse/common/mma.h",
95
+ "marlin/sparse/marlin_24_cuda_kernel.cu"
96
+ ]
97
+ include = [ "." ]
98
+ depends = [ "torch" ]
99
+
100
+
ext-torch/__init__.py CHANGED
@@ -1,177 +1,30 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch
4
-
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
- ops = torch.ops._quantization
12
- except ImportError:
13
- raise e
14
-
15
- def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
16
- return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
17
-
18
- def cutlass_scaled_mm(a: torch.Tensor,
19
- b: torch.Tensor,
20
- scale_a: torch.Tensor,
21
- scale_b: torch.Tensor,
22
- out_dtype: torch.dtype,
23
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
24
- assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
25
- assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
26
- assert bias is None or bias.shape[0] == b.shape[
27
- 1] and bias.dtype == out_dtype
28
-
29
- m = a.shape[0]
30
- n = b.shape[1]
31
-
32
- #if current_platform.is_rocm():
33
- # triton_scaled_mm_module = importlib.import_module(
34
- # "vllm.model_executor.layers.quantization.compressed_tensors."
35
- # "triton_scaled_mm")
36
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
37
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
38
-
39
- out = torch.empty((m, n), dtype=out_dtype, device=a.device)
40
-
41
- ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
42
-
43
- return out
44
-
45
- def cutlass_scaled_mm_azp(a: torch.Tensor,
46
- b: torch.Tensor,
47
- scale_a: torch.Tensor,
48
- scale_b: torch.Tensor,
49
- out_dtype: torch.dtype,
50
- azp_adj: torch.Tensor,
51
- azp: Optional[torch.Tensor] = None,
52
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
53
- """
54
- :param azp_adj: In the per-tensor case, this should include the azp.
55
- Always per-channel.
56
- :param azp: Only set in the per-token case. Per-token if set.
57
- """
58
- assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
59
- assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
60
- assert bias is None or bias.numel(
61
- ) == b.shape[1] and bias.dtype == out_dtype
62
- assert azp is None or azp.numel() == a.shape[0]
63
-
64
- m = a.shape[0]
65
- n = b.shape[1]
66
- out = torch.empty((m, n), dtype=out_dtype, device=a.device)
67
-
68
- ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
69
- azp, bias)
70
- return out
71
-
72
- # fp8
73
- def scaled_fp8_quant(
74
- input: torch.Tensor,
75
- scale: Optional[torch.Tensor] = None,
76
- num_token_padding: Optional[int] = None,
77
- scale_ub: Optional[torch.Tensor] = None,
78
- use_per_token_if_dynamic: bool = False,
79
- ) -> Tuple[torch.Tensor, torch.Tensor]:
80
- """
81
- Quantize input tensor to FP8 and return quantized tensor and scale.
82
-
83
- This function supports both static and dynamic quantization: If you
84
- provide the scale, it will use static scaling and if you omit it,
85
- the scale will be determined dynamically. The function also allows
86
- optional padding of the output tensors for downstream kernels that
87
- will benefit from padding.
88
-
89
- Args:
90
- input: The input tensor to be quantized to FP8
91
- scale: Optional scaling factor for the FP8 quantization
92
- scale_ub: Optional upper bound for scaling factor in dynamic
93
- per token case
94
- num_token_padding: If specified, pad the first dimension
95
- of the output to at least this value.
96
- use_per_token_if_dynamic: Whether to do per_tensor or per_token
97
- in the dynamic quantization case.
98
-
99
- Returns:
100
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
101
- scaling factor.
102
- """
103
- # This code assumes batch_dim and num_tokens are flattened
104
- assert (input.ndim == 2)
105
- shape: Union[Tuple[int, int], torch.Size] = input.shape
106
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
107
- #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
108
- # if current_platform.is_rocm() else torch.float8_e4m3fn
109
- out_dtype = torch.float8_e4m3fn
110
- if num_token_padding:
111
- shape = (max(num_token_padding, input.shape[0]), shape[1])
112
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
113
-
114
- if scale is None:
115
- if use_per_token_if_dynamic:
116
- scale = torch.empty((shape[0], 1),
117
- device=input.device,
118
- dtype=torch.float32)
119
- ops.dynamic_per_token_scaled_fp8_quant(
120
- output, input, scale, scale_ub)
121
- else:
122
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
123
- ops.dynamic_scaled_fp8_quant(output, input, scale)
124
- else:
125
- # num_token_padding not implemented for this case
126
- assert (scale.numel() == 1 or num_token_padding is None)
127
- ops.static_scaled_fp8_quant(output, input, scale)
128
-
129
- return output, scale
130
-
131
- # int8
132
- def scaled_int8_quant(
133
- input: torch.Tensor,
134
- scale: Optional[torch.Tensor] = None,
135
- azp: Optional[torch.Tensor] = None,
136
- symmetric: bool = True
137
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
138
- """
139
- Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
140
-
141
- Args:
142
- input: The input tensor to be quantized to int8.
143
- scale: Optional scaling factor for the int8 quantization.
144
- When not provided, we invoke dynamic-per-token quantization.
145
- azp: Optional zero-point for the int8 quantization.
146
- Must be provided for asymmetric quantization if `scale` is provided.
147
- symmetric: Whether to use symmetric quantization (scale only, azp ignored).
148
-
149
- Returns:
150
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
151
- """
152
- output = torch.empty_like(input, dtype=torch.int8)
153
- if scale is not None:
154
- # static-per-tensor quantization.
155
- assert symmetric == (
156
- azp is
157
- None), "azp must only be provided for asymmetric quantization."
158
- ops.static_scaled_int8_quant(output, input, scale, azp)
159
- return output, scale, azp
160
-
161
- # dynamic-per-token quantization.
162
- input_scales = torch.empty((input.numel() // input.shape[-1], 1),
163
- device=input.device,
164
- dtype=torch.float32)
165
- input_azp = None if symmetric else torch.empty_like(input_scales,
166
- dtype=torch.int32)
167
- ops.dynamic_scaled_int8_quant(output, input, input_scales,
168
- input_azp)
169
- return output, input_scales, input_azp
170
-
171
- # fp8 marlin
172
- def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
173
- b_scales: torch.Tensor, workspace: torch.Tensor,
174
- num_bits: int, size_m: int, size_n: int,
175
- size_k: int) -> torch.Tensor:
176
- return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
177
- num_bits, size_m, size_n, size_k)
 
1
+ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
+ from .cutlass import (
3
+ cutlass_scaled_mm_supports_fp8,
4
+ cutlass_scaled_mm,
5
+ cutlass_scaled_mm_azp,
6
+ )
7
+ from .marlin import (
8
+ awq_marlin_repack,
9
+ fp8_marlin_gemm,
10
+ gptq_marlin_gemm,
11
+ gptq_marlin_repack,
12
+ gptq_marlin_24_gemm,
13
+ marlin_qqq_gemm,
14
+ marlin_gemm,
15
+ )
16
+
17
+ __all__ = [
18
+ "awq_marlin_repack",
19
+ "cutlass_scaled_mm",
20
+ "cutlass_scaled_mm_azp",
21
+ "cutlass_scaled_mm_supports_fp8",
22
+ "fp8_marlin_gemm",
23
+ "gptq_marlin_24_gemm",
24
+ "gptq_marlin_gemm",
25
+ "gptq_marlin_repack",
26
+ "marlin_gemm",
27
+ "marlin_qqq_gemm",
28
+ "scaled_fp8_quant",
29
+ "scaled_int8_quant",
30
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ext-torch/compressed_tensors.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ # fp8
18
+ def scaled_fp8_quant(
19
+ input: torch.Tensor,
20
+ scale: Optional[torch.Tensor] = None,
21
+ num_token_padding: Optional[int] = None,
22
+ scale_ub: Optional[torch.Tensor] = None,
23
+ use_per_token_if_dynamic: bool = False,
24
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Quantize input tensor to FP8 and return quantized tensor and scale.
27
+
28
+ This function supports both static and dynamic quantization: If you
29
+ provide the scale, it will use static scaling and if you omit it,
30
+ the scale will be determined dynamically. The function also allows
31
+ optional padding of the output tensors for downstream kernels that
32
+ will benefit from padding.
33
+
34
+ Args:
35
+ input: The input tensor to be quantized to FP8
36
+ scale: Optional scaling factor for the FP8 quantization
37
+ scale_ub: Optional upper bound for scaling factor in dynamic
38
+ per token case
39
+ num_token_padding: If specified, pad the first dimension
40
+ of the output to at least this value.
41
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
42
+ in the dynamic quantization case.
43
+
44
+ Returns:
45
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
+ scaling factor.
47
+ """
48
+ # This code assumes batch_dim and num_tokens are flattened
49
+ assert input.ndim == 2
50
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
51
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
+ # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
54
+ out_dtype = torch.float8_e4m3fn
55
+ if num_token_padding:
56
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
57
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
58
+
59
+ if scale is None:
60
+ if use_per_token_if_dynamic:
61
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
63
+ else:
64
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
66
+ else:
67
+ # num_token_padding not implemented for this case
68
+ assert scale.numel() == 1 or num_token_padding is None
69
+ ops.static_scaled_fp8_quant(output, input, scale)
70
+
71
+ return output, scale
72
+
73
+
74
+ # int8
75
+ def scaled_int8_quant(
76
+ input: torch.Tensor,
77
+ scale: Optional[torch.Tensor] = None,
78
+ azp: Optional[torch.Tensor] = None,
79
+ symmetric: bool = True,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
+ """
82
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
+
84
+ Args:
85
+ input: The input tensor to be quantized to int8.
86
+ scale: Optional scaling factor for the int8 quantization.
87
+ When not provided, we invoke dynamic-per-token quantization.
88
+ azp: Optional zero-point for the int8 quantization.
89
+ Must be provided for asymmetric quantization if `scale` is provided.
90
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
+
92
+ Returns:
93
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
+ """
95
+ output = torch.empty_like(input, dtype=torch.int8)
96
+ if scale is not None:
97
+ # static-per-tensor quantization.
98
+ assert symmetric == (
99
+ azp is None
100
+ ), "azp must only be provided for asymmetric quantization."
101
+ ops.static_scaled_int8_quant(output, input, scale, azp)
102
+ return output, scale, azp
103
+
104
+ # dynamic-per-token quantization.
105
+ input_scales = torch.empty(
106
+ (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
+ )
108
+ input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
+ ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
110
+ return output, input_scales, input_azp
ext-torch/cutlass.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
+ return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
+
20
+
21
+ def cutlass_scaled_mm(
22
+ a: torch.Tensor,
23
+ b: torch.Tensor,
24
+ scale_a: torch.Tensor,
25
+ scale_b: torch.Tensor,
26
+ out_dtype: torch.dtype,
27
+ bias: Optional[torch.Tensor] = None,
28
+ ) -> torch.Tensor:
29
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
30
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
31
+ assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
32
+
33
+ m = a.shape[0]
34
+ n = b.shape[1]
35
+
36
+ # if current_platform.is_rocm():
37
+ # triton_scaled_mm_module = importlib.import_module(
38
+ # "vllm.model_executor.layers.quantization.compressed_tensors."
39
+ # "triton_scaled_mm")
40
+ # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
+ # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
+
43
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
+
45
+ ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
46
+
47
+ return out
48
+
49
+
50
+ def cutlass_scaled_mm_azp(
51
+ a: torch.Tensor,
52
+ b: torch.Tensor,
53
+ scale_a: torch.Tensor,
54
+ scale_b: torch.Tensor,
55
+ out_dtype: torch.dtype,
56
+ azp_adj: torch.Tensor,
57
+ azp: Optional[torch.Tensor] = None,
58
+ bias: Optional[torch.Tensor] = None,
59
+ ) -> torch.Tensor:
60
+ """
61
+ :param azp_adj: In the per-tensor case, this should include the azp.
62
+ Always per-channel.
63
+ :param azp: Only set in the per-token case. Per-token if set.
64
+ """
65
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
66
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
67
+ assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
68
+ assert azp is None or azp.numel() == a.shape[0]
69
+
70
+ m = a.shape[0]
71
+ n = b.shape[1]
72
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
73
+
74
+ ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
75
+ return out
ext-torch/marlin.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+ def register_fake(fn):
8
+ return lambda name: fn
9
+ else:
10
+ try:
11
+ from torch.library import register_fake
12
+ except ImportError:
13
+ from torch.library import impl_abstract as register_fake
14
+
15
+ try:
16
+ from ._ops import ops, add_op_namespace_prefix
17
+ except ImportError as e:
18
+ # Fallback for local development.
19
+ try:
20
+ import _quantization
21
+
22
+ ops = torch.ops._quantization
23
+
24
+ def add_op_namespace_prefix(op_name: str):
25
+ return f"_quantization::{op_name}"
26
+ except ImportError:
27
+ raise e
28
+
29
+
30
+ from .scalar_type import ScalarType
31
+
32
+
33
+ # fp8 marlin
34
+ def fp8_marlin_gemm(
35
+ a: torch.Tensor,
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ workspace: torch.Tensor,
39
+ num_bits: int,
40
+ size_m: int,
41
+ size_n: int,
42
+ size_k: int,
43
+ ) -> torch.Tensor:
44
+ return ops.fp8_marlin_gemm(
45
+ a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
+ )
47
+
48
+
49
+ # gptq_marlin
50
+ def gptq_marlin_gemm(
51
+ a: torch.Tensor,
52
+ b_q_weight: torch.Tensor,
53
+ b_scales: torch.Tensor,
54
+ b_zeros: torch.Tensor,
55
+ g_idx: torch.Tensor,
56
+ perm: torch.Tensor,
57
+ workspace: torch.Tensor,
58
+ b_q_type: ScalarType,
59
+ size_m: int,
60
+ size_n: int,
61
+ size_k: int,
62
+ is_k_full: bool,
63
+ has_zp: bool = False,
64
+ use_fp32_reduce: bool = False,
65
+ is_zp_float: bool = False,
66
+ ) -> torch.Tensor:
67
+ return ops.gptq_marlin_gemm(
68
+ a,
69
+ b_q_weight,
70
+ b_scales,
71
+ b_zeros,
72
+ g_idx,
73
+ perm,
74
+ workspace,
75
+ b_q_type.id,
76
+ size_m,
77
+ size_n,
78
+ size_k,
79
+ is_k_full,
80
+ has_zp,
81
+ use_fp32_reduce,
82
+ is_zp_float,
83
+ )
84
+
85
+
86
+ # gptq_marlin
87
+ def gptq_marlin_repack(
88
+ b_q_weight: torch.Tensor,
89
+ perm: torch.Tensor,
90
+ size_k: int,
91
+ size_n: int,
92
+ num_bits: int,
93
+ ) -> torch.Tensor:
94
+ return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
95
+
96
+
97
+ # gptq_marlin
98
+ def awq_marlin_repack(
99
+ b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
100
+ ) -> torch.Tensor:
101
+ return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
102
+
103
+
104
+ # marlin
105
+ def marlin_gemm(
106
+ a: torch.Tensor,
107
+ b_q_weight: torch.Tensor,
108
+ b_scales: torch.Tensor,
109
+ workspace: torch.Tensor,
110
+ size_m: int,
111
+ size_n: int,
112
+ size_k: int,
113
+ ) -> torch.Tensor:
114
+ return ops.marlin_gemm(
115
+ a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
116
+ )
117
+
118
+
119
+ # marlin_24
120
+ def gptq_marlin_24_gemm(
121
+ a: torch.Tensor,
122
+ b_q_weight: torch.Tensor,
123
+ b_meta: torch.Tensor,
124
+ b_scales: torch.Tensor,
125
+ workspace: torch.Tensor,
126
+ b_q_type: ScalarType,
127
+ size_m: int,
128
+ size_n: int,
129
+ size_k: int,
130
+ ) -> torch.Tensor:
131
+ return ops.gptq_marlin_24_gemm(
132
+ a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
133
+ )
134
+
135
+
136
+ # qqq ops
137
+ def marlin_qqq_gemm(
138
+ a: torch.Tensor,
139
+ b_q_weight: torch.Tensor,
140
+ s_tok: torch.Tensor,
141
+ s_ch: torch.Tensor,
142
+ s_group: torch.Tensor,
143
+ workspace: torch.Tensor,
144
+ size_m: int,
145
+ size_n: int,
146
+ size_k: int,
147
+ ) -> torch.Tensor:
148
+ return ops.marlin_qqq_gemm(
149
+ a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
150
+ )
151
+
152
+
153
+ # Fake ops
154
+
155
+ if hasattr(ops, "gptq_marlin_24_gemm"):
156
+ @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
+ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
+ b_scales: torch.Tensor, workspace: torch.Tensor,
159
+ num_bits: int, size_m: torch.SymInt,
160
+ size_n: torch.SymInt,
161
+ size_k: torch.SymInt) -> torch.Tensor:
162
+ return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
+
164
+ @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
+ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
+ b_meta: torch.Tensor, b_scales: torch.Tensor,
167
+ workspace: torch.Tensor,
168
+ b_q_type: ScalarType, size_m: torch.SymInt,
169
+ size_n: torch.SymInt,
170
+ size_k: torch.SymInt) -> torch.Tensor:
171
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
172
+
173
+ @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
+ def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
+ b_q_weight: torch.Tensor,
176
+ b_scales: torch.Tensor,
177
+ b_zeros: torch.Tensor,
178
+ g_idx: torch.Tensor,
179
+ perm: torch.Tensor,
180
+ workspace: torch.Tensor,
181
+ b_q_type: ScalarType,
182
+ size_m: torch.SymInt,
183
+ size_n: torch.SymInt,
184
+ size_k: torch.SymInt,
185
+ is_k_full: bool,
186
+ has_zp: bool = False,
187
+ use_fp32_reduce: bool = False,
188
+ is_zp_float: bool = False) -> torch.Tensor:
189
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
+
191
+ @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
192
+ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
193
+ s_tok: torch.Tensor, s_ch: torch.Tensor,
194
+ s_group: torch.Tensor, workspace: torch.Tensor,
195
+ size_m: torch.SymInt, size_n: torch.SymInt,
196
+ size_k: torch.SymInt) -> torch.Tensor:
197
+ return torch.empty((size_m, size_n),
198
+ dtype=torch.float16,
199
+ device=a.device)
200
+
201
+ @register_fake(add_op_namespace_prefix("marlin_gemm"))
202
+ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
203
+ b_scales: torch.Tensor, workspace: torch.Tensor,
204
+ size_m: torch.SymInt, size_n: torch.SymInt,
205
+ size_k: torch.SymInt) -> torch.Tensor:
206
+ return torch.empty((size_m, size_n),
207
+ dtype=torch.float16,
208
+ device=a.device)
ext-torch/scalar_type.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import struct
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Optional, Union
6
+
7
+
8
+ # Mirrors enum in `core/scalar_type.hpp`
9
+ class NanRepr(Enum):
10
+ NONE = 0 # nans are not supported
11
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
12
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
13
+
14
+
15
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
16
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
17
+ # in sync until the inductor fully supports custom C++ classes.
18
+ @dataclass(frozen=True)
19
+ class ScalarType:
20
+ """
21
+ ScalarType can represent a wide range of floating point and integer
22
+ types, in particular it can be used to represent sub-byte data types
23
+ (something that torch.dtype currently does not support). It is also
24
+ capable of representing types with a bias, i.e.:
25
+ `stored_value = value + bias`,
26
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
27
+ of 8). The implementation for this class can be found in
28
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
29
+ with that file.
30
+ """
31
+
32
+ exponent: int
33
+ """
34
+ Number of bits in the exponent if this is a floating point type
35
+ (zero if this an integer type)
36
+ """
37
+
38
+ mantissa: int
39
+ """
40
+ Number of bits in the mantissa if this is a floating point type,
41
+ or the number bits representing an integer excluding the sign bit if
42
+ this an integer type.
43
+ """
44
+
45
+ signed: bool
46
+ "If the type is signed (i.e. has a sign bit)"
47
+
48
+ bias: int
49
+ """
50
+ bias used to encode the values in this scalar type
51
+ (value = stored_value - bias, default 0) for example if we store the
52
+ type as an unsigned integer with a bias of 128 then the value 0 will be
53
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
54
+ """
55
+
56
+ _finite_values_only: bool = False
57
+ """
58
+ Private: if infs are supported, used `has_infs()` instead.
59
+ """
60
+
61
+ nan_repr: NanRepr = NanRepr.IEEE_754
62
+ """
63
+ How NaNs are represent in this scalar type, returns NanRepr value.
64
+ (not applicable for integer types)
65
+ """
66
+
67
+ def _floating_point_max_int(self) -> int:
68
+ assert (
69
+ self.mantissa <= 52 and self.exponent <= 11
70
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
71
+
72
+ max_mantissa = (1 << self.mantissa) - 1
73
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
74
+ max_mantissa = max_mantissa - 1
75
+
76
+ max_exponent = (1 << self.exponent) - 2
77
+ if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
78
+ or self.nan_repr == NanRepr.NONE):
79
+ assert (
80
+ self.exponent < 11
81
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
82
+ max_exponent = max_exponent + 1
83
+
84
+ # adjust the exponent to match that of a double
85
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
86
+ # e is the exponent bits), there is some precedent for non-standard
87
+ # biases, example `float8_e4m3b11fnuz` here:
88
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
89
+ # complication we are just assuming the standard exponent bias until
90
+ # there is a need to support non-standard biases
91
+ exponent_bias = (1 << (self.exponent - 1)) - 1
92
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
93
+
94
+ max_exponent_double = (max_exponent - exponent_bias +
95
+ exponent_bias_double)
96
+
97
+ # shift the mantissa and exponent into the proper positions for an
98
+ # IEEE double and bitwise-or them together.
99
+ return (max_mantissa <<
100
+ (52 - self.mantissa)) | (max_exponent_double << 52)
101
+
102
+ def _floating_point_max(self) -> float:
103
+ double_raw = self._floating_point_max_int()
104
+ return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
105
+
106
+ def _raw_max(self) -> Union[int, float]:
107
+ if self.is_floating_point():
108
+ return self._floating_point_max()
109
+ else:
110
+ assert (self.size_bits < 64 or self.size_bits == 64
111
+ and self.is_signed()), "Cannot represent max as an int"
112
+ return (1 << self.mantissa) - 1
113
+
114
+ def _raw_min(self) -> Union[int, float]:
115
+ if self.is_floating_point():
116
+ assert self.is_signed(
117
+ ), "We currently assume all floating point types are signed"
118
+ sign_bit_double = 1 << 63
119
+
120
+ max_raw = self._floating_point_max_int()
121
+ min_raw = max_raw | sign_bit_double
122
+ return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
+ else:
124
+ assert (not self.is_signed() or
125
+ self.size_bits <= 64), "Cannot represent min as a int64_t"
126
+
127
+ if self.is_signed():
128
+ return -(1 << (self.size_bits - 1))
129
+ else:
130
+ return 0
131
+
132
+ @functools.cached_property
133
+ def id(self) -> int:
134
+ """
135
+ Convert the ScalarType to an int which can be passed to pytorch custom
136
+ ops. This layout of the int must be kept in sync with the C++
137
+ ScalarType's from_id method.
138
+ """
139
+ val = 0
140
+ offset = 0
141
+
142
+ def or_and_advance(member, bit_width):
143
+ nonlocal val
144
+ nonlocal offset
145
+ bit_mask = (1 << bit_width) - 1
146
+ val = val | (int(member) & bit_mask) << offset
147
+ offset = offset + bit_width
148
+
149
+ or_and_advance(self.exponent, 8)
150
+ or_and_advance(self.mantissa, 8)
151
+ or_and_advance(self.signed, 1)
152
+ or_and_advance(self.bias, 32)
153
+ or_and_advance(self._finite_values_only, 1)
154
+ or_and_advance(self.nan_repr.value, 8)
155
+
156
+ assert offset <= 64, \
157
+ f"ScalarType fields too big {offset} to fit into an int64"
158
+
159
+ return val
160
+
161
+ @property
162
+ def size_bits(self) -> int:
163
+ return self.exponent + self.mantissa + int(self.signed)
164
+
165
+ def min(self) -> Union[int, float]:
166
+ """
167
+ Min representable value for this scalar type.
168
+ (accounting for bias if there is one)
169
+ """
170
+ return self._raw_min() - self.bias
171
+
172
+ def max(self) -> Union[int, float]:
173
+ """
174
+ Max representable value for this scalar type.
175
+ (accounting for bias if there is one)
176
+ """
177
+ return self._raw_max() - self.bias
178
+
179
+ def is_signed(self) -> bool:
180
+ """
181
+ If the type is signed (i.e. has a sign bit), same as `signed`
182
+ added for consistency with:
183
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
184
+ """
185
+ return self.signed
186
+
187
+ def is_floating_point(self) -> bool:
188
+ "If the type is a floating point type"
189
+ return self.exponent != 0
190
+
191
+ def is_integer(self) -> bool:
192
+ "If the type is an integer type"
193
+ return self.exponent == 0
194
+
195
+ def has_bias(self) -> bool:
196
+ "If the type has a non-zero bias"
197
+ return self.bias != 0
198
+
199
+ def has_infs(self) -> bool:
200
+ "If the type is floating point and supports infinity"
201
+ return not self._finite_values_only
202
+
203
+ def has_nans(self) -> bool:
204
+ return self.nan_repr != NanRepr.NONE.value
205
+
206
+ def is_ieee_754(self) -> bool:
207
+ """
208
+ If the type is a floating point type that follows IEEE 754
209
+ conventions
210
+ """
211
+ return self.nan_repr == NanRepr.IEEE_754.value and \
212
+ not self._finite_values_only
213
+
214
+ def __str__(self) -> str:
215
+ """
216
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
217
+ for floating point types (leading f) the scheme is:
218
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
219
+ flags:
220
+ - no-flags: means it follows IEEE 754 conventions
221
+ - f: means finite values only (no infinities)
222
+ - n: means nans are supported (non-standard encoding)
223
+ for integer types the scheme is:
224
+ `[u]int<size_bits>[b<bias>]`
225
+ - if bias is not present it means its zero
226
+ """
227
+ if self.is_floating_point():
228
+ ret = "float" + str(self.size_bits) + "_e" + str(
229
+ self.exponent) + "m" + str(self.mantissa)
230
+
231
+ if not self.is_ieee_754():
232
+ if self._finite_values_only:
233
+ ret = ret + "f"
234
+ if self.nan_repr != NanRepr.NONE:
235
+ ret = ret + "n"
236
+
237
+ return ret
238
+ else:
239
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
240
+ if self.has_bias():
241
+ ret = ret + "b" + str(self.bias)
242
+ return ret
243
+
244
+ def __repr__(self) -> str:
245
+ return "ScalarType." + self.__str__()
246
+
247
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
248
+ # opcheck to work.
249
+ def __len__(self) -> int:
250
+ raise TypeError
251
+
252
+ #
253
+ # Convenience Constructors
254
+ #
255
+
256
+ @classmethod
257
+ def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
258
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
259
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
260
+ ret.id # noqa B018: make sure the id is cached
261
+ return ret
262
+
263
+ @classmethod
264
+ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
+ """Create a unsigned integer scalar type."""
266
+ ret = cls(0, size_bits, False, bias if bias else 0)
267
+ ret.id # noqa B018: make sure the id is cached
268
+ return ret
269
+
270
+ @classmethod
271
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
272
+ """
273
+ Create a standard floating point type
274
+ (i.e. follows IEEE 754 conventions).
275
+ """
276
+ assert (mantissa > 0 and exponent > 0)
277
+ ret = cls(exponent, mantissa, True, 0)
278
+ ret.id # noqa B018: make sure the id is cached
279
+ return ret
280
+
281
+ @classmethod
282
+ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
283
+ nan_repr: NanRepr) -> 'ScalarType':
284
+ """
285
+ Create a non-standard floating point type
286
+ (i.e. does not follow IEEE 754 conventions).
287
+ """
288
+ assert (mantissa > 0 and exponent > 0)
289
+ assert (nan_repr != NanRepr.IEEE_754), (
290
+ "use `float_IEEE754` constructor for floating point types that "
291
+ "follow IEEE 754 conventions")
292
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
293
+ ret.id # noqa B018: make sure the id is cached
294
+ return ret
295
+
296
+
297
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
+ # for floating point types (leading f) the scheme is:
299
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
300
+ # flags:
301
+ # - no-flags: means it follows IEEE 754 conventions
302
+ # - f: means finite values only (no infinities)
303
+ # - n: means nans are supported (non-standard encoding)
304
+ # for integer types the scheme is:
305
+ # `[u]int<size_bits>[b<bias>]`
306
+ # - if bias is not present it means its zero
307
+
308
+
309
+ class scalar_types:
310
+ int4 = ScalarType.int_(4, None)
311
+ uint4 = ScalarType.uint(4, None)
312
+ int8 = ScalarType.int_(8, None)
313
+ uint8 = ScalarType.uint(8, None)
314
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
315
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
316
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
317
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
318
+
319
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
+
322
+ # "gptq" types
323
+ uint2b2 = ScalarType.uint(2, 2)
324
+ uint3b4 = ScalarType.uint(3, 4)
325
+ uint4b8 = ScalarType.uint(4, 8)
326
+ uint8b128 = ScalarType.uint(8, 128)
327
+
328
+ # colloquial names
329
+ bfloat16 = float16_e8m7
330
+ float16 = float16_e5m10
ext-torch/torch_binding.cpp CHANGED
@@ -65,16 +65,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
65
  "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
66
  "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
67
  "SymInt size_k) -> Tensor");
68
- ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
69
 
70
  // awq_marlin repack from AWQ.
71
  ops.def(
72
  "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
73
  "SymInt size_n, int num_bits) -> Tensor");
74
- ops.impl("awq_marlin_repack", &awq_marlin_repack);
75
 
76
  // gptq_marlin Optimized Quantized GEMM for GPTQ.
77
- ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
78
  ops.def(
79
  "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
80
  "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
@@ -86,7 +83,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
86
  ops.def(
87
  "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
88
  "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
 
 
90
  }
91
 
92
  TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
 
65
  "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
66
  "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
67
  "SymInt size_k) -> Tensor");
 
68
 
69
  // awq_marlin repack from AWQ.
70
  ops.def(
71
  "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
72
  "SymInt size_n, int num_bits) -> Tensor");
 
73
 
74
  // gptq_marlin Optimized Quantized GEMM for GPTQ.
 
75
  ops.def(
76
  "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
77
  "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
 
83
  ops.def(
84
  "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
85
  "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
86
+
87
+ // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
88
+ ops.def(
89
+ "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
90
+ "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
91
+ "Tensor");
92
+
93
+ // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
94
+ ops.def(
95
+ "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
96
+ "Tensor b_scales, Tensor workspace, "
97
+ "int b_q_type, "
98
+ "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
99
+
100
+ // marlin_qqq_gemm for QQQ.
101
+ ops.def(
102
+ "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
103
+ "Tensor s_tok, Tensor s_ch, Tensor s_group, "
104
+ "Tensor! workspace, SymInt size_m, SymInt size_n, "
105
+ "SymInt size_k) -> Tensor");
106
+ }
107
+
108
+ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
109
+ ops.impl("awq_marlin_repack", &awq_marlin_repack);
110
+ ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
111
+ ops.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
112
+ ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
113
  ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
114
+ ops.impl("marlin_gemm", &marlin_gemm);
115
+ ops.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
116
  }
117
 
118
  TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
ext-torch/torch_binding.h CHANGED
@@ -74,3 +74,26 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
74
  torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
75
  torch::Tensor& perm, c10::SymInt size_k,
76
  c10::SymInt size_n, int64_t num_bits);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
75
  torch::Tensor& perm, c10::SymInt size_k,
76
  c10::SymInt size_n, int64_t num_bits);
77
+
78
+
79
+ // Marlin
80
+
81
+ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
82
+ torch::Tensor& b_scales, torch::Tensor& workspace,
83
+ int64_t size_m, int64_t size_n, int64_t size_k);
84
+
85
+ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
86
+ torch::Tensor& b_meta,
87
+ torch::Tensor& b_scales,
88
+ torch::Tensor& workspace,
89
+ vllm::ScalarTypeId const b_q_type_id,
90
+ int64_t size_m, int64_t size_n,
91
+ int64_t size_k);
92
+
93
+ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
94
+ torch::Tensor const& b_q_weight,
95
+ torch::Tensor const& s_tok,
96
+ torch::Tensor const& s_ch,
97
+ torch::Tensor const& s_group,
98
+ torch::Tensor& workspace, int64_t size_m,
99
+ int64_t size_n, int64_t size_k);
ext-torch/utils/marlin_utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ import quantization as ops
7
+ from quantization.scalar_type import ScalarType, scalar_types
8
+
9
+ from .quant_utils import pack_cols, unpack_cols
10
+
11
+ GPTQ_MARLIN_TILE = 16
12
+ GPTQ_MARLIN_MIN_THREAD_N = 64
13
+ GPTQ_MARLIN_MIN_THREAD_K = 128
14
+ GPTQ_MARLIN_MAX_PARALLEL = 16
15
+
16
+ GPTQ_MARLIN_24_TILE = 16
17
+ GPTQ_MARLIN_24_MIN_THREAD_N = 128
18
+ GPTQ_MARLIN_24_MIN_THREAD_K = 128
19
+ GPTQ_MARLIN_24_MAX_PARALLEL = 64
20
+
21
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
22
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
23
+
24
+ MARLIN_QQQ_TILE = 16
25
+ MARLIN_QQQ_MIN_THREAD_N = 64
26
+ MARLIN_QQQ_MIN_THREAD_K = 128
27
+ MARLIN_QQQ_MAX_PARALLEL = 16
28
+
29
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
30
+ MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
31
+ MARLIN_QQQ_SUPPORTED_SYM = [True]
32
+
33
+ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
34
+
35
+ # In case there is a performance issue with Marlin, the variable below can be
36
+ # changed to False, which allows Marlin to perform global reductions in fp16
37
+ # precision (instead of fp32), and therefore, save on some memory movements.
38
+ USE_FP32_REDUCE_DEFAULT = True
39
+
40
+
41
+ # For binary size and compile time, we don't support the same types for with and
42
+ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
+ # TODO: we may want to move this into the C++ so its closer to the actual impl
44
+ def query_marlin_supported_quant_types(
45
+ has_zp: bool, device_capability: Optional[int] = None
46
+ ):
47
+ if device_capability is None:
48
+ capability_tuple = torch.cuda.get_device_capability()
49
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
50
+
51
+ if device_capability < 80:
52
+ return []
53
+
54
+ if has_zp:
55
+ # AWQ style, unsigned + runtime zero-point
56
+ return [scalar_types.uint4, scalar_types.uint8]
57
+ else:
58
+ # GPTQ style, unsigned + symmetric bias
59
+ # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
+ # to add `scalar_types.float8_e4m3fn` here
61
+ return [scalar_types.uint4b8, scalar_types.uint8b128]
62
+
63
+
64
+ def _check_marlin_supported(
65
+ quant_type: ScalarType,
66
+ group_size: Optional[int],
67
+ has_zp: bool,
68
+ device_capability: Optional[int] = None,
69
+ ) -> Tuple[bool, Optional[str]]:
70
+
71
+ if device_capability is None:
72
+ capability_tuple = torch.cuda.get_device_capability()
73
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
+
75
+ supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
76
+
77
+ if quant_type not in supported_types:
78
+ return (
79
+ False,
80
+ f"Marlin does not support weight_bits = {quant_type}. "
81
+ f"Only types = {supported_types} "
82
+ f"are supported (for group_size = {group_size}, "
83
+ f"device_capability = {device_capability}, zp = {has_zp}).",
84
+ )
85
+ if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
+ return (
87
+ False,
88
+ f"Marlin does not support group_size = {group_size}. "
89
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
+ "are supported.",
91
+ )
92
+
93
+ return True, None
94
+
95
+
96
+ def check_marlin_supported(
97
+ quant_type: ScalarType,
98
+ group_size: int,
99
+ has_zp: bool = False,
100
+ device_capability: Optional[int] = None,
101
+ ) -> bool:
102
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
+ return cond
104
+
105
+
106
+ def verify_marlin_supported(
107
+ quant_type: ScalarType, group_size: int, has_zp: bool = False
108
+ ) -> None:
109
+ cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
+ if not cond:
111
+ assert err_msg is not None
112
+ raise ValueError(err_msg)
113
+
114
+
115
+ def verify_marlin_supports_shape(
116
+ output_size_per_partition: int,
117
+ input_size_per_partition: int,
118
+ input_size: int,
119
+ group_size: int,
120
+ ) -> None:
121
+
122
+ # Validate output_size_per_partition
123
+ if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
+ raise ValueError(
125
+ f"Weight output_size_per_partition = "
126
+ f"{output_size_per_partition} is not divisible by "
127
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
+ "Consider reducing tensor_parallel_size or running "
129
+ "with --quantization gptq."
130
+ )
131
+
132
+ # Validate input_size_per_partition
133
+ if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
+ raise ValueError(
135
+ f"Weight input_size_per_partition = "
136
+ f"{input_size_per_partition} is not divisible "
137
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
+ "Consider reducing tensor_parallel_size or running "
139
+ "with --quantization gptq."
140
+ )
141
+
142
+ if group_size < input_size and input_size_per_partition % group_size != 0:
143
+ raise ValueError(
144
+ f"Weight input_size_per_partition = {input_size_per_partition}"
145
+ f" is not divisible by group_size = {group_size}."
146
+ "Consider reducing tensor_parallel_size or running "
147
+ "with --quantization gptq."
148
+ )
149
+
150
+
151
+ def check_marlin_supports_shape(
152
+ output_size_per_partition: int,
153
+ input_size_per_partition: int,
154
+ input_size: int,
155
+ group_size: int,
156
+ ) -> Tuple[bool, Optional[str]]:
157
+ try:
158
+ verify_marlin_supports_shape(
159
+ output_size_per_partition, input_size_per_partition, input_size, group_size
160
+ )
161
+ except ValueError as e:
162
+ return False, e.__str__()
163
+ return True, None
164
+
165
+
166
+ def marlin_make_workspace(
167
+ output_size_per_partition: int, device: torch.device
168
+ ) -> torch.Tensor:
169
+ max_workspace_size = (
170
+ output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
+ ) * GPTQ_MARLIN_MAX_PARALLEL
172
+
173
+ return torch.zeros(
174
+ max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
+ )
176
+
177
+
178
+ def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
+ return (not act_order) or (act_order and not is_row_parallel)
180
+
181
+
182
+ def marlin_repeat_scales_on_all_ranks(
183
+ act_order: bool, group_size: int, is_row_parallel: bool
184
+ ) -> bool:
185
+ # Need to repeat scales on every rank if act_ordering or
186
+ # channelwise and RowParallelLinear
187
+ is_channelwise = group_size == -1
188
+ return act_order or (is_channelwise and is_row_parallel)
189
+
190
+
191
+ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
+ return torch.nn.Parameter(
193
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
+ )
195
+
196
+
197
+ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
+ return torch.nn.Parameter(
199
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
+ )
201
+
202
+
203
+ def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
+ return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
+
207
+
208
+ def get_scale_perms():
209
+ scale_perm: List[int] = []
210
+ for i in range(8):
211
+ scale_perm.extend([i + 8 * j for j in range(8)])
212
+ scale_perm_single: List[int] = []
213
+ for i in range(4):
214
+ scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
215
+ return scale_perm, scale_perm_single
216
+
217
+
218
+ def marlin_permute_scales(
219
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
+ ) -> torch.Tensor:
221
+
222
+ scale_perm, scale_perm_single = get_scale_perms()
223
+ if group_size < size_k and group_size != -1:
224
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
225
+ else:
226
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
227
+ s = s.reshape((-1, size_n)).contiguous()
228
+
229
+ return s
230
+
231
+
232
+ def marlin_moe_permute_scales(
233
+ s: torch.Tensor,
234
+ size_k: int,
235
+ size_n: int,
236
+ group_size: int,
237
+ ):
238
+ num_experts = s.shape[0]
239
+ output = torch.empty(
240
+ (num_experts, s.shape[1], s.shape[2]),
241
+ device=s.device,
242
+ dtype=s.dtype,
243
+ )
244
+
245
+ for e in range(num_experts):
246
+ output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
247
+ return output
248
+
249
+
250
+ def marlin_zero_points(
251
+ zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
+ ) -> torch.Tensor:
253
+ # Permute zero-points in a similar way to scales, but do not use the
254
+ # "single" permutation, since zero-points are applied on every MMA
255
+ scale_perm, _ = get_scale_perms()
256
+ zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
257
+
258
+ # Interleave column dim (for the dequantize code) and pack it to int32
259
+ if num_bits == 4:
260
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
261
+ elif num_bits == 8:
262
+ interleave = numpy.array([0, 2, 1, 3])
263
+ else:
264
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
265
+
266
+ zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
267
+ zp = zp.reshape((-1, size_n)).contiguous()
268
+ zp = pack_cols(zp, num_bits, size_k, size_n)
269
+
270
+ return zp
271
+
272
+
273
+ def awq_to_marlin_zero_points(
274
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
+ ) -> torch.Tensor:
276
+ # AWQ zero-points are quantized and packed on the column dim.
277
+ # In addition, the values are permuted based on dequantizer.
278
+ # Here we undo both of these, and then apply marlin permutation
279
+ # and pack it back.
280
+ q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
281
+
282
+ # Undo interleaving (use argsort(..) to get inverse perm)
283
+ if num_bits == 4:
284
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
285
+ elif num_bits == 8:
286
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
287
+ else:
288
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
289
+
290
+ q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
291
+ q_zp = q_zp.reshape((-1, size_n)).contiguous()
292
+
293
+ marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
294
+ return marlin_zp
295
+
296
+
297
+ def moe_awq_to_marlin_zero_points(
298
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
+ ):
300
+ num_experts = q_zp_packed.shape[0]
301
+ output = torch.empty(
302
+ (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
303
+ device=q_zp_packed.device,
304
+ dtype=q_zp_packed.dtype,
305
+ )
306
+ for e in range(num_experts):
307
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
308
+ return output
309
+
310
+
311
+ def apply_gptq_marlin_linear(
312
+ input: torch.Tensor,
313
+ weight: torch.Tensor,
314
+ weight_scale: torch.Tensor,
315
+ weight_zp: torch.Tensor,
316
+ g_idx: torch.Tensor,
317
+ g_idx_sort_indices: torch.Tensor,
318
+ workspace: torch.Tensor,
319
+ wtype: ScalarType,
320
+ output_size_per_partition: int,
321
+ input_size_per_partition: int,
322
+ is_k_full: bool,
323
+ bias: Optional[torch.Tensor] = None,
324
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
+ ) -> torch.Tensor:
326
+ reshaped_x = input.reshape(-1, input.shape[-1])
327
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
328
+
329
+ output = ops.gptq_marlin_gemm(
330
+ reshaped_x,
331
+ weight,
332
+ weight_scale,
333
+ weight_zp,
334
+ g_idx,
335
+ g_idx_sort_indices,
336
+ workspace,
337
+ wtype,
338
+ size_m=reshaped_x.shape[0],
339
+ size_n=output_size_per_partition,
340
+ size_k=input_size_per_partition,
341
+ is_k_full=is_k_full,
342
+ has_zp=False,
343
+ use_fp32_reduce=use_fp32_reduce,
344
+ is_zp_float=False,
345
+ )
346
+
347
+ if bias is not None:
348
+ output.add_(bias) # In-place add
349
+
350
+ return output.reshape(out_shape)
351
+
352
+
353
+ def apply_awq_marlin_linear(
354
+ input: torch.Tensor,
355
+ weight: torch.Tensor,
356
+ weight_scale: torch.Tensor,
357
+ weight_zp: torch.Tensor,
358
+ g_idx: torch.Tensor,
359
+ g_idx_sort_indices: torch.Tensor,
360
+ workspace: torch.Tensor,
361
+ quant_type: ScalarType,
362
+ output_size_per_partition: int,
363
+ input_size_per_partition: int,
364
+ bias: Optional[torch.Tensor] = None,
365
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
+ ) -> torch.Tensor:
367
+ reshaped_x = input.reshape(-1, input.shape[-1])
368
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
369
+
370
+ output = ops.gptq_marlin_gemm(
371
+ reshaped_x,
372
+ weight,
373
+ weight_scale,
374
+ weight_zp,
375
+ g_idx,
376
+ g_idx_sort_indices,
377
+ workspace,
378
+ quant_type,
379
+ size_m=reshaped_x.shape[0],
380
+ size_n=output_size_per_partition,
381
+ size_k=input_size_per_partition,
382
+ is_k_full=True,
383
+ has_zp=True,
384
+ use_fp32_reduce=use_fp32_reduce,
385
+ is_zp_float=False,
386
+ )
387
+
388
+ if bias is not None:
389
+ output.add_(bias) # In-place add
390
+
391
+ return output.reshape(out_shape)
ext-torch/utils/marlin_utils_fp8.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ import quantization as ops
6
+
7
+ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
+
9
+
10
+ def is_fp8_marlin_supported():
11
+ capability = torch.cuda.get_device_capability()
12
+ capability = capability[0] * 10 + capability[1]
13
+ return capability >= 80
14
+
15
+
16
+ def apply_fp8_marlin_linear(
17
+ input: torch.Tensor,
18
+ weight: torch.Tensor,
19
+ weight_scale: torch.Tensor,
20
+ workspace: torch.Tensor,
21
+ size_n: int,
22
+ size_k: int,
23
+ bias: Optional[torch.Tensor],
24
+ ) -> torch.Tensor:
25
+ # For GPUs that lack FP8 hardware support, we can leverage the
26
+ # Marlin kernel for fast weight-only FP8 quantization
27
+
28
+ reshaped_x = input.reshape(-1, input.shape[-1])
29
+ out_shape = input.shape[:-1] + (size_n,)
30
+
31
+ output = ops.fp8_marlin_gemm(
32
+ a=reshaped_x,
33
+ b_q_weight=weight,
34
+ b_scales=weight_scale,
35
+ workspace=workspace,
36
+ num_bits=8,
37
+ size_m=reshaped_x.shape[0],
38
+ size_n=size_n,
39
+ size_k=size_k,
40
+ )
41
+
42
+ if bias is not None:
43
+ output.add_(bias) # In-place add
44
+
45
+ return output.reshape(out_shape)
46
+
47
+
48
+ def prepare_fp8_layer_for_marlin(
49
+ layer: torch.nn.Module, strategy: str = "tensor"
50
+ ) -> None:
51
+ part_size_n = layer.output_size_per_partition
52
+ part_size_k = layer.input_size_per_partition
53
+
54
+ device = layer.weight.device
55
+
56
+ # WORKSPACE
57
+ layer.workspace = marlin_make_workspace(part_size_n, device)
58
+
59
+ # WEIGHT
60
+ # Repack weights to marlin format
61
+ marlin_qweight = ops.gptq_marlin_repack(
62
+ b_q_weight=pack_fp8_to_int32(layer.weight),
63
+ perm=torch.empty(0, dtype=torch.int, device=device),
64
+ size_k=part_size_k,
65
+ size_n=part_size_n,
66
+ num_bits=8,
67
+ )
68
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
+
70
+ # WEIGHT SCALES
71
+ scales = layer.weight_scale.to(layer.orig_dtype)
72
+ # Permute scales
73
+ marlin_scales = marlin_permute_scales(
74
+ s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
+ )
76
+ layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
+
78
+
79
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
+ """
81
+ Repack FP8 weights to gptq format (packed int32 elements)
82
+ """
83
+ assert fp8_tensor.dtype == torch.float8_e4m3fn
84
+ assert fp8_tensor.shape[0] % 4 == 0
85
+
86
+ # Reshape to prepare for packing
87
+ reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
88
+
89
+ # Convert fp8 to uint8 (byte) representation
90
+ byte_tensor = reshaped.view(torch.uint8)
91
+
92
+ # Pack 4 uint8 values into one int32
93
+ packed = (
94
+ byte_tensor[:, 0].to(torch.int32)
95
+ | (byte_tensor[:, 1].to(torch.int32) << 8)
96
+ | (byte_tensor[:, 2].to(torch.int32) << 16)
97
+ | (byte_tensor[:, 3].to(torch.int32) << 24)
98
+ )
99
+
100
+ return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
ext-torch/utils/marlin_utils_test.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from quantization.scalar_type import ScalarType
9
+
10
+ from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
11
+ from .quant_utils import (
12
+ get_pack_factor,
13
+ gptq_quantize_weights,
14
+ quantize_weights,
15
+ sort_weights,
16
+ )
17
+
18
+
19
+ class MarlinWorkspace:
20
+
21
+ def __init__(self, out_features, min_thread_n, max_parallel):
22
+ assert (
23
+ out_features % min_thread_n == 0
24
+ ), "out_features = {} is undivisible by min_thread_n = {}".format(
25
+ out_features, min_thread_n
26
+ )
27
+
28
+ max_workspace_size = (out_features // min_thread_n) * max_parallel
29
+
30
+ self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
31
+
32
+
33
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
34
+ assert q_w.shape == (size_k, size_n)
35
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
36
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
37
+
38
+ # Permute weights to 16x64 marlin tiles
39
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
40
+ q_w = q_w.permute((0, 2, 1, 3))
41
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
42
+
43
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
44
+
45
+ return q_w
46
+
47
+
48
+ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
49
+ # Permute
50
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
51
+
52
+ # Pack
53
+ pack_factor = get_pack_factor(num_bits)
54
+ orig_device = q_w.device
55
+
56
+ q_w = q_w.cpu().numpy().astype(np.uint32)
57
+
58
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
59
+ for i in range(pack_factor):
60
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
61
+
62
+ q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
63
+
64
+ return q_packed
65
+
66
+
67
+ def get_weight_perm(num_bits: int):
68
+ perm_list: List[int] = []
69
+ for i in range(32):
70
+ perm1: List[int] = []
71
+ col = i // 4
72
+ for block in [0, 1]:
73
+ for row in [
74
+ 2 * (i % 4),
75
+ 2 * (i % 4) + 1,
76
+ 2 * (i % 4 + 4),
77
+ 2 * (i % 4 + 4) + 1,
78
+ ]:
79
+ perm1.append(16 * row + col + 8 * block)
80
+ for j in range(4):
81
+ perm_list.extend([p + 256 * j for p in perm1])
82
+
83
+ perm = np.array(perm_list)
84
+
85
+ if num_bits == 4:
86
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
87
+ elif num_bits == 8:
88
+ interleave = np.array([0, 2, 1, 3])
89
+ else:
90
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
91
+
92
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
93
+ perm = torch.from_numpy(perm)
94
+ return perm
95
+
96
+
97
+ def marlin_quantize(
98
+ w: torch.Tensor,
99
+ quant_type: ScalarType,
100
+ group_size: int,
101
+ act_order: bool,
102
+ test_perm: Optional[torch.Tensor] = None,
103
+ ):
104
+ size_k, size_n = w.shape
105
+ num_bits = quant_type.size_bits
106
+
107
+ # Normalize group_size
108
+ if group_size == -1:
109
+ group_size = size_k
110
+ assert group_size <= size_k
111
+
112
+ # Quantize (and apply act_order if provided)
113
+ w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
114
+ w, quant_type, group_size, act_order, test_perm
115
+ )
116
+
117
+ # For act_order, sort the "weights" and "g_idx" so that group ids are
118
+ # increasing
119
+ sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
120
+ if act_order:
121
+ q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
122
+
123
+ # Reformat to marlin
124
+ weight_perm = get_weight_perm(num_bits)
125
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
126
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
127
+
128
+ # Create result
129
+ res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
130
+ for i in range(len(res_list)):
131
+ res_list[i] = res_list[i].to(w.device)
132
+
133
+ return res_list
134
+
135
+
136
+ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
137
+ size_k, size_n = w.shape
138
+
139
+ # Normalize group_size
140
+ if group_size == -1:
141
+ group_size = size_k
142
+ assert group_size <= size_k
143
+
144
+ # Detect num groups
145
+ assert size_k % group_size == 0
146
+ num_groups = size_k // group_size
147
+
148
+ # Quantize with zp
149
+ w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
150
+
151
+ # Reformat to marlin
152
+ weight_perm = get_weight_perm(quant_type.size_bits)
153
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
154
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
155
+ marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
156
+
157
+ # Create result
158
+ res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
159
+ for i in range(len(res_list)):
160
+ res_list[i] = res_list[i].to(w.device)
161
+
162
+ return res_list
ext-torch/utils/marlin_utils_test_24.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ import random
4
+ from typing import List
5
+
6
+ import numpy
7
+ import torch
8
+
9
+ from quantization.scalar_type import ScalarType
10
+
11
+ from .marlin_utils_test import marlin_weights
12
+ from .quant_utils import gptq_quantize_weights
13
+
14
+
15
+ # This is PyTorch implementation of main part of reorder_meta()
16
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
17
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
18
+ # GEMM decides upon layout of this matrix, and at the moment for the
19
+ # sparse GEMM executed on tensor cores, this is layout described by
20
+ # ColumnMajorInterleaved<2> data structure, in
21
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
22
+ # reordering of meta matrix into meta_reordered matrix calculated
23
+ # according to these segments of CUTLASS code is re-implemented here.
24
+ # Note that this calculation produces offsets for scattering metadata
25
+ # matrix elements into reordered metadata matrix elements (or,
26
+ # equivalently, for gathering reordered metadata matrix element back
27
+ # into metadata matrix elements).
28
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
29
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
30
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
31
+
32
+ # Reorder the rows, then swizzle the 2x2 blocks.
33
+ group_x = 64
34
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
35
+
36
+ dst_rows = (
37
+ dst_rows // group_x * group_x
38
+ + (dst_rows % 2) * 2
39
+ + (dst_rows % 8) // 4
40
+ + ((dst_rows % group_y) % 4) // 2 * 32
41
+ + ((dst_rows % group_x) // 8) * 4
42
+ )
43
+
44
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
45
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
46
+ dst_rows += topright - bottomleft
47
+ dst_cols -= topright - bottomleft
48
+
49
+ # Assumed that meta tensor is to be stored in CUTLASS
50
+ # InterleavedColumnMajor layout, and reverse engineered
51
+ # corresponding code to store values into this tensor.
52
+ interleave = 2
53
+ cols_maj = dst_cols // interleave
54
+ cols_min = dst_cols % interleave
55
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
56
+
57
+
58
+ # This function converts dense matrix into sparse semi-structured
59
+ # representation, producing "compressed" matrix, in the layout used by
60
+ # CUTLASS backend, and corresponding metadata matrix.
61
+ def sparse_semi_structured_from_dense_cutlass(dense):
62
+ if dense.dim() != 2:
63
+ raise RuntimeError(
64
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
65
+ )
66
+
67
+ m, k = dense.shape
68
+ device = dense.device
69
+
70
+ meta_dtype = torch.int8
71
+ if dense.dtype == torch.int8:
72
+ meta_dtype = torch.int32
73
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
74
+ meta_dtype = torch.int16
75
+ else:
76
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
77
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
78
+ if quadbits_per_meta_elem not in (4, 8):
79
+ raise RuntimeError("Invalid number of elements per meta element calculated")
80
+
81
+ if meta_dtype == torch.int32:
82
+ if m % 16 != 0:
83
+ raise RuntimeError(
84
+ f"Number of rows of dense matrix {m} must be divisible by 16"
85
+ )
86
+ else:
87
+ if m % 32 != 0:
88
+ raise RuntimeError(
89
+ f"Number of rows of dense matrix {m} must be divisible by 32"
90
+ )
91
+ if k % (4 * quadbits_per_meta_elem) != 0:
92
+ raise RuntimeError(
93
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
94
+ )
95
+
96
+ if dense.dtype != torch.float:
97
+ ksparse = 4
98
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
99
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
100
+ else:
101
+ ksparse = 2
102
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
103
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
104
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
105
+
106
+ # Encoding quadruples of True/False values as follows:
107
+ # [True, True, False, False] -> 0b0100
108
+ # [True, False, True, False] -> 0b1000
109
+ # [False, True, True, False] -> 0b1001
110
+ # [True, False, False, True ] -> 0b1100
111
+ # [False, True, False, True ] -> 0b1101
112
+ # [False, False, True, True ] -> 0b1110
113
+ # Thus, lower two bits in the encoding are index of the True value
114
+ # at the lowest index in the quadruple, and the higher two bits in
115
+ # the encoding are index of the other True value in the quadruple.
116
+ # In case there are less than two True values, than False value or
117
+ # values at some index or indices are considered True for the
118
+ # encoding. In case there are more than two True values, then the
119
+ # excess True value(s) at some indices are considered False for
120
+ # the encoding. The exact encodings used for these cases are as
121
+ # follows:
122
+ # [False, False, False, False] -> 0b1110
123
+ # [False, False, False, True ] -> 0b1110
124
+ # [False, False, True, False] -> 0b1110
125
+ # [False, True, False, False] -> 0b1001
126
+ # [False, True, True, True ] -> 0b1101
127
+ # [True, False, False, False] -> 0b1000
128
+ # [True, False, True, True ] -> 0b1100
129
+ # [True, True, False, True ] -> 0b0100
130
+ # [True, True, True, False] -> 0b0100
131
+ # [True, True, True, True ] -> 0b0100
132
+ # These particular encodings are chosen, with the help of Espresso
133
+ # logic minimizer software, for the purpose of minimization of
134
+ # corresponding Boolean functions, that translate non-zero flags
135
+ # into encoding bits. Note also possible choices for the first
136
+ # and last of these encodings were limited only to (0b0100,
137
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
138
+ # case.
139
+
140
+ expr0 = m0 & m1
141
+ expr1 = ~m0 & m1
142
+ expr2 = ~m0 & ~m1
143
+ bit0 = expr1
144
+ bit1 = expr2
145
+ bit2 = expr0 | expr2 | m3
146
+ bit3 = expr1 | ~m1
147
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
148
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
149
+
150
+ if dense.dtype != torch.float:
151
+ sparse0 = dense_4.gather(
152
+ -1, idxs0.unsqueeze(-1)
153
+ ) # type: ignore[possibly-undefined]
154
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
155
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
156
+ else:
157
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
158
+ m, k // 2
159
+ ) # type: ignore[possibly-undefined]
160
+
161
+ meta_4 = idxs0 | (idxs1 << 2)
162
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
163
+
164
+ if quadbits_per_meta_elem == 4:
165
+ meta = (
166
+ meta_n[:, :, 0]
167
+ | (meta_n[:, :, 1] << 4)
168
+ | (meta_n[:, :, 2] << 8)
169
+ | (meta_n[:, :, 3] << 12)
170
+ )
171
+ elif quadbits_per_meta_elem == 8:
172
+ meta = (
173
+ meta_n[:, :, 0]
174
+ | (meta_n[:, :, 1] << 4)
175
+ | (meta_n[:, :, 2] << 8)
176
+ | (meta_n[:, :, 3] << 12)
177
+ | (meta_n[:, :, 4] << 16)
178
+ | (meta_n[:, :, 5] << 20)
179
+ | (meta_n[:, :, 6] << 24)
180
+ | (meta_n[:, :, 7] << 28)
181
+ )
182
+
183
+ # Reorder meta tensor elements.
184
+ meta_reordered = meta.new_empty(
185
+ (m * meta_ncols,)
186
+ ) # type: ignore[possibly-undefined]
187
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
188
+ m, meta_ncols, meta_dtype, device
189
+ )
190
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
191
+
192
+ return (sparse, meta_reordered.view(m, meta_ncols))
193
+
194
+
195
+ # This function performs reverse of the function above - it
196
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
197
+ # in the layout used by CUTLASS backend, and accompanying metadata
198
+ # matrix.
199
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
200
+ if sparse.dim() != 2:
201
+ raise RuntimeError(
202
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
203
+ )
204
+
205
+ m, k = sparse.shape
206
+ device = sparse.device
207
+
208
+ if meta_reordered.dim() != 2:
209
+ raise RuntimeError(
210
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
211
+ )
212
+ if meta_reordered.device != device:
213
+ raise RuntimeError(
214
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
215
+ )
216
+
217
+ meta_dtype = meta_reordered.dtype
218
+ if meta_dtype not in (torch.int16, torch.int32):
219
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
220
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
221
+
222
+ ksparse = 4 if sparse.dtype != torch.float else 2
223
+
224
+ meta_nrows, meta_ncols = meta_reordered.shape
225
+ if meta_nrows != m:
226
+ raise RuntimeError(
227
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
228
+ )
229
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
230
+ raise RuntimeError(
231
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
232
+ "expected according to the number of columns of meta matrix"
233
+ )
234
+
235
+ # Undo meta tensor elements reordering.
236
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
237
+ m, meta_ncols, meta_dtype, device
238
+ )
239
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
240
+
241
+ # Unpack sparse tensor back to original dense tensor, using
242
+ # information provided by meta tensor. Note that torch.float
243
+ # datatype is handled pretty much the same as
244
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
245
+ # value is encoded as if underlying 8 bytes contain four
246
+ # torch.half/torch.bfloat16 values, where either first two or last
247
+ # two are zeros.
248
+ meta_2 = torch.empty(
249
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
250
+ dtype=meta_dtype,
251
+ device=device,
252
+ )
253
+ if quadbits_per_meta_elem == 4:
254
+ meta_2[:, :, 0] = meta & 0b11
255
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
256
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
257
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
258
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
259
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
260
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
261
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
262
+ elif quadbits_per_meta_elem == 8:
263
+ meta_2[:, :, 0] = meta & 0b11
264
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
265
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
266
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
267
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
268
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
269
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
270
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
271
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
272
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
273
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
274
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
275
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
276
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
277
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
278
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
279
+
280
+ dense_offsets = meta_2.view(-1) + (
281
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
282
+ ).view(-1, 1).repeat(1, 2).view(-1)
283
+
284
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
285
+ if sparse.dtype != torch.float:
286
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
287
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
288
+ else:
289
+ dense.view(torch.half).scatter_(
290
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
291
+ )
292
+
293
+ return dense.view(m, 2 * k)
294
+
295
+
296
+ def mask_creator(tensor):
297
+ """
298
+ Class for creating N:M sparsity masks.
299
+ Masks will be created using the N:M ratio, where for every block of
300
+ M weights, N will be pruned based on ranked weight value. Each mask
301
+ will correspond to the given tensor.
302
+
303
+ :param N: The number of weights in a group to keep
304
+ :param M: The size of a weight group
305
+ """
306
+ N = 2
307
+ M = 4
308
+
309
+ mask = None
310
+ # for i, tensor in enumerate(tensors):
311
+ if tensor.numel() % M != 0:
312
+ raise ValueError(
313
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
314
+ )
315
+
316
+ num_groups = tensor.numel() // M
317
+
318
+ # N:M sparsity for linear layers
319
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
320
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
321
+
322
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
323
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
324
+
325
+ return mask
326
+
327
+
328
+ def inject_24(w, size_k, size_n):
329
+ assert w.shape == (size_k, size_n)
330
+
331
+ mask = mask_creator(w.t()).t().cuda().bool()
332
+
333
+ return (mask * w).contiguous(), mask.contiguous()
334
+
335
+
336
+ def check_24(w, num_rows_to_sample=50, _verbose=False):
337
+ BLOCK_SIZE = 4
338
+ MAX_NON_ZEROS = 2
339
+
340
+ w = w.t().contiguous()
341
+
342
+ print("check_24: w.shape = {}".format(w.shape))
343
+
344
+ num_rows, num_cols = w.shape
345
+ sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
346
+ if _verbose:
347
+ print(f"Sampled row idxs = {sampled_row_idxs}")
348
+
349
+ total_segments = 0
350
+ non_24_segments = 0
351
+ for i in sampled_row_idxs:
352
+ for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
353
+ total_segments += 1
354
+ block = w[i, j : j + BLOCK_SIZE]
355
+ num_nonzero = torch.count_nonzero(block)
356
+ if num_nonzero > MAX_NON_ZEROS:
357
+ print("i = {} j = {} block = {}".format(i, j, block))
358
+ non_24_segments += 1
359
+
360
+ print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
361
+
362
+
363
+ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
364
+ assert q_24.shape == (size_k, size_n)
365
+
366
+ # Remove bias to normalize over 0
367
+ q_24_no_zp = q_24 - wtype.bias
368
+
369
+ # Compress
370
+ q_24_no_zp = q_24_no_zp.t().contiguous()
371
+ q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
372
+ q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
373
+
374
+ # Restore bias
375
+ q_24_comp = q_24_no_zp_comp + wtype.bias
376
+
377
+ # Resize meta to its actual shape (without moving any data)
378
+ meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
379
+
380
+ return q_24_comp, meta
381
+
382
+
383
+ def get_scale_perms_24():
384
+ scale_perm: List[int] = []
385
+ for i in range(8):
386
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
387
+ scale_perm_single: List[int] = []
388
+ for i in range(8):
389
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
390
+ return scale_perm, scale_perm_single
391
+
392
+
393
+ def get_weight_perm_24(num_bits: int):
394
+ perm_list: List[int] = []
395
+ for i in range(32):
396
+ perm1: List[int] = []
397
+ col = i // 4
398
+ col_o = col // 2
399
+ for block in [0, 1]:
400
+ for row in [
401
+ 2 * (i % 4),
402
+ 2 * (i % 4) + 1,
403
+ 2 * (i % 4 + 4),
404
+ 2 * (i % 4 + 4) + 1,
405
+ ]:
406
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
407
+ for j in range(4):
408
+ perm_list.extend([p + 1 * j for p in perm1])
409
+ perm = numpy.array(perm_list)
410
+
411
+ if num_bits == 4:
412
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
413
+ elif num_bits == 8:
414
+ interleave = numpy.array([0, 2, 1, 3])
415
+ else:
416
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
417
+
418
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
419
+ perm = torch.from_numpy(perm)
420
+ return perm
421
+
422
+
423
+ def marlin_permute_scales_24(
424
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
425
+ ) -> torch.Tensor:
426
+
427
+ scale_perm, scale_perm_single = get_scale_perms_24()
428
+ if group_size < size_k and group_size != -1:
429
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
430
+ else:
431
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
432
+ s = s.reshape((-1, size_n)).contiguous()
433
+
434
+ return s
435
+
436
+
437
+ def marlin_24_quantize(
438
+ w: torch.Tensor,
439
+ quant_type: ScalarType,
440
+ group_size: int,
441
+ ):
442
+ size_k, size_n = w.shape
443
+
444
+ # Normalize group_size
445
+ if group_size == -1:
446
+ group_size = size_k
447
+ assert group_size <= size_k
448
+
449
+ # Inject 2:4 sparsity
450
+ w_24, mask_24 = inject_24(w, size_k, size_n)
451
+
452
+ # Quantize
453
+ w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
454
+ w_24, quant_type, group_size, act_order=False
455
+ )
456
+
457
+ # Compress quantized weight
458
+ q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
459
+ size_k_comp = size_k // 2
460
+
461
+ # Reformat to marlin
462
+ weight_perm = get_weight_perm_24(quant_type.size_bits)
463
+ marlin_24_q_w_comp = marlin_weights(
464
+ q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
465
+ )
466
+ marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
467
+
468
+ # Create result
469
+ res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
470
+ for i in range(len(res_list)):
471
+ res_list[i] = res_list[i].to(w.device)
472
+
473
+ return res_list
ext-torch/utils/marlin_utils_test_qqq.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ from .marlin_utils_test import marlin_permute_weights
7
+ from .quant_utils import get_pack_factor, qqq_quantize_weights
8
+
9
+
10
+ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
11
+ # Permute
12
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
13
+
14
+ # Pack
15
+ pack_factor = get_pack_factor(num_bits)
16
+ orig_device = q_w.device
17
+
18
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
19
+
20
+ q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
21
+ dtype=numpy.uint32)
22
+ if group_size == size_k:
23
+ for i in range(pack_factor):
24
+ q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
25
+ else:
26
+ for i in range(pack_factor):
27
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
28
+
29
+ q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
30
+
31
+ return q_packed
32
+
33
+
34
+ def get_qqq_scale_perms():
35
+ scale_perm: List[int] = []
36
+ for i in range(8):
37
+ scale_perm.extend([i + 8 * j for j in range(8)])
38
+ scale_perm_single: List[int] = []
39
+ for i in range(4):
40
+ scale_perm_single.extend(
41
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
42
+ return scale_perm, scale_perm_single
43
+
44
+
45
+ # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
46
+ def get_qqq_weight_perm(num_bits: int, quant_type: str):
47
+ perm_list: List[int] = []
48
+ for i in range(32):
49
+ perm1: List[int] = []
50
+ col = i // 4
51
+ for block in [0, 1]:
52
+ for row in [
53
+ 4 * (i % 4),
54
+ 4 * (i % 4) + 1,
55
+ 4 * (i % 4) + 2,
56
+ 4 * (i % 4) + 3,
57
+ ]:
58
+ perm1.append(16 * row + col + 8 * block)
59
+ for j in range(4):
60
+ perm_list.extend([p + 256 * j for p in perm1])
61
+
62
+ perm = numpy.array(perm_list)
63
+
64
+ assert quant_type in ["per-channel",
65
+ "per-group"], "not supported quantization type"
66
+ if num_bits == 4:
67
+ if quant_type == "per-channel":
68
+ interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
69
+ else:
70
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
71
+ else:
72
+ raise Exception("num_bits must be 4, got {}".format(num_bits))
73
+
74
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
75
+ perm = torch.from_numpy(perm)
76
+ return perm
77
+
78
+
79
+ def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
80
+ scale_perm, scale_perm_single = get_qqq_scale_perms()
81
+ if group_size < size_k and group_size != -1:
82
+ s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
83
+ s_channel = s_channel.reshape(
84
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
85
+ s_group = s_group.reshape((-1, size_n)).contiguous()
86
+ else:
87
+ s_channel = s_channel.reshape(
88
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
89
+ s_channel = s_channel.reshape((-1, size_n)).contiguous()
90
+
91
+ return s_group, s_channel
92
+
93
+
94
+ def marlin_qqq_quantize(
95
+ w: torch.Tensor,
96
+ num_bits: int,
97
+ group_size: int,
98
+ ):
99
+ size_k, size_n = w.shape
100
+
101
+ # Normalize group_size
102
+ if group_size == -1:
103
+ group_size = size_k
104
+ assert group_size <= size_k
105
+ quant_type = "per-channel" if group_size == size_k else "per-group"
106
+
107
+ # Quantize
108
+ w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
109
+ w, num_bits, group_size)
110
+
111
+ # Reformat to marlin_qqq
112
+ weight_perm = get_qqq_weight_perm(num_bits, quant_type)
113
+ marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
114
+ weight_perm, group_size)
115
+ marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
116
+ s_group, s_channel, size_k, size_n, group_size)
117
+
118
+ # Create result
119
+ res_list = [
120
+ w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
121
+ ]
122
+ for i in range(len(res_list)):
123
+ res_list[i] = res_list[i].to(w.device)
124
+
125
+ return res_list
ext-torch/utils/quant_utils.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is used for /tests and /benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy
6
+ import torch
7
+
8
+ from quantization.scalar_type import ScalarType, scalar_types
9
+
10
+ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
+ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
12
+
13
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
14
+
15
+ # Note: this is a hack. We should update each model to register the
16
+ # stacked params and get it from there instead in a future PR.
17
+ # fused_name: List[shard_name]
18
+ FUSED_LAYER_NAME_MAPPING = {
19
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
20
+ "gate_up_proj": ["gate_proj", "up_proj"],
21
+ }
22
+
23
+
24
+ def pack_quantized_values_into_int32(
25
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
26
+ ):
27
+ # move dim to pack to the end
28
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
29
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
30
+ w_q_perm = w_q.permute(perm)
31
+
32
+ pack_factor = 32 // wtype.size_bits
33
+ mask = (1 << wtype.size_bits) - 1
34
+
35
+ new_shape_perm = list(w_q_perm.shape)
36
+ assert w_q_perm.shape[-1] % pack_factor == 0
37
+ new_shape_perm[-1] //= pack_factor
38
+
39
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
40
+ for i in range(pack_factor):
41
+ res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
42
+
43
+ return res.permute(inv_perm)
44
+
45
+
46
+ def unpack_quantized_values_into_int32(
47
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
48
+ ):
49
+ # move dim to pack to the end
50
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
51
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
52
+ w_q_perm = w_q.permute(perm)
53
+
54
+ pack_factor = 32 // wtype.size_bits
55
+ mask = (1 << wtype.size_bits) - 1
56
+
57
+ new_shape_perm = list(w_q_perm.shape)
58
+ new_shape_perm[-1] *= pack_factor
59
+
60
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
61
+ for i in range(pack_factor):
62
+ res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
63
+
64
+ return res.permute(inv_perm)
65
+
66
+
67
+ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
68
+ # prefix: model.layers.0.self_attn.q_proj
69
+ # proj_name: q_proj
70
+ proj_name = prefix.split(".")[-1]
71
+ if proj_name in FUSED_LAYER_NAME_MAPPING:
72
+ shard_prefixes = [
73
+ prefix.replace(proj_name, shard_proj_name)
74
+ for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
75
+ ]
76
+
77
+ is_skipped = None
78
+ for shard_prefix in shard_prefixes:
79
+ is_shard_skipped = shard_prefix in ignored_layers
80
+
81
+ if is_skipped is None:
82
+ is_skipped = is_shard_skipped
83
+ elif is_shard_skipped != is_skipped:
84
+ raise ValueError(
85
+ f"Detected some but not all shards of {prefix} "
86
+ "are quantized. All shards of fused layers "
87
+ "to have the same precision."
88
+ )
89
+ else:
90
+ is_skipped = prefix in ignored_layers
91
+
92
+ assert is_skipped is not None
93
+ return is_skipped
94
+
95
+
96
+ def get_pack_factor(num_bits):
97
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
98
+ return 32 // num_bits
99
+
100
+
101
+ def permute_rows(
102
+ q_w: torch.Tensor,
103
+ w_ref: torch.Tensor,
104
+ group_size: int,
105
+ test_perm: Optional[torch.Tensor] = None,
106
+ ):
107
+ assert q_w.shape == w_ref.shape
108
+
109
+ orig_device = q_w.device
110
+ k_size, _ = q_w.shape
111
+
112
+ g_idx = torch.zeros((k_size,), dtype=torch.int32)
113
+ for i in range(k_size):
114
+ g_idx[i] = i // group_size
115
+
116
+ # Simulate act_order by doing a random permutation on K
117
+ rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
118
+
119
+ g_idx = g_idx[rand_perm].contiguous()
120
+ q_w = q_w[rand_perm, :].contiguous()
121
+ w_ref = w_ref[rand_perm, :].contiguous()
122
+
123
+ return (
124
+ w_ref.to(device=orig_device),
125
+ q_w.to(device=orig_device),
126
+ g_idx.to(device=orig_device),
127
+ rand_perm.to(device=orig_device),
128
+ )
129
+
130
+
131
+ def quantize_weights(
132
+ w: torch.Tensor,
133
+ quant_type: ScalarType,
134
+ group_size: Optional[int],
135
+ zero_points: bool = False,
136
+ ref_zero_points_after_scales: bool = False,
137
+ ):
138
+ assert (
139
+ quant_type.is_integer()
140
+ ), "Floating point quantization may work but has not been tested"
141
+ assert not zero_points or group_size is not None, (
142
+ "to have group zero points, group_size must be provided "
143
+ "(-1 group_size is channelwise)"
144
+ )
145
+
146
+ orig_device = w.device
147
+ orig_type = w.dtype
148
+ size_k, size_n = w.shape
149
+
150
+ assert w.is_floating_point(), "w must be float"
151
+
152
+ if group_size == -1:
153
+ group_size = size_k
154
+
155
+ # Reshape to [groupsize, -1]
156
+ if group_size is not None and group_size < size_k:
157
+ w = w.reshape((-1, group_size, size_n))
158
+ w = w.permute(1, 0, 2)
159
+ w = w.reshape((group_size, -1))
160
+
161
+ # Compute scale for each group
162
+ max_val = torch.max(w, 0, keepdim=True).values
163
+ min_val = torch.min(w, 0, keepdim=True).values
164
+
165
+ max_q_val = quant_type.max()
166
+ min_q_val = quant_type.min()
167
+
168
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
169
+ maybe_w_zp = None
170
+ if group_size is not None:
171
+ if zero_points:
172
+ assert not quant_type.is_signed() and quant_type.max() > 0
173
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
174
+ maybe_w_zp = (
175
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
176
+ )
177
+ else:
178
+ # If the bias is such that there are no possible negative/positive
179
+ # values, set the max value to inf to avoid divide by 0
180
+ w_s = torch.max(
181
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
182
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
183
+ )
184
+
185
+ # Quantize
186
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
187
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
188
+
189
+ # Compute ref (dequantized)
190
+ # For some kernels (namely Machete) the zero-points are applied after the
191
+ # scales are applied, for this case computing the reference in similar way
192
+ # allows us to use tighter error tolerances in our unit tests.
193
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
194
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
195
+ else:
196
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
197
+
198
+ if quant_type.has_bias():
199
+ w_q += quant_type.bias
200
+
201
+ # Restore original shapes
202
+ if group_size is not None and group_size < size_k:
203
+
204
+ def reshape_w(w):
205
+ w = w.reshape((group_size, -1, size_n))
206
+ w = w.permute(1, 0, 2)
207
+ w = w.reshape((size_k, size_n)).contiguous()
208
+ return w
209
+
210
+ w_q = reshape_w(w_q)
211
+ w_ref = reshape_w(w_ref)
212
+ w_s = w_s.reshape((-1, size_n)).contiguous()
213
+
214
+ if maybe_w_zp is not None:
215
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
216
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
217
+
218
+ return (
219
+ w_ref.to(device=orig_device),
220
+ w_q.to(device=orig_device),
221
+ w_s if group_size is not None else None,
222
+ maybe_w_zp,
223
+ )
224
+
225
+
226
+ def gptq_quantize_weights(
227
+ w: torch.Tensor,
228
+ quant_type: ScalarType,
229
+ group_size: int,
230
+ act_order: bool,
231
+ test_perm: Optional[torch.Tensor] = None,
232
+ ):
233
+ size_k, _ = w.shape
234
+
235
+ assert w.is_floating_point(), "w must be float"
236
+ assert (
237
+ quant_type in SUPPORTED_GPTQ_QUANT_TYPES
238
+ ), f"Unsupported gptq type = {quant_type}"
239
+ assert group_size in SUPPORTED_GROUP_SIZES + [
240
+ size_k
241
+ ], f"Unsupported groupsize = {group_size}"
242
+
243
+ w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
244
+
245
+ # Apply act_order
246
+ g_idx = torch.empty(0, dtype=torch.int, device=w.device)
247
+ rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
248
+ if act_order:
249
+ assert (
250
+ group_size < size_k
251
+ ), "For act_order, groupsize = {} must be less than size_k = {}".format(
252
+ group_size, size_k
253
+ )
254
+
255
+ w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
256
+
257
+ return w_ref, w_q, w_s, g_idx, rand_perm
258
+
259
+
260
+ # QQQ employs different quant schemes for per-group and
261
+ # per-channel quantization.
262
+ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
263
+ orig_device = w.device
264
+ size_k, size_n = w.shape
265
+
266
+ assert w.is_floating_point(), "w must be float"
267
+ assert (
268
+ num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
269
+ ), f"Unsupported num_bits = {num_bits}"
270
+ assert group_size in SUPPORTED_GROUP_SIZES + [
271
+ size_k
272
+ ], f"Unsupported groupsize = {group_size}"
273
+
274
+ if group_size == -1:
275
+ group_size = size_k
276
+ assert group_size <= size_k
277
+
278
+ if group_size < size_k:
279
+ # Reshape to [groupsize, -1]
280
+ w = w.reshape((-1, group_size, size_n))
281
+ w = w.permute(1, 0, 2)
282
+ w = w.reshape((group_size, -1))
283
+
284
+ max_q_val = 2**num_bits - 1
285
+ half_q_val = (max_q_val + 1) // 2
286
+
287
+ # Compute scale for each group
288
+ s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
289
+ s_group *= 2 / max_q_val # 2 => symmetric
290
+
291
+ # Quantize
292
+ q_w = torch.round(w / s_group).int()
293
+ q_w += half_q_val
294
+ q_w = torch.clamp(q_w, 0, max_q_val)
295
+ # Compute ref (dequantized)
296
+ w_ref = (q_w - half_q_val).half() * s_group
297
+
298
+ # Restore original shapes
299
+ def reshape_w(w):
300
+ w = w.reshape((group_size, -1, size_n))
301
+ w = w.permute(1, 0, 2)
302
+ w = w.reshape((size_k, size_n)).contiguous()
303
+ return w
304
+
305
+ q_w = reshape_w(q_w)
306
+ w_ref = reshape_w(w_ref)
307
+
308
+ # Compute int8 quantization scale for each channel
309
+ s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
310
+ s_channel /= 127.0
311
+ t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
312
+ w_ref = t_int8.half() * s_channel
313
+ s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
314
+
315
+ # Fuse scales
316
+ s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
317
+ dtype=torch.half
318
+ )
319
+ else:
320
+ max_q_val = 2 ** (num_bits - 1) - 1
321
+
322
+ # Compute scale for each channel
323
+ s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
324
+ s_channel /= max_q_val
325
+
326
+ # Quantize
327
+ q_w = torch.round(w / s_channel).int()
328
+ q_w = torch.clamp(q_w, -max_q_val, max_q_val)
329
+ # Compute ref (dequantized)
330
+ w_ref = q_w.half() * s_channel
331
+
332
+ s_group = torch.tensor([], dtype=torch.half)
333
+ # div 2 ** (8 - self.bits)) to offset right shift in unpacking
334
+ s_channel /= 2 ** (8 - num_bits)
335
+ s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
336
+
337
+ return (
338
+ w_ref.to(device=orig_device),
339
+ q_w.to(device=orig_device),
340
+ s_group.to(device=orig_device),
341
+ s_channel.to(device=orig_device),
342
+ )
343
+
344
+
345
+ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
346
+ orig_device = q_w.device
347
+
348
+ sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
349
+
350
+ g_idx = g_idx[sort_indices].contiguous()
351
+ q_w = q_w[sort_indices, :].contiguous()
352
+
353
+ return (
354
+ q_w.to(device=orig_device),
355
+ g_idx.to(device=orig_device),
356
+ sort_indices.to(device=orig_device),
357
+ )
358
+
359
+
360
+ def pack_rows(
361
+ q_w: torch.Tensor,
362
+ num_bits: int,
363
+ size_k: int,
364
+ size_n: int,
365
+ ):
366
+ assert q_w.shape == (size_k, size_n)
367
+
368
+ pack_factor = get_pack_factor(num_bits)
369
+ assert size_k % pack_factor == 0
370
+
371
+ orig_device = q_w.device
372
+
373
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
374
+
375
+ q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
376
+
377
+ for i in range(pack_factor):
378
+ q_res |= q_w[i::pack_factor, :] << num_bits * i
379
+
380
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
381
+ return q_res
382
+
383
+
384
+ def pack_cols(
385
+ q_w: torch.Tensor,
386
+ num_bits: int,
387
+ size_k: int,
388
+ size_n: int,
389
+ ):
390
+ assert q_w.shape == (size_k, size_n)
391
+
392
+ pack_factor = get_pack_factor(num_bits)
393
+ assert size_n % pack_factor == 0
394
+
395
+ orig_device = q_w.device
396
+
397
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
398
+
399
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
400
+
401
+ for i in range(pack_factor):
402
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
403
+
404
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
405
+ q_res = q_res.contiguous()
406
+
407
+ return q_res
408
+
409
+
410
+ def unpack_cols(
411
+ packed_q_w: torch.Tensor,
412
+ num_bits: int,
413
+ size_k: int,
414
+ size_n: int,
415
+ ):
416
+ pack_factor = get_pack_factor(num_bits)
417
+ assert size_n % pack_factor == 0
418
+ assert packed_q_w.shape == (
419
+ size_k,
420
+ size_n // pack_factor,
421
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
422
+ packed_q_w.shape, size_k, size_n, pack_factor
423
+ )
424
+
425
+ orig_device = packed_q_w.device
426
+
427
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
428
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
429
+
430
+ mask = (1 << num_bits) - 1
431
+ for i in range(pack_factor):
432
+ vals = packed_q_w_cpu & mask
433
+ packed_q_w_cpu >>= num_bits
434
+ q_res[:, i::pack_factor] = vals
435
+
436
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
437
+ q_res = q_res.contiguous()
438
+
439
+ return q_res
440
+
441
+
442
+ def gptq_pack(
443
+ q_w: torch.Tensor,
444
+ num_bits: int,
445
+ size_k: int,
446
+ size_n: int,
447
+ ):
448
+ return pack_rows(q_w, num_bits, size_k, size_n)
449
+
450
+
451
+ def awq_pack(
452
+ q_w: torch.Tensor,
453
+ num_bits: int,
454
+ size_k: int,
455
+ size_n: int,
456
+ ):
457
+ assert q_w.shape == (size_k, size_n)
458
+
459
+ # Interleave column dim (for the dequantize code) and pack it to int32
460
+ if num_bits == 4:
461
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
462
+ elif num_bits == 8:
463
+ interleave = numpy.array([0, 2, 1, 3])
464
+ else:
465
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
466
+
467
+ q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
468
+ q_w = q_w.reshape((-1, size_n)).contiguous()
469
+
470
+ return pack_cols(q_w, num_bits, size_k, size_n)
marlin/dense/LICENSE ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Contains code from https://github.com/IST-DASLab/marlin
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "{}"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright {yyyy} {name of copyright owner}
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
204
+
205
+ ------------------------------------------------------------------------------------
206
+
207
+ This product bundles various third-party components under other open source licenses.
208
+ This section summarizes those components and their licenses. See licenses/
209
+ for text of these licenses.
marlin/dense/common/base.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Modified by HandH1998
3
+ * Modified by Neural Magic
4
+ * Copyright (C) Marlin.2024 Elias Frantar
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
22
+
23
+ // Instances of `Vec` are used to organize groups of >>registers<<, as needed
24
+ // for instance as inputs to tensor core operations. Consequently, all
25
+ // corresponding index accesses must be compile-time constants, which is why we
26
+ // extensively use `#pragma unroll` throughout the kernel code to guarantee
27
+ // this.
28
+ template <typename T, int n>
29
+ struct Vec {
30
+ T elems[n];
31
+ __device__ T& operator[](int i) { return elems[i]; }
32
+ };
marlin/dense/common/mem.h ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Modified by HandH1998
3
+ * Modified by Neural Magic
4
+ * Copyright (C) Marlin.2024 Elias Frantar
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ // Predicated asynchronous global->shared copy; used for inputs A where we apply
22
+ // predication to handle batchsizes that are not multiples of 16.
23
+ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
24
+ bool pred = true) {
25
+ const int BYTES = 16;
26
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
27
+ asm volatile(
28
+ "{\n"
29
+ " .reg .pred p;\n"
30
+ " setp.ne.b32 p, %0, 0;\n"
31
+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
32
+ "}\n" ::"r"((int)pred),
33
+ "r"(smem), "l"(glob_ptr), "n"(BYTES));
34
+ }
35
+
36
+ // Asynchronous global->shared copy
37
+ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
38
+ const int BYTES = 16;
39
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
40
+ asm volatile(
41
+ "{\n"
42
+ " cp.async.cg.shared.global [%0], [%1], %2;\n"
43
+ "}\n" ::"r"(smem),
44
+ "l"(glob_ptr), "n"(BYTES));
45
+ }
46
+
47
+ // Async copy fence.
48
+ __device__ inline void cp_async_fence() {
49
+ asm volatile("cp.async.commit_group;\n" ::);
50
+ }
51
+
52
+ // Wait until at most `n` async copy stages are still pending.
53
+ template <int n>
54
+ __device__ inline void cp_async_wait() {
55
+ asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
56
+ }
57
+
58
+ // Wait until barrier reaches `count`, then lock for current threadblock.
59
+ __device__ inline void barrier_acquire(int* lock, int count) {
60
+ if (threadIdx.x == 0) {
61
+ int state = -1;
62
+ do
63
+ // Guarantee that subsequent writes by this threadblock will be visible
64
+ // globally.
65
+ asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
66
+ : "=r"(state)
67
+ : "l"(lock));
68
+ while (state != count);
69
+ }
70
+ __syncthreads();
71
+ }
72
+
73
+ // Release barrier and increment visitation count.
74
+ __device__ inline void barrier_release(int* lock, bool reset = false) {
75
+ __syncthreads();
76
+ if (threadIdx.x == 0) {
77
+ if (reset) {
78
+ lock[0] = 0;
79
+ return;
80
+ }
81
+ int val = 1;
82
+ // Make sure that all writes since acquiring this barrier are visible
83
+ // globally, while releasing the barrier.
84
+ asm volatile("fence.acq_rel.gpu;\n");
85
+ asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
86
+ :
87
+ : "l"(lock), "r"(val));
88
+ }
89
+ }
marlin/dense/marlin_cuda_kernel.cu ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Modified by Neural Magic
3
+ * Copyright (C) Marlin.2024 Elias Frantar
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include <torch/all.h>
19
+
20
+ #include <ATen/cuda/CUDAContext.h>
21
+ #include <c10/cuda/CUDAGuard.h>
22
+ #include <cuda.h>
23
+ #include <cuda_fp16.h>
24
+ #include <cuda_runtime.h>
25
+
26
+ #include <iostream>
27
+
28
+ #include "common/base.h"
29
+
30
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
31
+ #include "common/mem.h"
32
+ #endif
33
+
34
+ template <typename T>
35
+ inline std::string str(T x) {
36
+ return std::to_string(x);
37
+ }
38
+
39
+ namespace marlin_dense {
40
+
41
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
42
+
43
+ using I4 = Vec<int, 4>;
44
+ // Matrix fragments for tensor core instructions; their precise layout is
45
+ // documented here:
46
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
47
+ using FragA = Vec<half2, 4>;
48
+ using FragB = Vec<half2, 2>;
49
+ using FragC = Vec<float, 4>;
50
+ using FragS = Vec<half2, 1>; // quantization scales
51
+
52
+ // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
53
+ // output/accumulation.
54
+ __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
55
+ FragC& frag_c) {
56
+ const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
57
+ const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
58
+ float* c = reinterpret_cast<float*>(&frag_c);
59
+ asm volatile(
60
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
61
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
62
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
63
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
64
+ "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
65
+ }
66
+
67
+ // Instruction for loading a full 16x16 matrix fragment of operand A from shared
68
+ // memory, directly in tensor core layout.
69
+ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
70
+ uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
71
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
72
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
73
+ : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
74
+ : "r"(smem));
75
+ }
76
+
77
+ // Lookup-table based 3-input logical operation; explicitly used for
78
+ // dequantization as the compiler does not seem to automatically recognize it in
79
+ // all cases.
80
+ template <int lut>
81
+ __device__ inline int lop3(int a, int b, int c) {
82
+ int res;
83
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
84
+ : "=r"(res)
85
+ : "r"(a), "r"(b), "r"(c), "n"(lut));
86
+ return res;
87
+ }
88
+
89
+ // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
90
+ // values. We mostly follow the strategy in the link below, with some small
91
+ // changes:
92
+ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
93
+ __device__ inline FragB dequant(int q) {
94
+ const int LO = 0x000f000f;
95
+ const int HI = 0x00f000f0;
96
+ const int EX = 0x64006400;
97
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
98
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
99
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
100
+ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
101
+ // directly into `SUB` and `ADD`.
102
+ const int SUB = 0x64086408;
103
+ const int MUL = 0x2c002c00;
104
+ const int ADD = 0xd480d480;
105
+ FragB frag_b;
106
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
107
+ *reinterpret_cast<const half2*>(&SUB));
108
+ frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
109
+ *reinterpret_cast<const half2*>(&MUL),
110
+ *reinterpret_cast<const half2*>(&ADD));
111
+ return frag_b;
112
+ }
113
+
114
+ // Multiply dequantized values by the corresponding quantization scale; used
115
+ // only for grouped quantization.
116
+ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
117
+ half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
118
+ frag_b[0] = __hmul2(frag_b[0], s);
119
+ frag_b[1] = __hmul2(frag_b[1], s);
120
+ }
121
+
122
+ template <const int threads, // number of threads in a threadblock
123
+ const int thread_m_blocks, // number of 16x16 blocks in the m
124
+ // dimension (batchsize) of the
125
+ // threadblock
126
+ const int thread_n_blocks, // same for n dimension (output)
127
+ const int thread_k_blocks, // same for k dimension (reduction)
128
+ const int stages, // number of stages for the async global->shared
129
+ // fetch pipeline
130
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
131
+ // with a separate quantization scale
132
+ >
133
+ __global__ void Marlin(
134
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
135
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
136
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
137
+ const int4* __restrict__ s, // fp16 quantization scales of shape
138
+ // (k/groupsize)xn
139
+ int prob_m, // batch dimension m
140
+ int prob_n, // output dimension n
141
+ int prob_k, // reduction dimension k
142
+ int* locks // extra global storage for barrier synchronization
143
+ ) {
144
+ // Each threadblock processes one "stripe" of the B matrix with (roughly) the
145
+ // same size, which might involve multiple column "slices" (of width 16 *
146
+ // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
147
+ // example:
148
+ // 0 1 3
149
+ // 0 2 3
150
+ // 1 2 4
151
+ // While this kind of partitioning makes things somewhat more complicated, it
152
+ // ensures good utilization of all SMs for many kinds of shape and GPU
153
+ // configurations, while requiring as few slow global cross-threadblock
154
+ // reductions as possible.
155
+
156
+ // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
157
+ // better partitioning with less reductions
158
+ int parallel = 1;
159
+ if (prob_m > 16 * thread_m_blocks) {
160
+ parallel = prob_m / (16 * thread_m_blocks);
161
+ prob_m = 16 * thread_m_blocks;
162
+ }
163
+
164
+ int k_tiles = prob_k / 16 / thread_k_blocks;
165
+ int n_tiles = prob_n / 16 / thread_n_blocks;
166
+ int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
167
+ // Ensure that the number of tiles in each stripe is a multiple of the
168
+ // groupsize; this avoids an annoying special case where a stripe starts in
169
+ // the middle of group.
170
+ if (group_blocks != -1)
171
+ iters = (group_blocks / thread_k_blocks) *
172
+ ceildiv(iters, (group_blocks / thread_k_blocks));
173
+
174
+ int slice_row = (iters * blockIdx.x) % k_tiles;
175
+ int slice_col_par = (iters * blockIdx.x) / k_tiles;
176
+ int slice_col = slice_col_par;
177
+ int slice_iters; // number of threadblock tiles in the current slice
178
+ int slice_count =
179
+ 0; // total number of active threadblocks in the current slice
180
+ int slice_idx; // index of threadblock in current slice; numbered bottom to
181
+ // top
182
+
183
+ // We can easily implement parallel problem execution by just remapping
184
+ // indices and advancing global pointers
185
+ if (slice_col_par >= n_tiles) {
186
+ A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
187
+ C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
188
+ locks += (slice_col_par / n_tiles) * n_tiles;
189
+ slice_col = slice_col_par % n_tiles;
190
+ }
191
+
192
+ // Compute all information about the current slice which is required for
193
+ // synchronization.
194
+ auto init_slice = [&]() {
195
+ slice_iters =
196
+ iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
197
+ if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
198
+ if (slice_iters == 0) return;
199
+ if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
200
+ slice_count = 1;
201
+ slice_idx = 0;
202
+ int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
203
+ if (col_first <= k_tiles * (slice_col_par + 1)) {
204
+ int col_off = col_first - k_tiles * slice_col_par;
205
+ slice_count = ceildiv(k_tiles - col_off, iters);
206
+ if (col_off > 0) slice_count++;
207
+ int delta_first = iters * blockIdx.x - col_first;
208
+ if (delta_first < 0 || (col_off == 0 && delta_first == 0))
209
+ slice_idx = slice_count - 1;
210
+ else {
211
+ slice_idx = slice_count - 1 - delta_first / iters;
212
+ if (col_off > 0) slice_idx--;
213
+ }
214
+ }
215
+ if (slice_col == n_tiles) {
216
+ A += 16 * thread_m_blocks * prob_k / 8;
217
+ C += 16 * thread_m_blocks * prob_n / 8;
218
+ locks += n_tiles;
219
+ slice_col = 0;
220
+ }
221
+ };
222
+ init_slice();
223
+
224
+ int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
225
+ // We typically use `constexpr` to indicate that this value is a compile-time
226
+ // constant
227
+ constexpr int a_sh_stride =
228
+ 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
229
+ constexpr int a_gl_rd_delta_o =
230
+ 16 * thread_k_blocks /
231
+ 8; // delta between subsequent A tiles in global memory
232
+ int a_gl_rd_delta_i =
233
+ a_gl_stride *
234
+ (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
235
+ constexpr int a_sh_wr_delta =
236
+ a_sh_stride *
237
+ (threads / a_gl_rd_delta_o); // between shared memory writes
238
+ constexpr int a_sh_rd_delta_o =
239
+ 2 * ((threads / 32) /
240
+ (thread_n_blocks / 4)); // between shared memory tile reads
241
+ constexpr int a_sh_rd_delta_i =
242
+ a_sh_stride * 16; // within a shared memory tile
243
+ constexpr int a_sh_stage =
244
+ a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
245
+ constexpr int a_sh_wr_iters =
246
+ ceildiv(a_sh_stage,
247
+ a_sh_wr_delta); // number of shared write iterations for a tile
248
+
249
+ int b_gl_stride = 16 * prob_n / 32;
250
+ constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
251
+ int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
252
+ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
253
+ constexpr int b_sh_wr_delta = threads;
254
+ constexpr int b_sh_rd_delta = threads;
255
+ constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
256
+ constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
257
+
258
+ int s_gl_stride = prob_n / 8;
259
+ constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
260
+ constexpr int s_sh_stage = s_sh_stride;
261
+ int s_gl_rd_delta = s_gl_stride;
262
+
263
+ // Global A read index of current thread.
264
+ int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
265
+ (threadIdx.x % a_gl_rd_delta_o);
266
+ a_gl_rd += a_gl_rd_delta_o * slice_row;
267
+ // Shared write index of current thread.
268
+ int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
269
+ (threadIdx.x % a_gl_rd_delta_o);
270
+ // Shared read index.
271
+ int a_sh_rd =
272
+ a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
273
+ a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
274
+
275
+ int b_gl_rd =
276
+ b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
277
+ b_gl_rd += b_sh_stride * slice_col;
278
+ b_gl_rd += b_gl_rd_delta_o * slice_row;
279
+ int b_sh_wr = threadIdx.x;
280
+ int b_sh_rd = threadIdx.x;
281
+
282
+ int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
283
+ s_sh_stride * slice_col + threadIdx.x;
284
+ int s_sh_wr = threadIdx.x;
285
+ int s_sh_rd;
286
+ // We use a different scale layout for grouped and column-wise quantization as
287
+ // we scale a `half2` tile in column-major layout in the former and in
288
+ // row-major in the latter case.
289
+ if (group_blocks != -1)
290
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
291
+ (threadIdx.x % 32) / 4;
292
+ else
293
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
294
+ (threadIdx.x % 32) % 4;
295
+
296
+ // Precompute which thread should not read memory in which iterations; this is
297
+ // needed if there are more threads than required for a certain tilesize or
298
+ // when the batchsize is not a multiple of 16.
299
+ bool a_sh_wr_pred[a_sh_wr_iters];
300
+ #pragma unroll
301
+ for (int i = 0; i < a_sh_wr_iters; i++)
302
+ a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
303
+ bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
304
+
305
+ // To ensure that writing and reading A tiles to/from shared memory, the
306
+ // latter in fragment format, is fully bank conflict free, we need to use a
307
+ // rather fancy XOR-based layout. The key here is that neither reads nor
308
+ // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
309
+ // same shared memory banks. Further, it seems (based on NSight-Compute) that
310
+ // each warp must also write a consecutive memory segment?
311
+ auto transform_a = [&](int i) {
312
+ int row = i / a_gl_rd_delta_o;
313
+ return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
314
+ };
315
+ // Since the computation of this remapping is non-trivial and, due to our main
316
+ // loop unrolls, all shared memory accesses are static, we simply precompute
317
+ // both transformed reads and writes.
318
+ int a_sh_wr_trans[a_sh_wr_iters];
319
+ #pragma unroll
320
+ for (int i = 0; i < a_sh_wr_iters; i++)
321
+ a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
322
+ int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
323
+ #pragma unroll
324
+ for (int i = 0; i < b_sh_wr_iters; i++) {
325
+ #pragma unroll
326
+ for (int j = 0; j < thread_m_blocks; j++)
327
+ a_sh_rd_trans[i][j] =
328
+ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
329
+ }
330
+
331
+ // Since B-accesses have non-constant stride they have to be computed at
332
+ // runtime; we break dependencies between subsequent accesses with a tile by
333
+ // maintining multiple pointers (we have enough registers), a tiny
334
+ // optimization.
335
+ const int4* B_ptr[b_sh_wr_iters];
336
+ #pragma unroll
337
+ for (int i = 0; i < b_sh_wr_iters; i++)
338
+ B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
339
+
340
+ extern __shared__ int4 sh[];
341
+ // Shared memory storage for global fetch pipelines.
342
+ int4* sh_a = sh;
343
+ int4* sh_b = sh_a + (stages * a_sh_stage);
344
+ int4* sh_s = sh_b + (stages * b_sh_stage);
345
+ // Register storage for double buffer of shared memory reads.
346
+ FragA frag_a[2][thread_m_blocks];
347
+ I4 frag_b_quant[2];
348
+ FragC frag_c[thread_m_blocks][4][2];
349
+ FragS frag_s[2][4];
350
+
351
+ // Zero accumulators.
352
+ auto zero_accums = [&]() {
353
+ #pragma unroll
354
+ for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
355
+ reinterpret_cast<float*>(frag_c)[i] = 0;
356
+ };
357
+
358
+ // Asynchronously fetch the next A, B and s tile from global to the next
359
+ // shared memory pipeline location.
360
+ auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
361
+ if (pred) {
362
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
363
+ #pragma unroll
364
+ for (int i = 0; i < a_sh_wr_iters; i++) {
365
+ cp_async4_pred(
366
+ &sh_a_stage[a_sh_wr_trans[i]],
367
+ &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
368
+ a_sh_wr_pred[i]);
369
+ }
370
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
371
+ #pragma unroll
372
+ for (int i = 0; i < b_sh_wr_iters; i++) {
373
+ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
374
+ B_ptr[i] += b_gl_rd_delta_o;
375
+ }
376
+ // Only fetch scales if this tile starts a new group
377
+ if constexpr (group_blocks != -1) {
378
+ // This assumes group_blocks >= thread_k_blocks
379
+ // and would need to be modified to support smaller groups.
380
+ static_assert(group_blocks >= thread_k_blocks);
381
+ if (pipe % (group_blocks / thread_k_blocks) == 0) {
382
+ int4* sh_s_stage = sh_s + s_sh_stage * pipe;
383
+ if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
384
+ s_gl_rd += s_gl_rd_delta;
385
+ }
386
+ }
387
+ }
388
+ // Insert a fence even when we are winding down the pipeline to ensure that
389
+ // waiting is also correct at this point.
390
+ cp_async_fence();
391
+ };
392
+
393
+ // Wait until the next thread tile has been loaded to shared memory.
394
+ auto wait_for_stage = [&]() {
395
+ // We only have `stages - 2` active fetches since we are double buffering
396
+ // and can only issue the next fetch when it is guaranteed that the previous
397
+ // shared memory load is fully complete (as it may otherwise be
398
+ // overwritten).
399
+ cp_async_wait<stages - 2>();
400
+ __syncthreads();
401
+ };
402
+
403
+ // Load the next sub-tile from the current location in the shared memory pipe
404
+ // into the current register buffer.
405
+ auto fetch_to_registers = [&](int k, int pipe) {
406
+ // It may seem inefficient that we reload the groups for every sub-tile;
407
+ // however, this does not seem to be a significant bottleneck, while some
408
+ // theoretically better attempts have lead to bad instruction ordering by
409
+ // the compiler and correspondingly a noticeable drop in performance.
410
+ if constexpr (group_blocks != -1) {
411
+ // This assumes group_blocks >= thread_k_blocks
412
+ // and would need to be modified to support smaller groups.
413
+ static_assert(group_blocks >= thread_k_blocks);
414
+ int4* sh_s_stage =
415
+ sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
416
+ (pipe / (group_blocks / thread_k_blocks)));
417
+ reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
418
+ }
419
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
420
+ #pragma unroll
421
+ for (int i = 0; i < thread_m_blocks; i++)
422
+ ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
423
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
424
+ frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
425
+ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
426
+ };
427
+
428
+ // Execute the actual tensor core matmul of a sub-tile.
429
+ auto matmul = [&](int k) {
430
+ // We have the m dimension as the inner loop in order to encourage overlapping
431
+ // dequantization and matmul operations.
432
+ #pragma unroll
433
+ for (int j = 0; j < 4; j++) {
434
+ int b_quant = frag_b_quant[k % 2][j];
435
+ int b_quant_shift = b_quant >> 8;
436
+ FragB frag_b0 = dequant(b_quant);
437
+ // If there are no groups, we can just scale the final output once and can
438
+ // avoid doing so for each weight.
439
+ if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
440
+ FragB frag_b1 = dequant(b_quant_shift);
441
+ if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
442
+ #pragma unroll
443
+ for (int i = 0; i < thread_m_blocks; i++) {
444
+ mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
445
+ mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
446
+ }
447
+ }
448
+ };
449
+
450
+ // Since we slice across the k dimension of a tile in order to increase the
451
+ // number of warps while keeping the n dimension of a tile reasonable, we have
452
+ // multiple warps that accumulate their partial sums of the same output
453
+ // location; which we have to reduce over in the end. We do in shared memory.
454
+ auto thread_block_reduce = [&]() {
455
+ constexpr int red_off = threads / b_sh_stride / 2;
456
+ if (red_off >= 1) {
457
+ int red_idx = threadIdx.x / b_sh_stride;
458
+ constexpr int red_sh_stride = b_sh_stride * 4 * 2;
459
+ constexpr int red_sh_delta = b_sh_stride;
460
+ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
461
+ (threadIdx.x % b_sh_stride);
462
+
463
+ // Parallel logarithmic shared memory reduction. We make sure to avoid any
464
+ // unnecessary read or write iterations, e.g., for two warps we write only
465
+ // once by warp 1 and read only once by warp 0.
466
+
467
+ #pragma unroll
468
+ for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
469
+ #pragma unroll
470
+ for (int i = red_off; i > 0; i /= 2) {
471
+ if (i <= red_idx && red_idx < 2 * i) {
472
+ #pragma unroll
473
+ for (int j = 0; j < 4 * 2; j++) {
474
+ int red_sh_wr =
475
+ red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
476
+ if (i < red_off) {
477
+ float* c_rd =
478
+ reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
479
+ float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
480
+ #pragma unroll
481
+ for (int k = 0; k < 4; k++)
482
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
483
+ c_rd[k] + c_wr[k];
484
+ }
485
+ sh[red_sh_wr] =
486
+ reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
487
+ }
488
+ }
489
+ __syncthreads();
490
+ }
491
+ if (red_idx == 0) {
492
+ #pragma unroll
493
+ for (int i = 0; i < 4 * 2; i++) {
494
+ float* c_rd =
495
+ reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
496
+ #pragma unroll
497
+ for (int j = 0; j < 4; j++)
498
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
499
+ c_rd[j];
500
+ }
501
+ }
502
+ __syncthreads();
503
+ }
504
+ }
505
+ };
506
+
507
+ // Since multiple threadblocks may process parts of the same column slice, we
508
+ // finally have to globally reduce over the results. As the striped
509
+ // partitioning minimizes the number of such reductions and our outputs are
510
+ // usually rather small, we perform this reduction serially in L2 cache.
511
+ auto global_reduce = [&](bool first = false, bool last = false) {
512
+ // We are very careful here to reduce directly in the output buffer to
513
+ // maximize L2 cache utilization in this step. To do this, we write out
514
+ // results in FP16 (but still reduce with FP32 compute).
515
+ constexpr int active_threads = 32 * thread_n_blocks / 4;
516
+ if (threadIdx.x < active_threads) {
517
+ int c_gl_stride = prob_n / 8;
518
+ int c_gl_wr_delta_o = 8 * c_gl_stride;
519
+ int c_gl_wr_delta_i = 4 * (active_threads / 32);
520
+ int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
521
+ 4 * (threadIdx.x / 32) + threadIdx.x % 4;
522
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
523
+ constexpr int c_sh_wr_delta = active_threads;
524
+ int c_sh_wr = threadIdx.x;
525
+
526
+ int row = (threadIdx.x % 32) / 4;
527
+
528
+ if (!first) {
529
+ // Interestingly, doing direct global accesses here really seems to mess up
530
+ // the compiler and lead to slowdowns, hence we also use async-copies even
531
+ // though these fetches are not actually asynchronous.
532
+ #pragma unroll
533
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
534
+ cp_async4_pred(
535
+ &sh[c_sh_wr + c_sh_wr_delta * i],
536
+ &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
537
+ c_gl_wr_delta_i * (i % 2)],
538
+ i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
539
+ }
540
+ cp_async_fence();
541
+ cp_async_wait<0>();
542
+ }
543
+
544
+ #pragma unroll
545
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
546
+ if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
547
+ if (!first) {
548
+ int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
549
+ #pragma unroll
550
+ for (int j = 0; j < 2 * 4; j++) {
551
+ reinterpret_cast<float*>(
552
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
553
+ __half2float(reinterpret_cast<__half*>(&c_red)[j]);
554
+ }
555
+ }
556
+ if (!last) {
557
+ int4 c;
558
+ #pragma unroll
559
+ for (int j = 0; j < 2 * 4; j++) {
560
+ reinterpret_cast<__half*>(&c)[j] =
561
+ __float2half(reinterpret_cast<float*>(
562
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
563
+ }
564
+ C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
565
+ c;
566
+ }
567
+ }
568
+ }
569
+ }
570
+ };
571
+
572
+ // Write out the reduce final result in the correct layout. We only actually
573
+ // reshuffle matrix fragments in this step, the reduction above is performed
574
+ // in fragment layout.
575
+ auto write_result = [&]() {
576
+ int c_gl_stride = prob_n / 8;
577
+ constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
578
+ int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
579
+ constexpr int c_sh_rd_delta =
580
+ c_sh_stride * (threads / (2 * thread_n_blocks));
581
+
582
+ int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
583
+ (threadIdx.x % (2 * thread_n_blocks));
584
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
585
+ int c_sh_wr =
586
+ (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
587
+ c_sh_wr += 32 * (threadIdx.x / 32);
588
+ int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
589
+ (threadIdx.x % (2 * thread_n_blocks));
590
+
591
+ int c_gl_wr_end = c_gl_stride * prob_m;
592
+
593
+ // We first reorder in shared memory to guarantee the most efficient final
594
+ // global write patterns
595
+ auto write = [&](int idx, float c0, float c1, FragS& s) {
596
+ half2 res = __halves2half2(__float2half(c0), __float2half(c1));
597
+ if (group_blocks ==
598
+ -1) // for per-column quantization we finally apply the scale here
599
+ res = __hmul2(res, s[0]);
600
+ ((half2*)sh)[idx] = res;
601
+ };
602
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
603
+ #pragma unroll
604
+ for (int i = 0; i < thread_m_blocks; i++) {
605
+ #pragma unroll
606
+ for (int j = 0; j < 4; j++) {
607
+ int wr = c_sh_wr + 8 * j;
608
+ write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
609
+ frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
610
+ write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
611
+ frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
612
+ write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
613
+ frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
614
+ write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
615
+ frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
616
+ }
617
+ c_sh_wr += 16 * (4 * c_sh_stride);
618
+ }
619
+ }
620
+ __syncthreads();
621
+
622
+ #pragma unroll
623
+ for (int i = 0;
624
+ i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
625
+ i++) {
626
+ if (c_gl_wr < c_gl_wr_end) {
627
+ C[c_gl_wr] = sh[c_sh_rd];
628
+ c_gl_wr += c_gl_wr_delta;
629
+ c_sh_rd += c_sh_rd_delta;
630
+ }
631
+ }
632
+ };
633
+
634
+ // Start global fetch and register load pipelines.
635
+ auto start_pipes = [&]() {
636
+ #pragma unroll
637
+ for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
638
+ zero_accums();
639
+ wait_for_stage();
640
+ fetch_to_registers(0, 0);
641
+ a_gl_rd += a_gl_rd_delta_o * (stages - 1);
642
+ };
643
+ start_pipes();
644
+
645
+ // Main loop.
646
+ while (slice_iters) {
647
+ // We unroll over both the global fetch and the register load pipeline to
648
+ // ensure all shared memory accesses are static. Note that both pipelines have
649
+ // even length meaning that the next iteration will always start at index 0.
650
+ #pragma unroll
651
+ for (int pipe = 0; pipe < stages;) {
652
+ #pragma unroll
653
+ for (int k = 0; k < b_sh_wr_iters; k++) {
654
+ fetch_to_registers(k + 1, pipe % stages);
655
+ if (k == b_sh_wr_iters - 2) {
656
+ fetch_to_shared((pipe + stages - 1) % stages, pipe,
657
+ slice_iters >= stages);
658
+ pipe++;
659
+ wait_for_stage();
660
+ }
661
+ matmul(k);
662
+ }
663
+ slice_iters--;
664
+ if (slice_iters == 0) break;
665
+ }
666
+ a_gl_rd += a_gl_rd_delta_o * stages;
667
+
668
+ // Process results and, if necessary, proceed to the next column slice.
669
+ // While this pattern may not be the most readable, other ways of writing
670
+ // the loop seemed to noticeably worse performance after compilation.
671
+ if (slice_iters == 0) {
672
+ cp_async_wait<0>();
673
+ bool last = slice_idx == slice_count - 1;
674
+ // For per-column scales, we only fetch them here in the final step before
675
+ // write-out
676
+ if (group_blocks == -1 && last) {
677
+ if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
678
+ cp_async_fence();
679
+ }
680
+ thread_block_reduce();
681
+ if (group_blocks == -1 && last) {
682
+ cp_async_wait<0>();
683
+ __syncthreads();
684
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
685
+ reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
686
+ reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
687
+ }
688
+ }
689
+ if (slice_count > 1) { // only globally reduce if there is more than one
690
+ // block in a slice
691
+ barrier_acquire(&locks[slice_col], slice_idx);
692
+ global_reduce(slice_idx == 0, last);
693
+ barrier_release(&locks[slice_col], last);
694
+ }
695
+ if (last) // only the last block in a slice actually writes the result
696
+ write_result();
697
+ slice_row = 0;
698
+ slice_col_par++;
699
+ slice_col++;
700
+ init_slice();
701
+ if (slice_iters) {
702
+ a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
703
+ (threadIdx.x % a_gl_rd_delta_o);
704
+ #pragma unroll
705
+ for (int i = 0; i < b_sh_wr_iters; i++)
706
+ B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
707
+ if (slice_col == 0) {
708
+ #pragma unroll
709
+ for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
710
+ }
711
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
712
+ start_pipes();
713
+ }
714
+ }
715
+ }
716
+ }
717
+
718
+ #else
719
+
720
+ template <const int threads, // number of threads in a threadblock
721
+ const int thread_m_blocks, // number of 16x16 blocks in the m
722
+ // dimension (batchsize) of the
723
+ // threadblock
724
+ const int thread_n_blocks, // same for n dimension (output)
725
+ const int thread_k_blocks, // same for k dimension (reduction)
726
+ const int stages, // number of stages for the async global->shared
727
+ // fetch pipeline
728
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
729
+ // with a separate quantization scale
730
+ >
731
+ __global__ void Marlin(
732
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
733
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
734
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
735
+ const int4* __restrict__ s, // fp16 quantization scales of shape
736
+ // (k/groupsize)xn
737
+ int prob_m, // batch dimension m
738
+ int prob_n, // output dimension n
739
+ int prob_k, // reduction dimension k
740
+ int* locks // extra global storage for barrier synchronization
741
+ ) {
742
+ // Marlin is not implemented yet for SM < 8.0
743
+ assert(false);
744
+ return;
745
+ }
746
+
747
+ #endif
748
+
749
+ // 8 warps are a good choice since every SM has 4 schedulers and having more
750
+ // than 1 warp per schedule allows some more latency hiding. At the same time,
751
+ // we want relatively few warps to have many registers per warp and small tiles.
752
+ const int USER_THREADS =
753
+ 256; // Note: This is only used with user-provided thread_k/n
754
+ const int STAGES = 4; // 4 pipeline stages fit into shared memory
755
+ const int SHARED_MEM =
756
+ 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
757
+
758
+ static constexpr int min_thread_n = 64;
759
+ static constexpr int min_thread_k = 64;
760
+
761
+ static constexpr int tile_size = 16;
762
+ static constexpr int max_par = 16;
763
+
764
+ static constexpr int pack_factor_4bit =
765
+ 8; // We have 8 4-bit vals inside a 32 bit
766
+
767
+ #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
768
+ GROUP_BLOCKS, NUM_THREADS) \
769
+ else if (thread_m_blocks == THREAD_M_BLOCKS && \
770
+ thread_n_blocks == THREAD_N_BLOCKS && \
771
+ thread_k_blocks == THREAD_K_BLOCKS && \
772
+ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
773
+ cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
774
+ THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
775
+ cudaFuncAttributeMaxDynamicSharedMemorySize, \
776
+ SHARED_MEM); \
777
+ Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
778
+ STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
779
+ A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \
780
+ }
781
+
782
+ typedef struct {
783
+ int thread_k;
784
+ int thread_n;
785
+ int num_threads;
786
+ } thread_config_t;
787
+
788
+ thread_config_t small_batch_thread_configs[] = {
789
+ // Ordered by priority
790
+
791
+ // thread_k, thread_n, num_threads
792
+ {128, 128, 256}, // Default
793
+ {128, 64, 128}, // Reduce N 2X, same K
794
+ {64, 256, 256}, // Reduce K 2X, increase N 2X
795
+ {64, 128, 128}, // Reduce K 2X, same N
796
+ };
797
+
798
+ thread_config_t large_batch_thread_configs[] = {
799
+ // Ordered by priority
800
+
801
+ // thread_k, thread_n, num_threads
802
+ {64, 256, 256}, // Default
803
+ {128, 128, 256}, // Reduce N 2X, increase K 2X
804
+ {64, 128, 128}, // Reduce N 2X, same K
805
+ {128, 64, 128}, // Reduce N 4X, increase K 2X
806
+ };
807
+
808
+ bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
809
+ int prob_k) {
810
+ // Sanity
811
+ if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
812
+ th_config.num_threads == -1) {
813
+ return false;
814
+ }
815
+
816
+ // Verify K/N are divisible by thread K/N
817
+ if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
818
+ return false;
819
+ }
820
+
821
+ // thread_k can be only 128 or 64 (because it must be less than groupsize
822
+ // which is 128)
823
+ if (th_config.thread_k != 128 && th_config.thread_k != 64) {
824
+ return false;
825
+ }
826
+
827
+ // Verify min for thread K/N
828
+ if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
829
+ return false;
830
+ }
831
+
832
+ // num_threads must be at least 128 (= 4 warps)
833
+ if (th_config.num_threads < 128) {
834
+ return false;
835
+ }
836
+
837
+ return true;
838
+ }
839
+
840
+ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
841
+ if (prob_m <= 16) {
842
+ for (auto th_config : small_batch_thread_configs) {
843
+ if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
844
+ return th_config;
845
+ }
846
+ }
847
+
848
+ } else {
849
+ for (auto th_config : large_batch_thread_configs) {
850
+ if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
851
+ return th_config;
852
+ }
853
+ }
854
+ }
855
+
856
+ return thread_config_t{-1, -1, -1};
857
+ }
858
+
859
+ #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
860
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
861
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
862
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
863
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
864
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
865
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
866
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
867
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
868
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
869
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
870
+
871
+ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
872
+ int prob_n, int prob_k, void* workspace, int groupsize = -1,
873
+ int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
874
+ int thread_n = -1, int sms = -1, int max_par = 16) {
875
+ int tot_m = prob_m;
876
+ int tot_m_blocks = ceildiv(tot_m, 16);
877
+ int pad = 16 * tot_m_blocks - tot_m;
878
+
879
+ if (sms == -1)
880
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
881
+
882
+ // Set thread config
883
+ thread_config_t th_config;
884
+ if (thread_k != -1 && thread_n != -1) {
885
+ // User-defined config
886
+ th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
887
+ } else {
888
+ // Auto config
889
+ th_config = determine_thread_config(prob_m, prob_n, prob_k);
890
+ }
891
+
892
+ if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
893
+ throw std::runtime_error(
894
+ "Invalid thread config: thread_k = " + str(th_config.thread_k) +
895
+ ", thread_n = " + str(th_config.thread_n) +
896
+ ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
897
+ str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
898
+ }
899
+
900
+ // Uncomment for debug
901
+ // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
902
+ // ", thread_n = " + str(th_config.thread_n) +
903
+ // ", num_threads = " + str(th_config.num_threads) + " for
904
+ // MKN = [" + str(prob_m) +
905
+ // ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
906
+
907
+ int num_threads = th_config.num_threads;
908
+ thread_k = th_config.thread_k;
909
+ thread_n = th_config.thread_n;
910
+
911
+ int thread_k_blocks = thread_k / 16;
912
+ int thread_n_blocks = thread_n / 16;
913
+ int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
914
+ int blocks = sms;
915
+
916
+ if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
917
+ return;
918
+ }
919
+
920
+ TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
921
+ " is not divisible by thread_n = ", thread_n);
922
+ TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
923
+ " is not divisible by thread_k = ", thread_k);
924
+ if (group_blocks != -1) {
925
+ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
926
+ " is not divisible by group_blocks = ", group_blocks);
927
+ }
928
+
929
+ const int4* A_ptr = (const int4*)A;
930
+ const int4* B_ptr = (const int4*)B;
931
+ int4* C_ptr = (int4*)C;
932
+ const int4* s_ptr = (const int4*)s;
933
+
934
+ int* locks = (int*)workspace;
935
+
936
+ for (int i = 0; i < tot_m_blocks; i += 4) {
937
+ int thread_m_blocks = tot_m_blocks - i;
938
+ prob_m = tot_m - 16 * i;
939
+ int par = 1;
940
+ if (thread_m_blocks > 4) {
941
+ // Note that parallel > 1 currently only works for inputs without any
942
+ // padding
943
+ par = (16 * thread_m_blocks - pad) / 64;
944
+ if (par > max_par) par = max_par;
945
+ prob_m = 64 * par;
946
+ i += 4 * (par - 1);
947
+ thread_m_blocks = 4;
948
+ }
949
+
950
+ // For compilation speed, we only define the kernel configurations that have
951
+ // seemed useful (in terms of performance) in our testing, however many more
952
+ // are, in principle, possible.
953
+ if (false) {
954
+ }
955
+ CALL_IF(8, 8, 256)
956
+ CALL_IF(16, 4, 256)
957
+ CALL_IF(8, 4, 128)
958
+ CALL_IF(4, 8, 128)
959
+ else {
960
+ throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
961
+ ", " + str(prob_k) + ", " + str(prob_n) + "]" +
962
+ ", groupsize = " + str(groupsize) +
963
+ ", thread_m_blocks = " + str(thread_m_blocks) +
964
+ ", thread_n_blocks = " + str(thread_n_blocks) +
965
+ ", thread_k_blocks = " + str(thread_k_blocks));
966
+ }
967
+
968
+ A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
969
+ C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
970
+ }
971
+ }
972
+
973
+ } // namespace marlin_dense
974
+
975
+ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
976
+ torch::Tensor& b_scales, torch::Tensor& workspace,
977
+ int64_t size_m, int64_t size_n, int64_t size_k) {
978
+ // Verify M
979
+ TORCH_CHECK(size_m == a.size(0),
980
+ "Shape mismatch: a.size(0) = " + str(a.size(0)) +
981
+ ", size_m = " + str(size_m));
982
+
983
+ // Verify K
984
+ TORCH_CHECK(size_k == a.size(1),
985
+ "Shape mismatch: a.size(1) = " + str(a.size(1)) +
986
+ ", size_k = " + str(size_k));
987
+ TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
988
+ "size_k = " + str(size_k) + " is not divisible by tile_size = " +
989
+ str(marlin_dense::tile_size));
990
+ TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
991
+ "Shape mismatch: b_q_weight.size(0) = " +
992
+ str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
993
+ ", tile_size = " + str(marlin_dense::tile_size));
994
+
995
+ // Verify N
996
+ TORCH_CHECK(b_scales.size(1) == size_n,
997
+ "b_scales.size(1) = " + str(b_scales.size(1)) +
998
+ ", size_n = " + str(size_n));
999
+ TORCH_CHECK(
1000
+ b_q_weight.size(1) % marlin_dense::tile_size == 0,
1001
+ "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
1002
+ " is not divisible by tile_size = " + str(marlin_dense::tile_size));
1003
+
1004
+ int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
1005
+ marlin_dense::pack_factor_4bit;
1006
+ TORCH_CHECK(
1007
+ size_n == actual_size_n,
1008
+ "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
1009
+
1010
+ // Verify A device and strides
1011
+ TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
1012
+ TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
1013
+
1014
+ // Verify B device and strides
1015
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
1016
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
1017
+
1018
+ // Verify scales device and strides
1019
+ TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
1020
+ TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
1021
+
1022
+ // Alloc C matrix
1023
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
1024
+ auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
1025
+ torch::Tensor c = torch::empty({size_m, size_n}, options);
1026
+
1027
+ // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
1028
+ // auto -1)
1029
+ int thread_k = -1;
1030
+ // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
1031
+ // auto -1)
1032
+ int thread_n = -1;
1033
+ // sms: number of SMs to use for the kernel (can usually be left as auto -1)
1034
+ int sms = -1;
1035
+
1036
+ // Detect groupsize
1037
+ if (b_scales.size(0) != 1) {
1038
+ TORCH_CHECK(size_k % b_scales.size(0) == 0,
1039
+ "size_k = " + str(size_k) +
1040
+ ", is not divisible by b_scales.size(0) = " +
1041
+ str(b_scales.size(0)));
1042
+ }
1043
+ int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0);
1044
+
1045
+ // Verify groupsize
1046
+ TORCH_CHECK(groupsize == -1 || groupsize == 128,
1047
+ "Unexpected groupsize = " + str(groupsize));
1048
+
1049
+ // Verify workspace size
1050
+ TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
1051
+ "size_n = " + str(size_n) +
1052
+ ", is not divisible by min_thread_n = " +
1053
+ str(marlin_dense::min_thread_n));
1054
+ int min_workspace_size =
1055
+ (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
1056
+ TORCH_CHECK(workspace.numel() >= min_workspace_size,
1057
+ "workspace.numel = " + str(workspace.numel()) +
1058
+ " is below min_workspace_size = " + str(min_workspace_size));
1059
+
1060
+ int dev = a.get_device();
1061
+ marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
1062
+ b_scales.data_ptr(), size_m, size_n, size_k,
1063
+ workspace.data_ptr(), groupsize, dev,
1064
+ at::cuda::getCurrentCUDAStream(dev), thread_k,
1065
+ thread_n, sms, marlin_dense::max_par);
1066
+
1067
+ return c;
1068
+ }
marlin/qqq/marlin_qqq_gemm_kernel.cu ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu
4
+ * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp
5
+ * Modified by HandH1998
6
+ * Copyright (C) 2024 HandH1998
7
+ * Copyright (C) Marlin.2024 Elias Frantar
8
+ *
9
+ * Licensed under the Apache License, Version 2.0 (the "License");
10
+ * you may not use this file except in compliance with the License.
11
+ * You may obtain a copy of the License at
12
+ *
13
+ * http://www.apache.org/licenses/LICENSE-2.0
14
+ *
15
+ * Unless required by applicable law or agreed to in writing, software
16
+ * distributed under the License is distributed on an "AS IS" BASIS,
17
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ * See the License for the specific language governing permissions and
19
+ * limitations under the License.
20
+ */
21
+
22
+ #include <torch/all.h>
23
+
24
+ #include <ATen/cuda/CUDAContext.h>
25
+ #include <c10/cuda/CUDAGuard.h>
26
+ #include <cuda.h>
27
+ #include <cuda_fp16.h>
28
+ #include <cuda_runtime.h>
29
+
30
+ #include <iostream>
31
+
32
+ #include "../dense/common/base.h"
33
+
34
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
35
+ #include "../dense/common/mem.h"
36
+ #endif
37
+
38
+ template <typename T>
39
+ inline std::string str(T x) {
40
+ return std::to_string(x);
41
+ }
42
+
43
+ namespace {
44
+
45
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
46
+
47
+ using I4 = Vec<int, 4>;
48
+ // Matrix fragments for tensor core instructions; their precise layout is
49
+ // documented here:
50
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type
51
+ using FragA = Vec<uint32_t, 2>;
52
+ using FragB = Vec<uint32_t, 1>;
53
+ using FragC = Vec<int, 4>;
54
+ using FragS_GROUP = Vec<half2, 1>; // weight per-group quantization scales
55
+ using FragS_CHANNEL =
56
+ Vec<float, 2>; // weight per-channel quantization scales or activaton
57
+ // per-token quantization scales
58
+
59
+ // NOTE(HandH1998): cp.async.cg only support BYTES = 16, however,
60
+ // cp.async.ca can support BYTES = 4, 8, 16;
61
+ // as s_tok's shape is equal to prob_m, we need set s_tok to float type,
62
+ // and cp_size = 1 float, i.e., 4 BYTES
63
+ // Asynchronous global->shared copy for activation quantizaton scales s_tok
64
+ __device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) {
65
+ const int BYTES = 4;
66
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
67
+ asm volatile(
68
+ "{\n"
69
+ " cp.async.ca.shared.global [%0], [%1], %2;\n"
70
+ "}\n" ::"r"(smem),
71
+ "l"(glob_ptr), "n"(BYTES));
72
+ }
73
+
74
+ // m16n8k16 tensor core mma instruction with int8 inputs and int32
75
+ // output/accumulation.
76
+ __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
77
+ FragC& frag_c) {
78
+ const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
79
+ const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
80
+ int* c = reinterpret_cast<int*>(&frag_c);
81
+ asm volatile(
82
+ "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 "
83
+ "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
84
+ : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
85
+ : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
86
+ "r"(c[3]));
87
+ }
88
+
89
+ // Instruction for loading a full 16x16 matrix fragment of operand A from shared
90
+ // memory, directly in int8 tensor core layout.
91
+ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
92
+ uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
93
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
94
+ asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
95
+ : "=r"(a[0]), "=r"(a[1])
96
+ : "r"(smem));
97
+ }
98
+
99
+ inline __device__ half2 float2_to_half2(float2 f) {
100
+ uint32_t res;
101
+ // NOTE(HandH1998): h0,h1 should be uint16_t, not half
102
+ uint16_t h0, h1;
103
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x));
104
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y));
105
+ asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1));
106
+ return reinterpret_cast<half2&>(res);
107
+ }
108
+
109
+ inline __device__ float int32_to_float(int h) {
110
+ float res;
111
+ asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h));
112
+ return res;
113
+ }
114
+
115
+ // Lookup-table based 3-input logical operation; explicitly used for
116
+ // dequantization as the compiler does not seem to automatically recognize it in
117
+ // all cases.
118
+ template <int lut>
119
+ __device__ inline int lop3(int a, int b, int c) {
120
+ int res;
121
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
122
+ : "=r"(res)
123
+ : "r"(a), "r"(b), "r"(c), "n"(lut));
124
+ return res;
125
+ }
126
+
127
+ // Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
128
+ // for weight per channel dequant.
129
+ __device__ inline FragB dequant_per_channel(int q) {
130
+ static constexpr int MASK = 0xf0f0f0f0;
131
+ FragB frag_b;
132
+ frag_b[0] = (q & MASK);
133
+ return frag_b;
134
+ }
135
+
136
+ // Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
137
+ // for weight per group dequant.
138
+ __device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
139
+ static constexpr uint32_t LO = 0x000f000f;
140
+ static constexpr uint32_t HI = 0x00f000f0;
141
+ static constexpr uint32_t EX = 0x64006400;
142
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
143
+ uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
144
+ uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
145
+ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
146
+ // directly into `SUB` and `ADD`.
147
+ static constexpr uint32_t SUB = 0x64086408;
148
+ static constexpr uint32_t MUL = 0x2c002c00;
149
+ static constexpr uint32_t ADD = 0xd480d480;
150
+ *reinterpret_cast<half2*>(&t0) = __hsub2(
151
+ *reinterpret_cast<half2*>(&t0), *reinterpret_cast<const half2*>(&SUB));
152
+ *reinterpret_cast<half2*>(&t1) = __hfma2(
153
+ *reinterpret_cast<half2*>(&t1), *reinterpret_cast<const half2*>(&MUL),
154
+ *reinterpret_cast<const half2*>(&ADD));
155
+
156
+ uint16_t s = reinterpret_cast<uint16_t*>(&frag_s)[i];
157
+ uint32_t double_s;
158
+ // pack 2xfp16 to half2
159
+ asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s));
160
+ // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4
161
+ // half, respectively)
162
+ static constexpr uint32_t MAGIC_NUM = 0x64806480;
163
+ *reinterpret_cast<half2*>(&t0) = __hfma2(
164
+ *reinterpret_cast<half2*>(&t0), *reinterpret_cast<half2*>(&double_s),
165
+ *reinterpret_cast<const half2*>(&MAGIC_NUM));
166
+ *reinterpret_cast<half2*>(&t1) = __hfma2(
167
+ *reinterpret_cast<half2*>(&t1), *reinterpret_cast<half2*>(&double_s),
168
+ *reinterpret_cast<const half2*>(&MAGIC_NUM));
169
+ // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4
170
+ // int8 into 1 uint32
171
+ FragB frag_b;
172
+ uint32_t uint8s;
173
+ static constexpr uint32_t MASK_0246 = 0x6420;
174
+ static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
175
+ asm volatile("prmt.b32 %0,%1,%2,%3;\n"
176
+ : "=r"(uint8s)
177
+ : "r"(t0), "r"(t1), "n"(MASK_0246));
178
+ frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK);
179
+ return frag_b;
180
+ }
181
+
182
+ template <const int threads, // number of threads in a threadblock
183
+ const int thread_m_blocks, // number of 16x16 blocks in the m
184
+ // dimension (batchsize) of the
185
+ // threadblock
186
+ const int thread_n_blocks, // same for n dimension (output)
187
+ const int thread_k_blocks, // same for k dimension (reduction)
188
+ const int stages, // number of stages for the async global->shared
189
+ // fetch pipeline
190
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
191
+ // with a separate quantization scale
192
+ >
193
+ __global__ void Marlin(
194
+ const int4* __restrict__ A, // int8 input matrix of shape mxk
195
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
196
+ int4* __restrict__ C, // int32 global_reduce buffer of shape
197
+ // (max_par*16*4)xn, as int8 tensor core's output is
198
+ // int32 dtype
199
+ int4* __restrict__ D, // fp16 output buffer of shape mxn
200
+ const float* __restrict__ s_tok, // fp32 activation per-token quantization
201
+ // scales of shape mx1
202
+ const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
203
+ // scales of shape 1xn
204
+ const int4* __restrict__ s_group, // fp16 weight per-group quantization
205
+ // scales of shape (k/groupsize)xn, when
206
+ // group_blocks=-1, it should be nullptr
207
+ int prob_m, // batch dimension m
208
+ int prob_n, // output dimension n
209
+ int prob_k, // reduction dimension k
210
+ int* locks // extra global storage for barrier synchronization
211
+ ) {
212
+ // Each threadblock processes one "stripe" of the B matrix with (roughly) the
213
+ // same size, which might involve multiple column "slices" (of width 16 *
214
+ // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
215
+ // example:
216
+ // 0 1 3
217
+ // 0 2 3
218
+ // 1 2 4
219
+ // While this kind of partitioning makes things somewhat more complicated, it
220
+ // ensures good utilization of all SMs for many kinds of shape and GPU
221
+ // configurations, while requiring as few slow global cross-threadblock
222
+ // reductions as possible.
223
+
224
+ // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
225
+ // better partitioning with less reductions
226
+ int parallel = 1;
227
+ if (prob_m > 16 * thread_m_blocks) {
228
+ parallel = prob_m / (16 * thread_m_blocks);
229
+ prob_m = 16 * thread_m_blocks;
230
+ }
231
+
232
+ int k_tiles = prob_k / 16 / thread_k_blocks;
233
+ int n_tiles = prob_n / 16 / thread_n_blocks;
234
+ int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
235
+ // Ensure that the number of tiles in each stripe is a multiple of the
236
+ // groupsize; this avoids an annoying special case where a stripe starts in
237
+ // the middle of group.
238
+ if constexpr (group_blocks != -1)
239
+ iters = (group_blocks / thread_k_blocks) *
240
+ ceildiv(iters, (group_blocks / thread_k_blocks));
241
+
242
+ int slice_row = (iters * blockIdx.x) % k_tiles;
243
+ int slice_col_par = (iters * blockIdx.x) / k_tiles;
244
+ int slice_col = slice_col_par;
245
+ int slice_iters; // number of threadblock tiles in the current slice
246
+ int slice_count =
247
+ 0; // total number of active threadblocks in the current slice
248
+ int slice_idx; // index of threadblock in current slice; numbered bottom to
249
+ // top
250
+
251
+ // We can easily implement parallel problem execution by just remapping
252
+ // indices and advancing global pointers
253
+ if (slice_col_par >= n_tiles) {
254
+ A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16;
255
+ C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4;
256
+ D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
257
+ s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
258
+ locks += (slice_col_par / n_tiles) * n_tiles;
259
+ slice_col = slice_col_par % n_tiles;
260
+ }
261
+
262
+ // Compute all information about the current slice which is required for
263
+ // synchronization.
264
+ auto init_slice = [&]() {
265
+ slice_iters =
266
+ iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
267
+ if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
268
+ if (slice_iters == 0) return;
269
+ if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
270
+ slice_count = 1;
271
+ slice_idx = 0;
272
+ int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
273
+ if (col_first <= k_tiles * (slice_col_par + 1)) {
274
+ int col_off = col_first - k_tiles * slice_col_par;
275
+ slice_count = ceildiv(k_tiles - col_off, iters);
276
+ if (col_off > 0) slice_count++;
277
+ int delta_first = iters * blockIdx.x - col_first;
278
+ if (delta_first < 0 || (col_off == 0 && delta_first == 0))
279
+ slice_idx = slice_count - 1;
280
+ else {
281
+ slice_idx = slice_count - 1 - delta_first / iters;
282
+ if (col_off > 0) slice_idx--;
283
+ }
284
+ }
285
+ if (slice_col == n_tiles) {
286
+ A += 16 * thread_m_blocks * prob_k / 16;
287
+ C += 16 * thread_m_blocks * prob_n / 4;
288
+ D += 16 * thread_m_blocks * prob_n / 8;
289
+ s_tok += 16 * thread_m_blocks;
290
+ locks += n_tiles;
291
+ slice_col = 0;
292
+ }
293
+ };
294
+ init_slice();
295
+
296
+ int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory
297
+ // We typically use `constexpr` to indicate that this value is a compile-time
298
+ // constant
299
+ constexpr int a_sh_stride =
300
+ 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory
301
+ constexpr int a_gl_rd_delta_o =
302
+ 16 * thread_k_blocks /
303
+ 16; // delta between subsequent A tiles in global memory
304
+ int a_gl_rd_delta_i =
305
+ a_gl_stride *
306
+ (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
307
+ constexpr int a_sh_wr_delta =
308
+ a_sh_stride *
309
+ (threads / a_gl_rd_delta_o); // between shared memory writes
310
+ constexpr int a_sh_rd_delta_o =
311
+ 1 * ((threads / 32) /
312
+ (thread_n_blocks / 4)); // between shared memory tile reads
313
+ constexpr int a_sh_rd_delta_i =
314
+ a_sh_stride * 16; // within a shared memory tile
315
+ constexpr int a_sh_stage =
316
+ a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
317
+ constexpr int a_sh_wr_iters =
318
+ ceildiv(a_sh_stage,
319
+ a_sh_wr_delta); // number of shared write iterations for a tile
320
+
321
+ int b_gl_stride = 16 * prob_n / 32;
322
+ constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
323
+ int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
324
+ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
325
+ constexpr int b_sh_wr_delta = threads;
326
+ constexpr int b_sh_rd_delta = threads;
327
+ constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
328
+ constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
329
+
330
+ constexpr int s_tok_sh_stride = 16 * thread_m_blocks;
331
+
332
+ constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4;
333
+
334
+ int s_group_gl_stride = prob_n / 8;
335
+ constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8;
336
+ constexpr int s_group_sh_stage = s_group_sh_stride;
337
+ int s_group_gl_rd_delta = s_group_gl_stride;
338
+
339
+ // Global A read index of current thread.
340
+ int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
341
+ (threadIdx.x % a_gl_rd_delta_o);
342
+ a_gl_rd += a_gl_rd_delta_o * slice_row;
343
+ // Shared write index of current thread.
344
+ int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
345
+ (threadIdx.x % a_gl_rd_delta_o);
346
+ // Shared read index.
347
+ // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix
348
+ int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16);
349
+ a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
350
+
351
+ int b_gl_rd =
352
+ b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
353
+ b_gl_rd += b_sh_stride * slice_col;
354
+ b_gl_rd += b_gl_rd_delta_o * slice_row;
355
+ int b_sh_wr = threadIdx.x;
356
+ int b_sh_rd = threadIdx.x;
357
+
358
+ int s_tok_gl_rd = threadIdx.x;
359
+ // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
360
+ // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
361
+ // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
362
+ // s_tok's size is not fixed, we can not shuffle before inference we shuffle
363
+ // it when fetching s_tok from global memory to shared memory, that's why
364
+ // s_tok_sh_wr is like this
365
+ int s_tok_sh_wr =
366
+ (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8;
367
+ int s_tok_sh_rd = (threadIdx.x % 32) / 4;
368
+ bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
369
+
370
+ int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
371
+ int s_ch_sh_wr = threadIdx.x;
372
+ int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
373
+ 2 * ((threadIdx.x % 32) % 4);
374
+ bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
375
+
376
+ int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd;
377
+ bool s_group_sh_wr_pred;
378
+ if constexpr (group_blocks != -1) {
379
+ s_group_gl_rd =
380
+ s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
381
+ s_group_sh_stride * slice_col + threadIdx.x;
382
+ s_group_sh_wr = threadIdx.x;
383
+ // NOTE(HandH1998): s_group_sh_rd is related to mma output C
384
+ s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
385
+ (threadIdx.x % 32) / 4;
386
+ s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride;
387
+ }
388
+
389
+ // Precompute which thread should not read memory in which iterations; this is
390
+ // needed if there are more threads than required for a certain tilesize or
391
+ // when the batchsize is not a multiple of 16.
392
+ bool a_sh_wr_pred[a_sh_wr_iters];
393
+ #pragma unroll
394
+ for (int i = 0; i < a_sh_wr_iters; i++)
395
+ a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
396
+
397
+ // To ensure that writing and reading A tiles to/from shared memory, the
398
+ // latter in fragment format, is fully bank conflict free, we need to use a
399
+ // rather fancy XOR-based layout. The key here is that neither reads nor
400
+ // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
401
+ // same shared memory banks. Further, it seems (based on NSight-Compute) that
402
+ // each warp must also write a consecutive memory segment?
403
+ auto transform_a = [&](int i) {
404
+ int row = i / a_gl_rd_delta_o;
405
+ return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
406
+ };
407
+ // Since the computation of this remapping is non-trivial and, due to our main
408
+ // loop unrolls, all shared memory accesses are static, we simply precompute
409
+ // both transformed reads and writes.
410
+ int a_sh_wr_trans[a_sh_wr_iters];
411
+ #pragma unroll
412
+ for (int i = 0; i < a_sh_wr_iters; i++)
413
+ a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
414
+ int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
415
+ #pragma unroll
416
+ for (int i = 0; i < b_sh_wr_iters; i++) {
417
+ #pragma unroll
418
+ for (int j = 0; j < thread_m_blocks; j++)
419
+ a_sh_rd_trans[i][j] =
420
+ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
421
+ }
422
+
423
+ // Since B-accesses have non-constant stride they have to be computed at
424
+ // runtime; we break dependencies between subsequent accesses with a tile by
425
+ // maintining multiple pointers (we have enough registers), a tiny
426
+ // optimization.
427
+ const int4* B_ptr[b_sh_wr_iters];
428
+ #pragma unroll
429
+ for (int i = 0; i < b_sh_wr_iters; i++)
430
+ B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
431
+
432
+ extern __shared__ int4 sh[];
433
+ // Shared memory storage for global fetch pipelines.
434
+ // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages *
435
+ // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage)
436
+ int4* sh_a = sh;
437
+ int4* sh_b = sh_a + (stages * a_sh_stage);
438
+ int4* sh_s_tok = sh_b + (stages * b_sh_stage);
439
+ int4* sh_s_ch = sh_s_tok + s_tok_sh_stride;
440
+ int4* sh_s_group = sh_s_ch + s_ch_sh_stride;
441
+
442
+ // Register storage for double buffer of shared memory reads.
443
+ FragA frag_a[2][thread_m_blocks];
444
+ I4 frag_b_quant[2];
445
+ FragC frag_c[thread_m_blocks][4][2];
446
+ FragS_GROUP frag_s_group[2][4];
447
+ FragS_CHANNEL frag_s_tok[thread_m_blocks];
448
+ FragS_CHANNEL frag_s_ch[2][4];
449
+
450
+ // Zero accumulators.
451
+ auto zero_accums = [&]() {
452
+ #pragma unroll
453
+ for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
454
+ reinterpret_cast<int*>(frag_c)[i] = 0;
455
+ };
456
+
457
+ // Asynchronously fetch the next A, B and s tile from global to the next
458
+ // shared memory pipeline location.
459
+ auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
460
+ if (pred) {
461
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
462
+ #pragma unroll
463
+ for (int i = 0; i < a_sh_wr_iters; i++) {
464
+ cp_async4_pred(
465
+ &sh_a_stage[a_sh_wr_trans[i]],
466
+ &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
467
+ a_sh_wr_pred[i]);
468
+ }
469
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
470
+ #pragma unroll
471
+ for (int i = 0; i < b_sh_wr_iters; i++) {
472
+ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
473
+ B_ptr[i] += b_gl_rd_delta_o;
474
+ }
475
+ // Only fetch scales if this tile starts a new group
476
+ if constexpr (group_blocks != -1) {
477
+ if (pipe % (group_blocks / thread_k_blocks) == 0) {
478
+ int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe;
479
+ if (s_group_sh_wr_pred)
480
+ cp_async4(&sh_s_group_stage[s_group_sh_wr],
481
+ &s_group[s_group_gl_rd]);
482
+ s_group_gl_rd += s_group_gl_rd_delta;
483
+ }
484
+ }
485
+ }
486
+ // Insert a fence even when we are winding down the pipeline to ensure that
487
+ // waiting is also correct at this point.
488
+ cp_async_fence();
489
+ };
490
+
491
+ // Wait until the next thread tile has been loaded to shared memory.
492
+ auto wait_for_stage = [&]() {
493
+ // We only have `stages - 2` active fetches since we are double buffering
494
+ // and can only issue the next fetch when it is guaranteed that the previous
495
+ // shared memory load is fully complete (as it may otherwise be
496
+ // overwritten).
497
+ cp_async_wait<stages - 2>();
498
+ __syncthreads();
499
+ };
500
+
501
+ // Load the next sub-tile from the current location in the shared memory pipe
502
+ // into the current register buffer.
503
+ auto fetch_to_registers = [&](int k, int pipe) {
504
+ // It may seem inefficient that we reload the groups for every sub-tile;
505
+ // however, this does not seem to be a significant bottleneck, while some
506
+ // theoretically better attempts have lead to bad instruction ordering by
507
+ // the compiler and correspondingly a noticeable drop in performance.
508
+ if constexpr (group_blocks != -1) {
509
+ int4* sh_s_group_stage =
510
+ sh_s_group +
511
+ s_group_sh_stage * ((group_blocks / thread_k_blocks) *
512
+ (pipe / (group_blocks / thread_k_blocks)));
513
+ reinterpret_cast<int4*>(&frag_s_group[k % 2])[0] =
514
+ sh_s_group_stage[s_group_sh_rd];
515
+ }
516
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
517
+ #pragma unroll
518
+ for (int i = 0; i < thread_m_blocks; i++)
519
+ ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
520
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
521
+ frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
522
+ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
523
+ };
524
+
525
+ // Execute the actual tensor core matmul of a sub-tile.
526
+ auto matmul = [&](int k) {
527
+ // We have the m dimension as the inner loop in order to encourage overlapping
528
+ // dequantization and matmul operations.
529
+ #pragma unroll
530
+ for (int j = 0; j < 4; j++) {
531
+ int b_quant = frag_b_quant[k % 2][j];
532
+ // int b_quant_shift = b_quant << 4;
533
+ FragB frag_b0, frag_b1;
534
+ // If there are no groups, we can just scale the final output once and can
535
+ // avoid doing so for each weight.
536
+ if constexpr (group_blocks != -1) {
537
+ int b_quant_shift = b_quant >> 8;
538
+ frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0);
539
+ frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1);
540
+ } else {
541
+ int b_quant_shift = b_quant << 4;
542
+ frag_b0 = dequant_per_channel(b_quant);
543
+ frag_b1 = dequant_per_channel(b_quant_shift);
544
+ }
545
+ #pragma unroll
546
+ for (int i = 0; i < thread_m_blocks; i++) {
547
+ mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
548
+ mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
549
+ }
550
+ }
551
+ };
552
+
553
+ // Since we slice across the k dimension of a tile in order to increase the
554
+ // number of warps while keeping the n dimension of a tile reasonable, we have
555
+ // multiple warps that accumulate their partial sums of the same output
556
+ // location; which we have to reduce over in the end. We do in shared memory.
557
+ auto thread_block_reduce = [&]() {
558
+ constexpr int red_off = threads / b_sh_stride / 2;
559
+ if (red_off >= 1) {
560
+ int red_idx = threadIdx.x / b_sh_stride;
561
+ constexpr int red_sh_stride = b_sh_stride * 4 * 2;
562
+ constexpr int red_sh_delta = b_sh_stride;
563
+ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
564
+ (threadIdx.x % b_sh_stride);
565
+
566
+ // Parallel logarithmic shared memory reduction. We make sure to avoid any
567
+ // unnecessary read or write iterations, e.g., for two warps we write only
568
+ // once by warp 1 and read only once by warp 0.
569
+
570
+ #pragma unroll
571
+ for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
572
+ #pragma unroll
573
+ for (int i = red_off; i > 0; i /= 2) {
574
+ if (i <= red_idx && red_idx < 2 * i) {
575
+ #pragma unroll
576
+ for (int j = 0; j < 4 * 2; j++) {
577
+ int red_sh_wr =
578
+ red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
579
+ if (i < red_off) {
580
+ int* c_rd =
581
+ reinterpret_cast<int*>(&sh[red_sh_delta * j + red_sh_rd]);
582
+ int* c_wr = reinterpret_cast<int*>(&sh[red_sh_wr]);
583
+ #pragma unroll
584
+ for (int k = 0; k < 4; k++)
585
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
586
+ c_rd[k] + c_wr[k];
587
+ }
588
+ sh[red_sh_wr] =
589
+ reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
590
+ }
591
+ }
592
+ __syncthreads();
593
+ }
594
+ if (red_idx == 0) {
595
+ #pragma unroll
596
+ for (int i = 0; i < 4 * 2; i++) {
597
+ int* c_rd =
598
+ reinterpret_cast<int*>(&sh[red_sh_delta * i + red_sh_rd]);
599
+ #pragma unroll
600
+ for (int j = 0; j < 4; j++)
601
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
602
+ c_rd[j];
603
+ }
604
+ }
605
+ __syncthreads();
606
+ }
607
+ }
608
+ };
609
+
610
+ // Since multiple threadblocks may process parts of the same column slice, we
611
+ // finally have to globally reduce over the results. As the striped
612
+ // partitioning minimizes the number of such reductions and our outputs are
613
+ // usually rather small, we perform this reduction serially in L2 cache.
614
+ // global_reduce works on INT32 elements, which are the results of INT8 GEMM.
615
+ // This is why we need another INT32 maxtrix `C` to reduce instead of the
616
+ // original half matrix `D`.
617
+ auto global_reduce = [&](bool first = false, bool last = false) {
618
+ // We are very careful here to reduce directly in the output buffer to
619
+ // maximize L2 cache utilization in this step. To do this, we write out
620
+ // results in FP16 (but still reduce with FP32 compute).
621
+ constexpr int active_threads = 32 * thread_n_blocks / 4;
622
+ if (threadIdx.x < active_threads) {
623
+ int c_gl_stride = prob_n / 4;
624
+ int c_gl_wr_delta_o = 8 * c_gl_stride;
625
+ int c_gl_wr_delta_i = 8 * (active_threads / 32);
626
+ int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
627
+ 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
628
+ c_gl_wr += (4 * thread_n_blocks) * slice_col;
629
+ constexpr int c_sh_wr_delta = active_threads * 2;
630
+ int c_sh_wr = 2 * threadIdx.x;
631
+
632
+ int row = (threadIdx.x % 32) / 4;
633
+
634
+ if (!first) {
635
+ // Interestingly, doing direct global accesses here really seems to mess up
636
+ // the compiler and lead to slowdowns, hence we also use async-copies even
637
+ // though these fetches are not actually asynchronous.
638
+ #pragma unroll
639
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
640
+ cp_async4_pred(
641
+ &sh[c_sh_wr + c_sh_wr_delta * i],
642
+ &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
643
+ c_gl_wr_delta_i * (i % 2)],
644
+ i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
645
+ cp_async4_pred(
646
+ &sh[c_sh_wr + c_sh_wr_delta * i + 1],
647
+ &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
648
+ c_gl_wr_delta_i * (i % 2) + 1],
649
+ i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
650
+ }
651
+ cp_async_fence();
652
+ cp_async_wait<0>();
653
+ }
654
+
655
+ #pragma unroll
656
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
657
+ if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
658
+ if (!first) {
659
+ int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta];
660
+ int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1];
661
+ #pragma unroll
662
+ for (int j = 0; j < 4; j++) {
663
+ reinterpret_cast<int*>(
664
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
665
+ reinterpret_cast<int*>(&d_red1)[j];
666
+ }
667
+ #pragma unroll
668
+ for (int j = 0; j < 4; j++) {
669
+ reinterpret_cast<int*>(
670
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] +=
671
+ reinterpret_cast<int*>(&d_red2)[j];
672
+ }
673
+ }
674
+ if (!last) {
675
+ int4 d1, d2;
676
+ #pragma unroll
677
+ for (int j = 0; j < 4; j++) {
678
+ reinterpret_cast<int*>(&d1)[j] = reinterpret_cast<int*>(
679
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)];
680
+ }
681
+ #pragma unroll
682
+ for (int j = 0; j < 4; j++) {
683
+ reinterpret_cast<int*>(&d2)[j] = reinterpret_cast<int*>(
684
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)];
685
+ }
686
+ C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
687
+ d1;
688
+ C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) +
689
+ 1] = d2;
690
+ }
691
+ }
692
+ }
693
+ }
694
+ };
695
+
696
+ // Write out the reduce final result in the correct layout. We only actually
697
+ // reshuffle matrix fragments in this step, the reduction above is performed
698
+ // in fragment layout.
699
+ auto write_result = [&]() {
700
+ int d_gl_stride = prob_n / 8;
701
+ constexpr int d_sh_stride = 2 * thread_n_blocks + 1;
702
+ int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks));
703
+ constexpr int d_sh_rd_delta =
704
+ d_sh_stride * (threads / (2 * thread_n_blocks));
705
+
706
+ int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
707
+ (threadIdx.x % (2 * thread_n_blocks));
708
+ d_gl_wr += (2 * thread_n_blocks) * slice_col;
709
+ int d_sh_wr =
710
+ (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
711
+ d_sh_wr += 32 * (threadIdx.x / 32);
712
+ int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
713
+ (threadIdx.x % (2 * thread_n_blocks));
714
+
715
+ int d_gl_wr_end = d_gl_stride * prob_m;
716
+
717
+ // We first reorder in shared memory to guarantee the most efficient final
718
+ // global write patterns
719
+ auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) {
720
+ float2 deq_res;
721
+ deq_res.x = int32_to_float(c0) * w_s[0] * a_s;
722
+ deq_res.y = int32_to_float(c1) * w_s[1] * a_s;
723
+ ((half2*)sh)[idx] = float2_to_half2(deq_res);
724
+ };
725
+
726
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
727
+ #pragma unroll
728
+ for (int i = 0; i < thread_m_blocks; i++) {
729
+ #pragma unroll
730
+ for (int j = 0; j < 4; j++) {
731
+ int wr = d_sh_wr + 8 * j;
732
+ write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0],
733
+ frag_c[i][j][0][1], frag_s_tok[i][0],
734
+ frag_s_ch[j / 2][2 * (j % 2) + 0]);
735
+ write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2],
736
+ frag_c[i][j][0][3], frag_s_tok[i][1],
737
+ frag_s_ch[j / 2][2 * (j % 2) + 0]);
738
+ write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0],
739
+ frag_c[i][j][1][1], frag_s_tok[i][0],
740
+ frag_s_ch[j / 2][2 * (j % 2) + 1]);
741
+ write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2],
742
+ frag_c[i][j][1][3], frag_s_tok[i][1],
743
+ frag_s_ch[j / 2][2 * (j % 2) + 1]);
744
+ }
745
+ d_sh_wr += 16 * (4 * d_sh_stride);
746
+ }
747
+ }
748
+ __syncthreads();
749
+
750
+ #pragma unroll
751
+ for (int i = 0;
752
+ i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
753
+ i++) {
754
+ if (d_gl_wr < d_gl_wr_end) {
755
+ D[d_gl_wr] = sh[d_sh_rd];
756
+ d_gl_wr += d_gl_wr_delta;
757
+ d_sh_rd += d_sh_rd_delta;
758
+ }
759
+ }
760
+ };
761
+
762
+ // Start global fetch and register load pipelines.
763
+ auto start_pipes = [&]() {
764
+ #pragma unroll
765
+ for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
766
+ zero_accums();
767
+ wait_for_stage();
768
+ fetch_to_registers(0, 0);
769
+ a_gl_rd += a_gl_rd_delta_o * (stages - 1);
770
+ };
771
+ start_pipes();
772
+
773
+ // Main loop.
774
+ while (slice_iters) {
775
+ // We unroll over both the global fetch and the register load pipeline to
776
+ // ensure all shared memory accesses are static. Note that both pipelines have
777
+ // even length meaning that the next iteration will always start at index 0.
778
+ #pragma unroll
779
+ for (int pipe = 0; pipe < stages;) {
780
+ #pragma unroll
781
+ for (int k = 0; k < b_sh_wr_iters; k++) {
782
+ fetch_to_registers(k + 1, pipe % stages);
783
+ if (k == b_sh_wr_iters - 2) {
784
+ fetch_to_shared((pipe + stages - 1) % stages, pipe,
785
+ slice_iters >= stages);
786
+ pipe++;
787
+ wait_for_stage();
788
+ }
789
+ matmul(k);
790
+ }
791
+ slice_iters--;
792
+ if (slice_iters == 0) break;
793
+ }
794
+ a_gl_rd += a_gl_rd_delta_o * stages;
795
+
796
+ // Process results and, if necessary, proceed to the next column slice.
797
+ // While this pattern may not be the most readable, other ways of writing
798
+ // the loop seemed to noticeably worse performance after compilation.
799
+ if (slice_iters == 0) {
800
+ cp_async_wait<0>();
801
+ bool last = slice_idx == slice_count - 1;
802
+ // For per-column scales, we only fetch them here in the final step before
803
+ // write-out
804
+ if (last) {
805
+ if (s_tok_sh_wr_pred) {
806
+ cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]);
807
+ }
808
+ if (s_ch_sh_wr_pred) {
809
+ cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]);
810
+ }
811
+ cp_async_fence();
812
+ }
813
+ thread_block_reduce();
814
+ if (last) {
815
+ cp_async_wait<0>();
816
+ __syncthreads();
817
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
818
+ #pragma unroll
819
+ for (int i = 0; i < thread_m_blocks; i++) {
820
+ frag_s_tok[i][0] =
821
+ *reinterpret_cast<float*>(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]);
822
+ frag_s_tok[i][1] = *reinterpret_cast<float*>(
823
+ &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]);
824
+ }
825
+ reinterpret_cast<int4*>(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0];
826
+ reinterpret_cast<int4*>(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1];
827
+ reinterpret_cast<int4*>(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8];
828
+ reinterpret_cast<int4*>(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9];
829
+ }
830
+ }
831
+ if (slice_count > 1) { // only globally reduce if there is more than one
832
+ // block in a slice
833
+ barrier_acquire(&locks[slice_col], slice_idx);
834
+ global_reduce(slice_idx == 0, last);
835
+ barrier_release(&locks[slice_col], last);
836
+ }
837
+ if (last) // only the last block in a slice actually writes the result
838
+ write_result();
839
+ slice_row = 0;
840
+ slice_col_par++;
841
+ slice_col++;
842
+ init_slice();
843
+ if (slice_iters) {
844
+ a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
845
+ (threadIdx.x % a_gl_rd_delta_o);
846
+ #pragma unroll
847
+ for (int i = 0; i < b_sh_wr_iters; i++)
848
+ B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
849
+ if (slice_col == 0) {
850
+ #pragma unroll
851
+ for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
852
+ }
853
+ s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x;
854
+ s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
855
+ start_pipes();
856
+ }
857
+ }
858
+ }
859
+ }
860
+
861
+ #else
862
+
863
+ template <const int threads, // number of threads in a threadblock
864
+ const int thread_m_blocks, // number of 16x16 blocks in the m
865
+ // dimension (batchsize) of the
866
+ // threadblock
867
+ const int thread_n_blocks, // same for n dimension (output)
868
+ const int thread_k_blocks, // same for k dimension (reduction)
869
+ const int stages, // number of stages for the async global->shared
870
+ // fetch pipeline
871
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
872
+ // with a separate quantization scale
873
+ >
874
+ __global__ void Marlin(
875
+ const int4* __restrict__ A, // int8 input matrix of shape mxk
876
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
877
+ int4* __restrict__ C, // int32 global_reduce buffer of shape
878
+ // (max_par*16*4)xn, as int8 tensor core's output is
879
+ // int32 dtype
880
+ int4* __restrict__ D, // fp16 output buffer of shape mxn
881
+ const float* __restrict__ s_tok, // fp32 activation per-token quantization
882
+ // scales of shape mx1
883
+ const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
884
+ // scales of shape 1xn
885
+ const int4* __restrict__ s_group, // fp16 weight per-group quantization
886
+ // scales of shape (k/groupsize)xn, when
887
+ // group_blocks=-1, it should be nullptr
888
+ int prob_m, // batch dimension m
889
+ int prob_n, // output dimension n
890
+ int prob_k, // reduction dimension k
891
+ int* locks // extra global storage for barrier synchronization
892
+ ) {
893
+ // Marlin is not implemented yet for SM < 8.0
894
+ assert(false);
895
+ return;
896
+ }
897
+
898
+ #endif
899
+
900
+ // 8 warps are a good choice since every SM has 4 schedulers and having more
901
+ // than 1 warp per schedule allows some more latency hiding. At the same time,
902
+ // we want relatively few warps to have many registers per warp and small tiles.
903
+ const int USER_THREADS =
904
+ 256; // Note: This is only used with user-provided thread_k/n
905
+ const int STAGES = 4; // 4 pipeline stages fit into shared memory
906
+
907
+ static constexpr int min_thread_n = 64;
908
+ static constexpr int min_thread_k = 64;
909
+
910
+ static constexpr int tile_size = 16;
911
+ static constexpr int max_par = 16;
912
+
913
+ static constexpr int pack_factor_4bit =
914
+ 8; // We have 8 4-bit vals inside a 32 bit
915
+
916
+ #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
917
+ GROUP_BLOCKS, NUM_THREADS) \
918
+ else if (thread_m_blocks == THREAD_M_BLOCKS && \
919
+ thread_n_blocks == THREAD_N_BLOCKS && \
920
+ thread_k_blocks == THREAD_K_BLOCKS && \
921
+ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
922
+ cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
923
+ THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
924
+ cudaFuncAttributeMaxDynamicSharedMemorySize, \
925
+ max_shared_mem); \
926
+ Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
927
+ STAGES, GROUP_BLOCKS> \
928
+ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
929
+ A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \
930
+ prob_m, prob_n, prob_k, locks); \
931
+ }
932
+
933
+ typedef struct {
934
+ int thread_k;
935
+ int thread_n;
936
+ int num_threads;
937
+ } thread_config_t;
938
+
939
+ thread_config_t small_batch_thread_configs[] = {
940
+ // Ordered by priority
941
+
942
+ // thread_k, thread_n, num_threads
943
+ {128, 128, 256}, // Default
944
+ {128, 64, 128}, // Reduce N 2X, same K
945
+ {64, 256, 256}, // Reduce K 2X, increase N 2X
946
+ {64, 128, 128}, // Reduce K 2X, same N
947
+ };
948
+
949
+ thread_config_t large_batch_thread_configs[] = {
950
+ // Ordered by priority
951
+
952
+ // thread_k, thread_n, num_threads
953
+ {64, 256, 256}, // Default
954
+ {128, 128, 256}, // Reduce N 2X, increase K 2X
955
+ {64, 128, 128}, // Reduce N 2X, same K
956
+ {128, 64, 128}, // Reduce N 4X, increase K 2X
957
+ };
958
+
959
+ bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
960
+ int prob_k) {
961
+ // Sanity
962
+ if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
963
+ th_config.num_threads == -1) {
964
+ return false;
965
+ }
966
+
967
+ // Verify K/N are divisible by thread K/N
968
+ if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
969
+ return false;
970
+ }
971
+
972
+ // thread_k can be only 128 or 64 (because it must be less than groupsize
973
+ // which is 128)
974
+ if (th_config.thread_k != 128 && th_config.thread_k != 64) {
975
+ return false;
976
+ }
977
+
978
+ // Verify min for thread K/N
979
+ if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
980
+ return false;
981
+ }
982
+
983
+ // num_threads must be at least 128 (= 4 warps)
984
+ if (th_config.num_threads < 128) {
985
+ return false;
986
+ }
987
+
988
+ return true;
989
+ }
990
+
991
+ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
992
+ if (prob_m <= 16) {
993
+ for (auto th_config : small_batch_thread_configs) {
994
+ if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
995
+ return th_config;
996
+ }
997
+ }
998
+
999
+ } else {
1000
+ for (auto th_config : large_batch_thread_configs) {
1001
+ if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
1002
+ return th_config;
1003
+ }
1004
+ }
1005
+ }
1006
+
1007
+ return thread_config_t{-1, -1, -1};
1008
+ }
1009
+
1010
+ #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
1011
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
1012
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
1013
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
1014
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
1015
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
1016
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
1017
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
1018
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
1019
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
1020
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
1021
+
1022
+ void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D,
1023
+ void* s_tok, void* s_ch, void* s_group, int prob_m,
1024
+ int prob_n, int prob_k, void* workspace,
1025
+ int groupsize = -1, int dev = 0, cudaStream_t stream = 0,
1026
+ int thread_k = -1, int thread_n = -1, int sms = -1,
1027
+ int max_par = 16) {
1028
+ int tot_m = prob_m;
1029
+ int tot_m_blocks = ceildiv(tot_m, 16);
1030
+ int pad = 16 * tot_m_blocks - tot_m;
1031
+
1032
+ if (sms == -1)
1033
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
1034
+
1035
+ int max_shared_mem = 0;
1036
+ cudaDeviceGetAttribute(&max_shared_mem,
1037
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
1038
+ TORCH_CHECK(max_shared_mem > 0);
1039
+
1040
+ // Set thread config
1041
+ thread_config_t th_config;
1042
+ if (thread_k != -1 && thread_n != -1) {
1043
+ // User-defined config
1044
+ th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
1045
+ } else {
1046
+ // Auto config
1047
+ th_config = determine_thread_config(prob_m, prob_n, prob_k);
1048
+ }
1049
+
1050
+ if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
1051
+ throw std::runtime_error(
1052
+ "Invalid thread config: thread_k = " + str(th_config.thread_k) +
1053
+ ", thread_n = " + str(th_config.thread_n) +
1054
+ ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
1055
+ str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
1056
+ }
1057
+
1058
+ int num_threads = th_config.num_threads;
1059
+ thread_k = th_config.thread_k;
1060
+ thread_n = th_config.thread_n;
1061
+
1062
+ int thread_k_blocks = thread_k / 16;
1063
+ int thread_n_blocks = thread_n / 16;
1064
+ int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
1065
+ int blocks = sms;
1066
+
1067
+ if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
1068
+ return;
1069
+ }
1070
+
1071
+ TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
1072
+ " is not divisible by thread_n = ", thread_n);
1073
+ TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
1074
+ " is not divisible by thread_k = ", thread_k);
1075
+ if (group_blocks != -1) {
1076
+ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
1077
+ " is not divisible by group_blocks = ", group_blocks);
1078
+ }
1079
+
1080
+ const int4* A_ptr = (const int4*)A;
1081
+ const int4* B_ptr = (const int4*)B;
1082
+ int4* C_ptr = (int4*)C;
1083
+ int4* D_ptr = (int4*)D;
1084
+ const float* s_tok_ptr = (const float*)s_tok;
1085
+ const int4* s_ch_ptr = (const int4*)s_ch;
1086
+ const int4* s_group_ptr = (const int4*)s_group;
1087
+
1088
+ int* locks = (int*)workspace;
1089
+
1090
+ for (int i = 0; i < tot_m_blocks; i += 4) {
1091
+ int thread_m_blocks = tot_m_blocks - i;
1092
+ prob_m = tot_m - 16 * i;
1093
+ int par = 1;
1094
+ if (thread_m_blocks > 4) {
1095
+ // Note that parallel > 1 currently only works for inputs without any
1096
+ // padding
1097
+ par = (16 * thread_m_blocks - pad) / 64;
1098
+ if (par > max_par) par = max_par;
1099
+ prob_m = 64 * par;
1100
+ i += 4 * (par - 1);
1101
+ thread_m_blocks = 4;
1102
+ }
1103
+
1104
+ // For compilation speed, we only define the kernel configurations that have
1105
+ // seemed useful (in terms of performance) in our testing, however many more
1106
+ // are, in principle, possible.
1107
+ if (false) {
1108
+ }
1109
+ CALL_IF(8, 8, 256)
1110
+ CALL_IF(16, 4, 256)
1111
+ CALL_IF(8, 4, 128)
1112
+ CALL_IF(4, 8, 128)
1113
+ else {
1114
+ throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
1115
+ ", " + str(prob_k) + ", " + str(prob_n) + "]" +
1116
+ ", groupsize = " + str(groupsize) +
1117
+ ", thread_m_blocks = " + str(thread_m_blocks) +
1118
+ ", thread_n_blocks = " + str(thread_n_blocks) +
1119
+ ", thread_k_blocks = " + str(thread_k_blocks));
1120
+ }
1121
+
1122
+ A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par;
1123
+ D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
1124
+ s_tok_ptr += 16 * thread_m_blocks * par;
1125
+ }
1126
+ }
1127
+ } // anonymous namespace
1128
+
1129
+ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
1130
+ torch::Tensor const& b_q_weight,
1131
+ torch::Tensor const& s_tok,
1132
+ torch::Tensor const& s_ch,
1133
+ torch::Tensor const& s_group,
1134
+ torch::Tensor& workspace, int64_t size_m,
1135
+ int64_t size_n, int64_t size_k) {
1136
+ // Verify M
1137
+ TORCH_CHECK(size_m == a.size(0),
1138
+ "Shape mismatch: a.size(0) = " + str(a.size(0)) +
1139
+ ", size_m = " + str(size_m));
1140
+ TORCH_CHECK(size_m == s_tok.numel(),
1141
+ "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) +
1142
+ ", size_m = " + str(size_m));
1143
+
1144
+ // Verify K
1145
+ TORCH_CHECK(size_k == a.size(1),
1146
+ "Shape mismatch: a.size(1) = " + str(a.size(1)) +
1147
+ ", size_k = " + str(size_k));
1148
+ TORCH_CHECK(size_k % tile_size == 0,
1149
+ "size_k = " + str(size_k) +
1150
+ " is not divisible by tile_size = " + str(tile_size));
1151
+ TORCH_CHECK(
1152
+ (size_k / tile_size) == b_q_weight.size(0),
1153
+ "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) +
1154
+ ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size));
1155
+
1156
+ int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0);
1157
+ // Verify groupsize
1158
+ TORCH_CHECK(groupsize == -1 || groupsize == 128,
1159
+ "Unexpected groupsize = " + str(groupsize));
1160
+
1161
+ // Verify N
1162
+ TORCH_CHECK(s_ch.numel() == size_n,
1163
+ "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) +
1164
+ ", size_n = " + str(size_n));
1165
+ TORCH_CHECK(b_q_weight.size(1) % tile_size == 0,
1166
+ "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
1167
+ " is not divisible by tile_size = " + str(tile_size));
1168
+ if (groupsize != -1) {
1169
+ TORCH_CHECK(s_group.size(1) == size_n,
1170
+ "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) +
1171
+ ", size_n = " + str(size_n));
1172
+ TORCH_CHECK(
1173
+ size_k % s_group.size(0) == 0,
1174
+ "size_k = " + str(size_k) +
1175
+ ", is not divisible by s_group.size(0) = " + str(s_group.size(0)));
1176
+ }
1177
+
1178
+ int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit;
1179
+ TORCH_CHECK(size_n == actual_size_n,
1180
+ "Shape mismatch: size_n = " + str(size_n) +
1181
+ ", actual_size_n = " + str(actual_size_n));
1182
+
1183
+ // Verify A device and strides
1184
+ TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
1185
+ TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
1186
+
1187
+ // Verify B device and strides
1188
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
1189
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
1190
+
1191
+ // Verify s_tok device, strides and dtype
1192
+ TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU");
1193
+ TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous");
1194
+ TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32");
1195
+
1196
+ // Verify s_ch device, strides and dtype
1197
+ TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU");
1198
+ TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous");
1199
+ TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32");
1200
+
1201
+ // Verify s_group device, strides and dtype
1202
+ TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU");
1203
+ TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous");
1204
+ TORCH_CHECK(s_group.dtype() == torch::kFloat16,
1205
+ "s_group's dtype is not float16");
1206
+
1207
+ // Verify workspace size
1208
+ TORCH_CHECK(size_n % min_thread_n == 0,
1209
+ "size_n = " + str(size_n) +
1210
+ ", is not divisible by min_thread_n = " + str(min_thread_n));
1211
+ int min_workspace_size = (size_n / min_thread_n) * max_par;
1212
+ TORCH_CHECK(workspace.numel() >= min_workspace_size,
1213
+ "workspace.numel = " + str(workspace.numel()) +
1214
+ " is below min_workspace_size = " + str(min_workspace_size));
1215
+
1216
+ // Alloc C matrix
1217
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
1218
+ auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device());
1219
+ torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c);
1220
+
1221
+ // Alloc D matrix
1222
+ auto options_d =
1223
+ torch::TensorOptions().dtype(torch::kFloat16).device(a.device());
1224
+ torch::Tensor d = torch::empty({size_m, size_n}, options_d);
1225
+
1226
+ // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
1227
+ // auto -1)
1228
+ int thread_k = -1;
1229
+ // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
1230
+ // auto -1)
1231
+ int thread_n = -1;
1232
+ // sms: number of SMs to use for the kernel (can usually be left as auto -1)
1233
+ int sms = -1;
1234
+
1235
+ int dev = a.get_device();
1236
+ marlin_qqq_cuda(
1237
+ a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(),
1238
+ s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n,
1239
+ size_k, workspace.data_ptr(), groupsize, dev,
1240
+ at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par);
1241
+
1242
+ return d;
1243
+ }
marlin/sparse/LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Contains code from https://github.com/IST-DASLab/Sparse-Marlin/
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
marlin/sparse/common/base.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
3
+ * Rights Reserved.
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ namespace marlin_24 {
21
+
22
+ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
23
+
24
+ // Instances of `Vec` are used to organize groups of >>registers<<, as needed
25
+ // for instance as inputs to tensor core operations. Consequently, all
26
+ // corresponding index accesses must be compile-time constants, which is why we
27
+ // extensively use `#pragma unroll` throughout the kernel code to guarantee
28
+ // this.
29
+ template <typename T, int n>
30
+ struct Vec {
31
+ T elems[n];
32
+ __device__ T& operator[](int i) { return elems[i]; }
33
+ };
34
+
35
+ template <int M_, int N_, int K_>
36
+ struct ShapeBase {
37
+ static constexpr int M = M_, N = N_, K = K_;
38
+ };
39
+
40
+ using I4 = Vec<int, 4>;
41
+
42
+ // Matrix fragments for tensor core instructions; their precise layout is
43
+ // documented here:
44
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
45
+ using FragA = Vec<half2, 4>;
46
+ using FragB = Vec<half2, 2>;
47
+ using FragM = Vec<uint, 1>;
48
+ using FragC = Vec<float, 4>;
49
+ using FragS = Vec<half2, 1>; // quantization scales
50
+
51
+ } // namespace marlin_24
marlin/sparse/common/mem.h ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
3
+ * Rights Reserved.
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+ #include "base.h"
20
+
21
+ namespace marlin_24 {
22
+ // Predicated asynchronous global->shared copy; used for inputs A where we apply
23
+ // predication to handle batchsizes that are not multiples of 16.
24
+ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
25
+ const void* glob_ptr,
26
+ bool pred = true,
27
+ const bool zfill = false) {
28
+ const int BYTES = 16;
29
+ int src_in_bytes = (zfill ? 0 : BYTES);
30
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
31
+ asm volatile(
32
+ "{\n"
33
+ " .reg .pred p;\n"
34
+ " setp.ne.b32 p, %0, 0;\n"
35
+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
36
+ "}\n" ::"r"((int)pred),
37
+ "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
38
+ }
39
+
40
+ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
41
+ bool pred = true) {
42
+ const int BYTES = 16;
43
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
44
+ asm volatile(
45
+ "{\n"
46
+ " .reg .pred p;\n"
47
+ " setp.ne.b32 p, %0, 0;\n"
48
+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
49
+ "}\n" ::"r"((int)pred),
50
+ "r"(smem), "l"(glob_ptr), "n"(BYTES));
51
+ }
52
+
53
+ // Asynchronous global->shared copy
54
+ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
55
+ const int BYTES = 16;
56
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
57
+ asm volatile(
58
+ "{\n"
59
+ " cp.async.cg.shared.global [%0], [%1], %2;\n"
60
+ "}\n" ::"r"(smem),
61
+ "l"(glob_ptr), "n"(BYTES));
62
+ }
63
+
64
+ // Async copy fence.
65
+ __device__ inline void cp_async_fence() {
66
+ asm volatile("cp.async.commit_group;\n" ::);
67
+ }
68
+
69
+ // Wait until at most `n` async copy stages are still pending.
70
+ template <int n>
71
+ __device__ inline void cp_async_wait() {
72
+ asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
73
+ }
74
+
75
+ // Instruction for loading a full 16x16 matrix fragment of operand A from shared
76
+ // memory, directly in tensor core layout.
77
+ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
78
+ uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
79
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
80
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
81
+ : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
82
+ : "r"(smem));
83
+ }
84
+
85
+ __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
86
+ uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
87
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
88
+ asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
89
+ : "=r"(a[0]), "=r"(a[1])
90
+ : "r"(smem));
91
+ }
92
+
93
+ // Instruction for loading a full 16x16 matrix fragment of operand A from shared
94
+ // memory, directly in tensor core layout.
95
+ __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
96
+ uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
97
+ uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
98
+ asm volatile(
99
+ "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
100
+ : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
101
+ : "r"(smem));
102
+ }
103
+
104
+ // Wait until barrier reaches `count`, then lock for current threadblock.
105
+ __device__ inline void barrier_acquire(int* lock, int count) {
106
+ if (threadIdx.x == 0) {
107
+ int state = -1;
108
+ do
109
+ // Guarantee that subsequent writes by this threadblock will be visible
110
+ // globally.
111
+ asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
112
+ : "=r"(state)
113
+ : "l"(lock));
114
+ while (state != count);
115
+ }
116
+ __syncthreads();
117
+ }
118
+
119
+ // Release barrier and increment visitation count.
120
+ __device__ inline void barrier_release(int* lock, bool reset = false) {
121
+ __syncthreads();
122
+ if (threadIdx.x == 0) {
123
+ if (reset) {
124
+ lock[0] = 0;
125
+ return;
126
+ }
127
+ int val = 1;
128
+ // Make sure that all writes since acquiring this barrier are visible
129
+ // globally, while releasing the barrier.
130
+ asm volatile("fence.acq_rel.gpu;\n");
131
+ asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
132
+ :
133
+ : "l"(lock), "r"(val));
134
+ }
135
+ }
136
+ } // namespace marlin_24
marlin/sparse/common/mma.h ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
3
+ * Rights Reserved.
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+ #include "base.h"
20
+ #include <cudaTypedefs.h>
21
+
22
+ namespace marlin_24 {
23
+
24
+ // On CUDA earlier than 12.5, the ordered_metadata version of this instruction
25
+ // is not supported. On later versions of CUDA the version without ordered
26
+ // metadata results in the following warning:
27
+ // | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
28
+ // | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
29
+ // | reduced performance on some future architectures
30
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12050
31
+ #define MMA_SP_INST \
32
+ "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
33
+ #else
34
+ #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
35
+ #endif
36
+
37
+ // m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
38
+ // output/accumulation.
39
+ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
40
+ const FragA& frag_b, FragC& frag_c, FragM& frag_m,
41
+ const int psel) {
42
+ const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
43
+ const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
44
+ const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
45
+ const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
46
+
47
+ float* c = reinterpret_cast<float*>(&frag_c);
48
+ if (psel == 0) {
49
+ asm volatile(MMA_SP_INST
50
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
51
+ "{%12,%13,%14,%15}, %16, 0x0;\n"
52
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
53
+ : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
54
+ "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
55
+ "f"(c[2]), "f"(c[3]), "r"(e[0]));
56
+ asm volatile(MMA_SP_INST
57
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
58
+ "{%12,%13,%14,%15}, %16, 0x0;\n"
59
+ : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
60
+ : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
61
+ "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
62
+ "f"(c[6]), "f"(c[7]), "r"(e[0]));
63
+ } else {
64
+ asm volatile(MMA_SP_INST
65
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
66
+ "{%12,%13,%14,%15}, %16, 0x1;\n"
67
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
68
+ : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
69
+ "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
70
+ "f"(c[2]), "f"(c[3]), "r"(e[0]));
71
+ asm volatile(MMA_SP_INST
72
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
73
+ "{%12,%13,%14,%15}, %16, 0x1;\n"
74
+ : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
75
+ : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
76
+ "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
77
+ "f"(c[6]), "f"(c[7]), "r"(e[0]));
78
+ }
79
+ }
80
+
81
+ // Lookup-table based 3-input logical operation; explicitly used for
82
+ // dequantization as the compiler does not seem to automatically recognize it in
83
+ // all cases.
84
+ template <int lut>
85
+ __device__ inline int lop3(int a, int b, int c) {
86
+ int res;
87
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
88
+ : "=r"(res)
89
+ : "r"(a), "r"(b), "r"(c), "n"(lut));
90
+ return res;
91
+ }
92
+
93
+ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
94
+ float c3) {
95
+ uint2 r;
96
+ asm("{\n\t"
97
+ ".reg .f16 a, b, c, d; \n\t"
98
+ "cvt.rn.f16.f32 a, %2; \n\t"
99
+ "cvt.rn.f16.f32 b, %3; \n\t"
100
+ "cvt.rn.f16.f32 c, %4; \n\t"
101
+ "cvt.rn.f16.f32 d, %5; \n\t"
102
+ "mov.b32 %0, {a, b}; \n\t"
103
+ "mov.b32 %1, {c, d}; \n\t"
104
+ "}"
105
+ : "=r"(r.x), "=r"(r.y)
106
+ : "f"(c0), "f"(c1), "f"(c2), "f"(c3));
107
+ return r;
108
+ }
109
+
110
+ // Constructs destination register by taking bytes from 2 sources (based on
111
+ // mask)
112
+ template <int start_byte, int mask>
113
+ __device__ inline uint32_t prmt(uint32_t a) {
114
+ uint32_t res;
115
+ asm volatile("prmt.b32 %0, %1, %2, %3;\n"
116
+ : "=r"(res)
117
+ : "r"(a), "n"(start_byte), "n"(mask));
118
+ return res;
119
+ }
120
+
121
+ // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
122
+ // values. We mostly follow the strategy in the link below, with some small
123
+ // changes:
124
+ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
125
+ __device__ inline FragB dequant_4bit(int q) {
126
+ const int LO = 0x000f000f;
127
+ const int HI = 0x00f000f0;
128
+ const int EX = 0x64006400;
129
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
130
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
131
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
132
+ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
133
+ // directly into `SUB` and `ADD`.
134
+ const int SUB = 0x64086408;
135
+ const int MUL = 0x2c002c00;
136
+ const int ADD = 0xd480d480;
137
+
138
+ FragB frag_b;
139
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
140
+ *reinterpret_cast<const half2*>(&SUB));
141
+ frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
142
+ *reinterpret_cast<const half2*>(&MUL),
143
+ *reinterpret_cast<const half2*>(&ADD));
144
+ return frag_b;
145
+ }
146
+
147
+ // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
148
+ // values. We mostly follow the strategy in the link below, with some small
149
+ // changes:
150
+ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
151
+ __device__ inline FragB dequant_8bit(int q) {
152
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
153
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
154
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
155
+
156
+ uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
157
+ uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
158
+
159
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
160
+
161
+ FragB frag_b;
162
+ frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
163
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
164
+ frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
165
+ *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
166
+ return frag_b;
167
+ }
168
+
169
+ // Multiply dequantized values by the corresponding quantization scale; used
170
+ // only for grouped quantization.
171
+ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
172
+ half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
173
+ frag_b[0] = __hmul2(frag_b[0], s);
174
+ frag_b[1] = __hmul2(frag_b[1], s);
175
+ }
176
+
177
+ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
178
+ FragS& s0, float* c4, float* c5, float* c6,
179
+ float* c7, FragS& s1) {
180
+ *c0 = __fmul_rn(*c0, __half2float(s0[0].x));
181
+ *c1 = __fmul_rn(*c1, __half2float(s0[0].y));
182
+ *c2 = __fmul_rn(*c2, __half2float(s0[1].x));
183
+ *c3 = __fmul_rn(*c3, __half2float(s0[1].y));
184
+
185
+ *c4 = __fmul_rn(*c4, __half2float(s1[0].x));
186
+ *c5 = __fmul_rn(*c5, __half2float(s1[0].y));
187
+ *c6 = __fmul_rn(*c6, __half2float(s1[1].x));
188
+ *c7 = __fmul_rn(*c7, __half2float(s1[1].y));
189
+ }
190
+
191
+ } // namespace marlin_24
marlin/sparse/marlin_24_cuda_kernel.cu ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Notice: This file was modified by Neuralmagic inc to include 8-bit support
3
+ *
4
+ * Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
5
+ * Rights Reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+ #include <torch/all.h>
20
+
21
+ #include <ATen/cuda/CUDAContext.h>
22
+ #include <c10/cuda/CUDAGuard.h>
23
+ #include <cuda.h>
24
+ #include <cuda_fp16.h>
25
+ #include <cuda_runtime.h>
26
+
27
+ #include <iostream>
28
+
29
+ #include "common/base.h"
30
+ #include "core/scalar_type.hpp"
31
+
32
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
33
+
34
+ #else
35
+
36
+ #include "common/mem.h"
37
+ #include "common/mma.h"
38
+
39
+ #endif
40
+
41
+ template <typename T>
42
+ inline std::string str(T x) {
43
+ return std::to_string(x);
44
+ }
45
+
46
+ namespace marlin_24 {
47
+
48
+ // 8 warps are a good choice since every SM has 4 schedulers and having more
49
+ // than 1 warp per schedule allows some more latency hiding. At the same time,
50
+ // we want relatively few warps to have many registers per warp and small tiles.
51
+ static constexpr int THREADS = 256;
52
+ static constexpr int STAGES = 4;
53
+
54
+ static constexpr int min_thread_n = 128;
55
+
56
+ static constexpr int tile_size = 16;
57
+ static constexpr int max_par = 64;
58
+
59
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
60
+
61
+ template <const int num_bits, // weight bits
62
+ const int threads, // number of threads in a threadblock
63
+ const int thread_m_blocks, // number of 16x16 blocks in the m
64
+ // dimension (batchsize) of the
65
+ // threadblock
66
+ const int thread_n_blocks, // same for n dimension (output)
67
+ const int thread_k_blocks, // same for k dimension (reduction)
68
+ const int stages, // number of stages for the async global->shared
69
+ // fetch pipeline
70
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
71
+ // with a separate quantization scale
72
+ >
73
+ __global__ void Marlin_24(
74
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
75
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
76
+ const int4* __restrict__ meta, // 2bit metadata information about 2:4
77
+ // format on B
78
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
79
+ const int4* __restrict__ s, // fp16 quantization scales of shape
80
+ // (k/groupsize)xn
81
+ int prob_m, // batch dimension m
82
+ int prob_n, // output dimension n
83
+ int prob_k, // reduction dimension k
84
+ int* locks // extra global storage for barrier synchronization
85
+ ) {}
86
+
87
+ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
88
+ torch::Tensor& b_meta,
89
+ torch::Tensor& b_scales,
90
+ torch::Tensor& workspace,
91
+ vllm::ScalarTypeId const b_q_type_id,
92
+ int64_t size_m, int64_t size_n,
93
+ int64_t size_k) {
94
+ TORCH_CHECK_NOT_IMPLEMENTED(
95
+ false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0");
96
+ return torch::empty({1, 1});
97
+ }
98
+
99
+ #else
100
+
101
+ template <const int num_bits, // weight bits
102
+ const int threads, // number of threads in a threadblock
103
+ const int thread_m_blocks, // number of 16x16 blocks in the m
104
+ // dimension (batchsize) of the
105
+ // threadblock
106
+ const int thread_n_blocks, // same for n dimension (output)
107
+ const int thread_k_blocks, // same for k dimension (reduction)
108
+ const int stages, // number of stages for the async global->shared
109
+ // fetch pipeline
110
+ const int group_blocks = -1 // number of consecutive 16x16 blocks
111
+ // with a separate quantization scale
112
+ >
113
+ __global__ void Marlin_24(
114
+ const int4* __restrict__ A, // fp16 input matrix of shape mxk
115
+ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
116
+ const int4* __restrict__ meta, // 2bit metadata information about 2:4
117
+ // format on B
118
+ int4* __restrict__ C, // fp16 output buffer of shape mxn
119
+ const int4* __restrict__ s, // fp16 quantization scales of shape
120
+ // (k/groupsize)xn
121
+ int prob_m, // batch dimension m
122
+ int prob_n, // output dimension n
123
+ int prob_k, // reduction dimension k
124
+ int* locks // extra global storage for barrier synchronization
125
+ ) {
126
+ // Each threadblock processes one "stripe" of the B matrix with (roughly) the
127
+ // same size, which might involve multiple column "slices" (of width 16 *
128
+ // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
129
+ // example:
130
+ // 0 1 3
131
+ // 0 2 3
132
+ // 1 2 4
133
+ // While this kind of partitioning makes things somewhat more complicated, it
134
+ // ensures good utilization of all SMs for many kinds of shape and GPU
135
+ // configurations, while requiring as few slow global cross-threadblock
136
+ // reductions as possible.
137
+
138
+ // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
139
+ // better partitioning with less reductions
140
+ int parallel = 1;
141
+ if (prob_m > 16 * thread_m_blocks) {
142
+ parallel = prob_m / (16 * thread_m_blocks);
143
+ prob_m = 16 * thread_m_blocks;
144
+ }
145
+
146
+ // number of thread_k_blocks in k-dim
147
+ int k_tiles = prob_k / 32 / thread_k_blocks;
148
+ // number of thread_n_blocks in n-dim
149
+ int n_tiles = prob_n / 16 / thread_n_blocks;
150
+ // iters needed to cover all slices
151
+ int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
152
+
153
+ // Ensure that the number of tiles in each stripe is a multiple of the
154
+ // groupsize; this avoids an annoying special case where a stripe starts in
155
+ // the middle of group.
156
+ if (group_blocks != -1)
157
+ iters = (group_blocks / thread_k_blocks) *
158
+ ceildiv(iters, (group_blocks / thread_k_blocks));
159
+
160
+ int slice_row = (iters * blockIdx.x) % k_tiles;
161
+ int slice_col_par = (iters * blockIdx.x) / k_tiles;
162
+ int slice_col = slice_col_par;
163
+ // number of threadblock tiles in the current slice
164
+ int slice_iters;
165
+ // total number of active threadblocks in the current slice
166
+ int slice_count = 0;
167
+ // index of threadblock in current slice; numbered bottom to top
168
+ int slice_idx;
169
+
170
+ // We can easily implement parallel problem execution by just remapping
171
+ // indices and advancing global pointers
172
+ if (slice_col_par >= n_tiles) {
173
+ A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
174
+ C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
175
+ locks += (slice_col_par / n_tiles) * n_tiles;
176
+ slice_col = slice_col_par % n_tiles;
177
+ }
178
+
179
+ // Compute all information about the current slice which is required for
180
+ // synchronization.
181
+ auto init_slice = [&]() {
182
+ slice_iters =
183
+ iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
184
+ if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
185
+ if (slice_iters == 0) return;
186
+ if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
187
+ slice_count = 1;
188
+ slice_idx = 0;
189
+ int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
190
+ if (col_first <= k_tiles * (slice_col_par + 1)) {
191
+ int col_off = col_first - k_tiles * slice_col_par;
192
+ slice_count = ceildiv(k_tiles - col_off, iters);
193
+ if (col_off > 0) slice_count++;
194
+ int delta_first = iters * blockIdx.x - col_first;
195
+ if (delta_first < 0 || (col_off == 0 && delta_first == 0))
196
+ slice_idx = slice_count - 1;
197
+ else {
198
+ slice_idx = slice_count - 1 - delta_first / iters;
199
+ if (col_off > 0) slice_idx--;
200
+ }
201
+ }
202
+ if (slice_col == n_tiles) {
203
+ A += 16 * thread_m_blocks * prob_k / 8;
204
+ C += 16 * thread_m_blocks * prob_n / 8;
205
+ locks += n_tiles;
206
+ slice_col = 0;
207
+ }
208
+ };
209
+ init_slice();
210
+
211
+ // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
212
+ int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
213
+
214
+ // stride of an A matrix tile in shared memory
215
+ constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
216
+ // delta between subsequent A tiles in global memory
217
+ constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8;
218
+ // between subsequent accesses within a tile
219
+ int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
220
+ // between shared memory writes
221
+ constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
222
+ // between shared memory tile reads //RLC: 2 * #warps k-dim
223
+ constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4));
224
+ // within a shared memory tile
225
+ constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
226
+ // overall size of a tile
227
+ constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
228
+ // number of shared write iterations for a tile
229
+ constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
230
+
231
+ constexpr int pack_factor = 32 / num_bits;
232
+
233
+ int b_gl_stride = 16 * prob_n / (pack_factor * 4);
234
+ constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
235
+ constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
236
+ constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
237
+ int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
238
+ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
239
+ constexpr int b_sh_wr_delta = threads * b_thread_vecs;
240
+ constexpr int b_sh_rd_delta = threads * b_thread_vecs;
241
+ constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
242
+ constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
243
+
244
+ int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
245
+ constexpr int m_sh_stride =
246
+ (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
247
+ int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
248
+ int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
249
+ constexpr int m_sh_wr_delta = threads / 2;
250
+ constexpr int m_sh_rd_delta = threads / 2;
251
+ constexpr int m_sh_stage = m_sh_stride * thread_k_blocks;
252
+ constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta);
253
+
254
+ int s_gl_stride = prob_n / 8;
255
+ constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
256
+ constexpr int s_sh_stage = s_sh_stride;
257
+ int s_gl_rd_delta = s_gl_stride;
258
+
259
+ // Global A read index of current thread.
260
+ int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
261
+ (threadIdx.x % a_gl_rd_delta_o);
262
+ a_gl_rd += a_gl_rd_delta_o * slice_row;
263
+ // Shared write index of current thread.
264
+ int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
265
+ (threadIdx.x % a_gl_rd_delta_o);
266
+ // Shared read index.
267
+ int a_sh_rd =
268
+ a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
269
+ a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
270
+
271
+ int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
272
+ (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
273
+ b_gl_rd += b_sh_stride * slice_col;
274
+ b_gl_rd += b_gl_rd_delta_o * slice_row;
275
+ int b_sh_wr = threadIdx.x * b_thread_vecs;
276
+ int b_sh_rd = threadIdx.x * b_thread_vecs;
277
+
278
+ int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
279
+ (threadIdx.x % (m_sh_stride));
280
+ m_gl_rd += (m_sh_stride)*slice_col;
281
+ m_gl_rd += m_gl_rd_delta_o * slice_row;
282
+ int m_sh_wr = threadIdx.x;
283
+ int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
284
+
285
+ int s_gl_rd;
286
+ if constexpr (group_blocks == -1) {
287
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
288
+ } else {
289
+ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
290
+ s_sh_stride * slice_col + threadIdx.x;
291
+ }
292
+
293
+ int s_sh_wr = threadIdx.x;
294
+ int s_sh_rd;
295
+ // We use a different scale layout for grouped and column-wise quantization as
296
+ // we scale a `half2` tile in column-major layout in the former and in
297
+ // row-major in the latter case.
298
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
299
+ (threadIdx.x % 32) / 4; // Note that in the original Marlin kernel
300
+ // this is (threadIdx.x % 32) / 4
301
+
302
+ // Precompute which thread should not read memory in which iterations; this is
303
+ // needed if there are more threads than required for a certain tilesize or
304
+ // when the batchsize is not a multiple of 16.
305
+ bool a_sh_wr_pred[a_sh_wr_iters];
306
+ #pragma unroll
307
+ for (int i = 0; i < a_sh_wr_iters; i++) {
308
+ a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
309
+ }
310
+ bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
311
+
312
+ // To ensure that writing and reading A tiles to/from shared memory, the
313
+ // latter in fragment format, is fully bank conflict free, we need to use a
314
+ // rather fancy XOR-based layout. The key here is that neither reads nor
315
+ // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
316
+ // same shared memory banks. Further, it seems (based on NSight-Compute) that
317
+ // each warp must also write a consecutive memory segment?
318
+ auto transform_a = [&](int i) {
319
+ int row = i / a_gl_rd_delta_o;
320
+ return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
321
+ };
322
+ // Since the computation of this remapping is non-trivial and, due to our main
323
+ // loop unrolls, all shared memory accesses are static, we simply precompute
324
+ // both transformed reads and writes.
325
+ int a_sh_wr_trans[a_sh_wr_iters];
326
+ #pragma unroll
327
+ for (int i = 0; i < a_sh_wr_iters; i++)
328
+ a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
329
+ int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
330
+ #pragma unroll
331
+ for (int i = 0; i < b_sh_wr_iters; i++) {
332
+ #pragma unroll
333
+ for (int j = 0; j < thread_m_blocks; j++) {
334
+ a_sh_rd_trans[0][i][j] =
335
+ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
336
+ a_sh_rd_trans[1][i][j] =
337
+ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2);
338
+ }
339
+ }
340
+
341
+ // Since B-accesses have non-constant stride they have to be computed at
342
+ // runtime; we break dependencies between subsequent accesses with a tile by
343
+ // maintining multiple pointers (we have enough registers), a tiny
344
+ // optimization.
345
+ const int4* B_ptr[b_sh_wr_iters];
346
+ #pragma unroll
347
+ for (int i = 0; i < b_sh_wr_iters; i++)
348
+ B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
349
+
350
+ bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
351
+ const int4* meta_ptr[m_sh_iters];
352
+ #pragma unroll
353
+ for (int i = 0; i < m_sh_iters; i++)
354
+ meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
355
+
356
+ extern __shared__ int4 sh[];
357
+ // Shared memory storage for global fetch pipelines.
358
+ int4* sh_a = sh;
359
+ int4* sh_b = sh_a + (stages * a_sh_stage);
360
+ int4* sh_s = sh_b + (stages * b_sh_stage);
361
+ int4* sh_m = sh_s + (stages * s_sh_stage);
362
+ // Register storage for double buffer of shared memory reads.
363
+ FragA frag_a[2][thread_m_blocks][2];
364
+ I4 frag_b_quant[2][b_thread_vecs];
365
+ FragM frag_m[2][2];
366
+ FragC frag_c[thread_m_blocks][4][2];
367
+ FragS frag_s[2][4];
368
+
369
+ // Zero accumulators.
370
+ auto zero_accums = [&]() {
371
+ #pragma unroll
372
+ for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
373
+ reinterpret_cast<float*>(frag_c)[i] = 0;
374
+ };
375
+
376
+ // Asynchronously fetch the next A, B and s tile from global to the next
377
+ // shared memory pipeline location.
378
+ auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
379
+ if (pred) {
380
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
381
+ #pragma unroll
382
+ for (int i = 0; i < a_sh_wr_iters; i++) {
383
+ cp_async4_pred(
384
+ &sh_a_stage[a_sh_wr_trans[i]],
385
+ &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
386
+ a_sh_wr_pred[i]);
387
+ }
388
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
389
+ #pragma unroll
390
+ for (int i = 0; i < b_sh_wr_iters; i++) {
391
+ #pragma unroll
392
+ for (int j = 0; j < b_thread_vecs; j++) {
393
+ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
394
+ }
395
+ B_ptr[i] += b_gl_rd_delta_o;
396
+ }
397
+ int4* sh_meta_stage = sh_m + m_sh_stage * pipe;
398
+ #pragma unroll
399
+ for (int i = 0; i < m_sh_iters; i++) {
400
+ if (m_sh_wr_pred)
401
+ cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
402
+ meta_ptr[i] += m_gl_rd_delta_o;
403
+ }
404
+ // Only fetch scales if this tile starts a new group
405
+ if constexpr (group_blocks != -1) {
406
+ // This assumes group_blocks >= thread_k_blocks
407
+ // and would need to be modified to support smaller groups.
408
+ static_assert(group_blocks >= thread_k_blocks);
409
+ if (pipe % (group_blocks / thread_k_blocks) == 0) {
410
+ int4* sh_s_stage = sh_s + s_sh_stage * pipe;
411
+ if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
412
+ s_gl_rd += s_gl_rd_delta;
413
+ }
414
+ }
415
+ }
416
+ // Insert a fence even when we are winding down the pipeline to ensure that
417
+ // waiting is also correct at this point.
418
+ cp_async_fence();
419
+ };
420
+
421
+ // Wait until the next thread tile has been loaded to shared memory.
422
+ auto wait_for_stage = [&]() {
423
+ // We only have `stages - 2` active fetches since we are double buffering
424
+ // and can only issue the next fetch when it is guaranteed that the previous
425
+ // shared memory load is fully complete (as it may otherwise be
426
+ // overwritten).
427
+ cp_async_wait<stages - 2>();
428
+ __syncthreads();
429
+ };
430
+
431
+ // Load the next sub-tile from the current location in the shared memory pipe
432
+ // into the current register buffer.
433
+ auto fetch_to_registers = [&](int k, int pipe) {
434
+ // It may seem inefficient that we reload the groups for every sub-tile;
435
+ // however, this does not seem to be a significant bottleneck, while some
436
+ // theoretically better attempts have lead to bad instruction ordering by
437
+ // the compiler and correspondingly a noticeable drop in performance.
438
+ if constexpr (group_blocks != -1) {
439
+ // This assumes group_blocks >= thread_k_blocks
440
+ // and would need to be modified to support smaller groups.
441
+ static_assert(group_blocks >= thread_k_blocks);
442
+ int4* sh_s_stage =
443
+ sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
444
+ (pipe / (group_blocks / thread_k_blocks)));
445
+ reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
446
+ }
447
+ int4* sh_a_stage = sh_a + a_sh_stage * pipe;
448
+ #pragma unroll
449
+ for (int i = 0; i < thread_m_blocks; i++) {
450
+ ldsm4(frag_a[k % 2][i][0],
451
+ &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
452
+ ldsm4(frag_a[k % 2][i][1],
453
+ &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
454
+ }
455
+
456
+ int4* sh_b_stage = sh_b + b_sh_stage * pipe;
457
+ #pragma unroll
458
+ for (int i = 0; i < b_thread_vecs; i++) {
459
+ frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
460
+ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
461
+ }
462
+
463
+ // Load meta with ldsm4
464
+ int4* sh_m_stage = sh_m + m_sh_stage * pipe;
465
+ ldsm4_m(frag_m[k % 2][0],
466
+ &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
467
+ };
468
+
469
+ // Execute the actual tensor core matmul of a sub-tile.
470
+ auto matmul = [&](int k) {
471
+ // We have the m dimension as the inner loop in order to encourage overlapping
472
+ // dequantization and matmul operations.
473
+ #pragma unroll
474
+ for (int j = 0; j < 4; j++) {
475
+ FragB frag_b0;
476
+ FragB frag_b1;
477
+
478
+ if constexpr (num_bits == 4) {
479
+ int b_quant = frag_b_quant[k % 2][0][j];
480
+ int b_quant_shift = b_quant >> 8;
481
+
482
+ frag_b0 = dequant_4bit(b_quant);
483
+ frag_b1 = dequant_4bit(b_quant_shift);
484
+
485
+ } else {
486
+ int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
487
+ int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
488
+ int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
489
+
490
+ frag_b0 = dequant_8bit(b_quant_0);
491
+ frag_b1 = dequant_8bit(b_quant_1);
492
+ }
493
+
494
+ // If there are no groups, we can just scale the final output once and can
495
+ // avoid doing so for each weight.
496
+ if constexpr (group_blocks != -1) {
497
+ scale(frag_b0, frag_s[k % 2][j], 0);
498
+ }
499
+ if constexpr (group_blocks != -1) {
500
+ scale(frag_b1, frag_s[k % 2][j], 1);
501
+ }
502
+
503
+ #pragma unroll
504
+ for (int i = 0; i < thread_m_blocks; i++) {
505
+ mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
506
+ frag_m[k % 2][j / 2], j % 2);
507
+ }
508
+ }
509
+ };
510
+
511
+ // Since we slice across the k dimension of a tile in order to increase the
512
+ // number of warps while keeping the n dimension of a tile reasonable, we have
513
+ // multiple warps that accumulate their partial sums of the same output
514
+ // location; which we have to reduce over in the end. We do in shared memory.
515
+ auto thread_block_reduce = [&]() {
516
+ constexpr int red_off = threads / b_sh_stride_threads / 2;
517
+ if (red_off >= 1) {
518
+ int red_idx = threadIdx.x / b_sh_stride_threads;
519
+ constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
520
+ constexpr int red_sh_delta = b_sh_stride_threads;
521
+ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
522
+ (threadIdx.x % b_sh_stride_threads);
523
+
524
+ // Parallel logarithmic shared memory reduction. We make sure to avoid any
525
+ // unnecessary read or write iterations, e.g., for two warps we write only
526
+ // once by warp 1 and read only once by warp 0.
527
+ #pragma unroll
528
+ for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
529
+ #pragma unroll
530
+ for (int i = red_off; i > 0; i /= 2) {
531
+ if (i <= red_idx && red_idx < 2 * i) {
532
+ #pragma unroll
533
+ for (int j = 0; j < 4 * 2; j++) {
534
+ int red_sh_wr =
535
+ red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
536
+ if (i < red_off) {
537
+ float* c_rd =
538
+ reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
539
+ float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
540
+ #pragma unroll
541
+ for (int k = 0; k < 4; k++)
542
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
543
+ c_rd[k] + c_wr[k];
544
+ }
545
+ sh[red_sh_wr] =
546
+ reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
547
+ }
548
+ }
549
+ __syncthreads();
550
+ }
551
+ if (red_idx == 0) {
552
+ #pragma unroll
553
+ for (int i = 0; i < 4 * 2; i++) {
554
+ float* c_rd =
555
+ reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
556
+ #pragma unroll
557
+ for (int j = 0; j < 4; j++)
558
+ reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
559
+ c_rd[j];
560
+ }
561
+ }
562
+ __syncthreads();
563
+ }
564
+ }
565
+ };
566
+
567
+ // Since multiple threadblocks may process parts of the same column slice, we
568
+ // finally have to globally reduce over the results. As the striped
569
+ // partitioning minimizes the number of such reductions and our outputs are
570
+ // usually rather small, we perform this reduction serially in L2 cache.
571
+ auto global_reduce = [&](bool first = false, bool last = false) {
572
+ // We are very careful here to reduce directly in the output buffer to
573
+ // maximize L2 cache utilization in this step. To do this, we write out
574
+ // results in FP16 (but still reduce with FP32 compute).
575
+ constexpr int active_threads = 32 * thread_n_blocks / 4;
576
+ if (threadIdx.x < active_threads) {
577
+ int c_gl_stride = prob_n / 8;
578
+ int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
579
+ int c_gl_wr_delta_i =
580
+ c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
581
+ int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
582
+ 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
583
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
584
+ constexpr int c_sh_wr_delta = active_threads;
585
+ int c_sh_wr = threadIdx.x;
586
+
587
+ int col = 2 * ((threadIdx.x % 32) % 4);
588
+
589
+ if (!first) {
590
+ // Interestingly, doing direct global accesses here really seems to mess up
591
+ // the compiler and lead to slowdowns, hence we also use async-copies even
592
+ // though these fetches are not actually asynchronous.
593
+ #pragma unroll
594
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
595
+ cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
596
+ &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
597
+ c_gl_wr_delta_i * (i % 2)],
598
+ i < (thread_m_blocks - 1) * 4 ||
599
+ 8 * (i / 2) + col + (i % 2) < prob_m);
600
+ }
601
+ cp_async_fence();
602
+ cp_async_wait<0>();
603
+ }
604
+
605
+ #pragma unroll
606
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
607
+ if (i < (thread_m_blocks - 1) * 4 ||
608
+ 8 * (i / 2) + col + (i % 2) < prob_m) {
609
+ if (!first) {
610
+ int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
611
+ #pragma unroll
612
+ for (int j2 = 0; j2 < 2; j2++) {
613
+ #pragma unroll
614
+ for (int j1 = 0; j1 < 4; j1++) {
615
+ reinterpret_cast<float*>(
616
+ &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
617
+ 4 * ((i % 4) / 2) + i % 2] +=
618
+ __half2float(
619
+ reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]);
620
+ }
621
+ }
622
+ }
623
+ if (!last) {
624
+ int4 c;
625
+ #pragma unroll
626
+ for (int j2 = 0; j2 < 2; j2++) {
627
+ #pragma unroll
628
+ for (int j1 = 0; j1 < 4; j1++) {
629
+ reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] =
630
+ __float2half(reinterpret_cast<float*>(
631
+ &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
632
+ 4 * ((i % 4) / 2) + i % 2]);
633
+ }
634
+ }
635
+ C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
636
+ c;
637
+ }
638
+ }
639
+ }
640
+ }
641
+ };
642
+
643
+ // Write out the reduce final result in the correct layout. We only actually
644
+ // reshuffle matrix fragments in this step, the reduction above is performed
645
+ // in fragment layout.
646
+ auto write_result = [&]() {
647
+ int c_gl_stride = prob_n / 8;
648
+
649
+ constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
650
+ constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
651
+ constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
652
+
653
+ int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
654
+
655
+ int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
656
+ (threadIdx.x % (2 * thread_n_blocks));
657
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
658
+
659
+ int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
660
+ ((threadIdx.x % 32) / 4); // RLC:
661
+ c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
662
+
663
+ constexpr int c_sh_rd_delta =
664
+ c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
665
+ int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
666
+ (threadIdx.x % (2 * 2 * thread_n_blocks));
667
+
668
+ int c_gl_wr_end = c_gl_stride * prob_m;
669
+
670
+ auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0,
671
+ float c4, float c5, float c6, float c7, FragS& s1) {
672
+ uint2 res[2];
673
+ res[0] = to_half4(c0, c1, c2, c3);
674
+ res[1] = to_half4(c4, c5, c6, c7);
675
+ half2* tmp = (half2*)&res;
676
+ // for per-column quantization we finally apply the scale here
677
+ if constexpr (group_blocks == -1 && num_bits == 4) {
678
+ tmp[0] = __hmul2(tmp[0], s0[0]);
679
+ tmp[1] = __hmul2(tmp[1], s0[1]);
680
+ tmp[2] = __hmul2(tmp[2], s1[0]);
681
+ tmp[3] = __hmul2(tmp[3], s1[1]);
682
+ }
683
+ ((int4*)sh)[idx] = *((int4*)&res[0]);
684
+ };
685
+
686
+ // RLC: only warp 0 and 1 baseline example
687
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
688
+ #pragma unroll
689
+ for (int i = 0; i < thread_m_blocks; i++) {
690
+ int wr = c_sh_wr;
691
+ write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
692
+ frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2],
693
+ frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2],
694
+ frag_s[0][2]);
695
+ write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1],
696
+ frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0],
697
+ frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3],
698
+ frag_c[i][3][0][3], frag_s[0][2]);
699
+ write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0],
700
+ frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0],
701
+ frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2],
702
+ frag_c[i][3][1][2], frag_s[0][2]);
703
+ write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1],
704
+ frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1],
705
+ frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3],
706
+ frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]);
707
+
708
+ c_sh_wr += 8 * c_sh_stride_2;
709
+ }
710
+ }
711
+ __syncthreads();
712
+
713
+ #pragma unroll
714
+ for (int i = 0;
715
+ i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
716
+ i++) {
717
+ if (c_gl_wr < c_gl_wr_end) {
718
+ C[c_gl_wr] = sh[c_sh_rd];
719
+ c_gl_wr += c_gl_wr_delta;
720
+ c_sh_rd += c_sh_rd_delta;
721
+ }
722
+ }
723
+ };
724
+
725
+ // Start global fetch and register load pipelines.
726
+ auto start_pipes = [&]() {
727
+ #pragma unroll
728
+ for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
729
+ zero_accums();
730
+ wait_for_stage();
731
+ fetch_to_registers(0, 0);
732
+ a_gl_rd += a_gl_rd_delta_o * (stages - 1);
733
+ };
734
+ start_pipes();
735
+
736
+ // Main loop.
737
+ while (slice_iters) {
738
+ // We unroll over both the global fetch and the register load pipeline to
739
+ // ensure all shared memory accesses are static. Note that both pipelines have
740
+ // even length meaning that the next iteration will always start at index 0.
741
+ #pragma unroll
742
+ for (int pipe = 0; pipe < stages;) {
743
+ fetch_to_shared((pipe + stages - 1) % stages, pipe,
744
+ slice_iters >= stages);
745
+ matmul(pipe);
746
+ wait_for_stage();
747
+
748
+ fetch_to_registers(pipe + 1, (pipe + 1) % stages);
749
+
750
+ pipe++;
751
+ slice_iters--;
752
+ if (slice_iters == 0) break;
753
+ }
754
+ a_gl_rd += a_gl_rd_delta_o * stages;
755
+
756
+ // Process results and, if necessary, proceed to the next column slice.
757
+ // While this pattern may not be the most readable, other ways of writing
758
+ // the loop seemed to noticeably worse performance after compilation.
759
+ if (slice_iters == 0) {
760
+ cp_async_wait<0>();
761
+ bool last = slice_idx == slice_count - 1;
762
+ // For per-column scales, we only fetch them here in the final step before
763
+ // write-out
764
+ if constexpr (group_blocks == -1) {
765
+ if constexpr (num_bits == 8) {
766
+ if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
767
+ cp_async_fence();
768
+ } else {
769
+ if (last) {
770
+ if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
771
+ cp_async_fence();
772
+ }
773
+ }
774
+ }
775
+ thread_block_reduce();
776
+
777
+ if constexpr (group_blocks == -1) {
778
+ if constexpr (num_bits == 8) {
779
+ cp_async_wait<0>();
780
+ __syncthreads();
781
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
782
+ *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
783
+ }
784
+ } else {
785
+ if (last) {
786
+ cp_async_wait<0>();
787
+ __syncthreads();
788
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
789
+ *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
790
+ }
791
+ }
792
+ }
793
+ }
794
+
795
+ // For 8-bit channelwise, we apply the scale before the global reduction
796
+ // that converts the fp32 results to fp16 (so that we avoid possible
797
+ // overflow in fp16)
798
+ if constexpr (group_blocks == -1 && num_bits == 8) {
799
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
800
+ #pragma unroll
801
+ for (int i = 0; i < thread_m_blocks; i++) {
802
+ scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
803
+ &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
804
+ &frag_c[i][0][0][2], &frag_c[i][1][0][2],
805
+ &frag_c[i][2][0][2], &frag_c[i][3][0][2],
806
+ frag_s[0][2]);
807
+
808
+ scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1],
809
+ &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0],
810
+ &frag_c[i][0][0][3], &frag_c[i][1][0][3],
811
+ &frag_c[i][2][0][3], &frag_c[i][3][0][3],
812
+ frag_s[0][2]);
813
+
814
+ scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0],
815
+ &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0],
816
+ &frag_c[i][0][1][2], &frag_c[i][1][1][2],
817
+ &frag_c[i][2][1][2], &frag_c[i][3][1][2],
818
+ frag_s[0][2]);
819
+
820
+ scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1],
821
+ &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0],
822
+ &frag_c[i][0][1][3], &frag_c[i][1][1][3],
823
+ &frag_c[i][2][1][3], &frag_c[i][3][1][3],
824
+ frag_s[0][2]);
825
+ }
826
+ }
827
+ }
828
+
829
+ if (slice_count > 1) { // only globally reduce if there is more than one
830
+ // block in a slice
831
+ barrier_acquire(&locks[slice_col], slice_idx);
832
+ global_reduce(slice_idx == 0, last);
833
+ barrier_release(&locks[slice_col], last);
834
+ }
835
+ if (last) // only the last block in a slice actually writes the result
836
+ write_result();
837
+
838
+ slice_row = 0;
839
+ slice_col_par++;
840
+ slice_col++;
841
+ init_slice();
842
+ if (slice_iters) {
843
+ a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
844
+ (threadIdx.x % a_gl_rd_delta_o);
845
+ #pragma unroll
846
+ for (int i = 0; i < b_sh_wr_iters; i++)
847
+ B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
848
+ #pragma unroll
849
+ for (int i = 0; i < m_sh_iters; i++)
850
+ meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
851
+ if (slice_col == 0) {
852
+ #pragma unroll
853
+ for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
854
+ #pragma unroll
855
+ for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
856
+ }
857
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
858
+ start_pipes();
859
+ }
860
+ }
861
+ }
862
+ }
863
+
864
+ #endif
865
+
866
+ #define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
867
+ THREAD_K_BLOCKS, GROUP_BLOCKS) \
868
+ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
869
+ thread_n_blocks == THREAD_N_BLOCKS && \
870
+ thread_k_blocks == THREAD_K_BLOCKS && \
871
+ group_blocks == GROUP_BLOCKS) { \
872
+ cudaFuncSetAttribute( \
873
+ Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
874
+ THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
875
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
876
+ Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
877
+ THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
878
+ <<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
879
+ C_ptr, s_ptr, prob_n, \
880
+ prob_m, prob_k, locks); \
881
+ }
882
+
883
+ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
884
+ void* s, int prob_m, int prob_n, int prob_k,
885
+ void* workspace, int num_bits, int groupsize = -1,
886
+ int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
887
+ int thread_m = -1, int sms = -1, int max_par = 16) {
888
+ int tot_n = prob_n;
889
+ int tot_n_blocks = ceildiv(tot_n, 16);
890
+ int pad = 16 * tot_n_blocks - tot_n;
891
+
892
+ if (sms == -1) {
893
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
894
+ }
895
+ TORCH_CHECK(sms > 0);
896
+
897
+ int max_shared_mem = 0;
898
+ cudaDeviceGetAttribute(&max_shared_mem,
899
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
900
+ TORCH_CHECK(max_shared_mem > 0);
901
+
902
+ if (thread_k == -1 || thread_m == -1) {
903
+ if (prob_n <= 16) {
904
+ // For small batchizes, better partitioningif is slightly more important
905
+ // than better compute utilization
906
+ thread_k = 128;
907
+ thread_m = 128;
908
+ } else {
909
+ thread_k = 64;
910
+ thread_m = 256;
911
+ }
912
+ // Also had
913
+ // if prob_n > 256
914
+ // thread_k = 32;
915
+ // thread_m = 512;
916
+ // but this is broken,
917
+ // TODO(Lucas, Alex M): figure out why
918
+ }
919
+
920
+ int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
921
+ int thread_m_blocks = thread_m / 16;
922
+ int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
923
+ int blocks = sms;
924
+
925
+ TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m,
926
+ " is not divisible by thread_m = ", thread_m);
927
+ TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
928
+ " is not divisible by thread_k = ", thread_k);
929
+ if (group_blocks != -1) {
930
+ TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2,
931
+ " is not divisible by group_blocks = ", group_blocks);
932
+ }
933
+
934
+ TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
935
+ ", ", prob_n, ", ", prob_k, "]");
936
+
937
+ const int4* A_ptr = (const int4*)A;
938
+ const int4* B_ptr = (const int4*)B;
939
+ const int4* meta_ptr = (const int4*)meta;
940
+ int4* C_ptr = (int4*)C;
941
+ const int4* s_ptr = (const int4*)s;
942
+
943
+ constexpr int max_m_blocks = 4;
944
+
945
+ int* locks = (int*)workspace;
946
+ for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
947
+ int thread_n_blocks = tot_n_blocks - i;
948
+ prob_n = tot_n - 16 * i;
949
+ int par = 1;
950
+ if (thread_n_blocks > max_m_blocks) {
951
+ // Note that parallel > 1 currently only works for inputs without any
952
+ // padding
953
+ par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
954
+ if (par > max_par) par = max_par;
955
+ prob_n = (max_m_blocks * 16) * par;
956
+ i += max_m_blocks * (par - 1);
957
+ thread_n_blocks = max_m_blocks;
958
+ }
959
+
960
+ // For compilation speed, we only define the kernel configurations that have
961
+ // seemed useful (in terms of performance) in our testing, however many more
962
+ // are, in principle, possible.
963
+
964
+ // the false is start of the CALL_IF macros
965
+ if (false) {
966
+ } // BMxBNxBK, group
967
+ // 4-bit
968
+ CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
969
+ CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
970
+
971
+ CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
972
+ CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
973
+ CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
974
+ CALL_IF_2_4(4, 16, 2, 2, 4)
975
+ CALL_IF_2_4(4, 16, 3, 2, -1)
976
+ CALL_IF_2_4(4, 16, 3, 2, 4)
977
+ CALL_IF_2_4(4, 16, 4, 2, -1)
978
+ CALL_IF_2_4(4, 16, 4, 2, 4)
979
+
980
+ CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
981
+ CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
982
+ CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
983
+ CALL_IF_2_4(4, 32, 2, 1, 4)
984
+ CALL_IF_2_4(4, 32, 3, 1, -1)
985
+ CALL_IF_2_4(4, 32, 3, 1, 4)
986
+ CALL_IF_2_4(4, 32, 4, 1, -1)
987
+ CALL_IF_2_4(4, 32, 4, 1, 4)
988
+
989
+ // 8-bit
990
+ CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
991
+ CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
992
+
993
+ CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
994
+ CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
995
+ CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
996
+ CALL_IF_2_4(8, 16, 2, 2, 4)
997
+ CALL_IF_2_4(8, 16, 3, 2, -1)
998
+ CALL_IF_2_4(8, 16, 3, 2, 4)
999
+ CALL_IF_2_4(8, 16, 4, 2, -1)
1000
+ CALL_IF_2_4(8, 16, 4, 2, 4)
1001
+
1002
+ CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
1003
+ CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
1004
+ CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
1005
+ CALL_IF_2_4(8, 32, 2, 1, 4)
1006
+ CALL_IF_2_4(8, 32, 3, 1, -1)
1007
+ CALL_IF_2_4(8, 32, 3, 1, 4)
1008
+ CALL_IF_2_4(8, 32, 4, 1, -1)
1009
+ CALL_IF_2_4(8, 32, 4, 1, 4)
1010
+ else {
1011
+ throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
1012
+ ", " + str(prob_k) + ", " + str(prob_n) + "]" +
1013
+ ", groupsize = " + str(groupsize) +
1014
+ ", thread_m_blocks = " + str(thread_m_blocks) +
1015
+ ", thread_n_blocks = " + str(thread_n_blocks) +
1016
+ ", thread_k_blocks = " + str(thread_k_blocks));
1017
+ }
1018
+
1019
+ A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par;
1020
+ C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par;
1021
+ }
1022
+ }
1023
+
1024
+ } // namespace marlin_24
1025
+
1026
+ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
1027
+ torch::Tensor& b_meta,
1028
+ torch::Tensor& b_scales,
1029
+ torch::Tensor& workspace,
1030
+ vllm::ScalarTypeId const b_q_type_id,
1031
+ int64_t size_m, int64_t size_n,
1032
+ int64_t size_k) {
1033
+ vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
1034
+ // Verify num_bits
1035
+ TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
1036
+ "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
1037
+ int pack_factor = 32 / b_q_type.size_bits();
1038
+
1039
+ // Verify M
1040
+ TORCH_CHECK(size_m == a.size(0),
1041
+ "Shape mismatch: a.size(0) = " + str(a.size(0)) +
1042
+ ", size_m = " + str(size_m));
1043
+
1044
+ // Verify K
1045
+ TORCH_CHECK(size_k == a.size(1),
1046
+ "Shape mismatch: a.size(1) = " + str(a.size(1)) +
1047
+ ", size_k = " + str(size_k));
1048
+ TORCH_CHECK(size_k % marlin_24::tile_size == 0,
1049
+ "size_k = " + str(size_k) + " is not divisible by tile_size = " +
1050
+ str(marlin_24::tile_size));
1051
+ TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0),
1052
+ "Shape mismatch: b_q_weight.size(0) = " +
1053
+ str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
1054
+ ", tile_size = " + str(marlin_24::tile_size));
1055
+
1056
+ // Verify N
1057
+ TORCH_CHECK(b_scales.size(1) == size_n,
1058
+ "b_scales.size(1) = " + str(b_scales.size(1)) +
1059
+ ", size_n = " + str(size_n));
1060
+ TORCH_CHECK(
1061
+ b_q_weight.size(1) % marlin_24::tile_size == 0,
1062
+ "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
1063
+ " is not divisible by tile_size = " + str(marlin_24::tile_size));
1064
+
1065
+ int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
1066
+ TORCH_CHECK(
1067
+ size_n == actual_size_n,
1068
+ "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
1069
+
1070
+ // Verify meta
1071
+ TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
1072
+ "b_meta.size(0) = ", b_meta.size(0),
1073
+ " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2);
1074
+ TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1),
1075
+ " is not size_n * 2 = ", size_n * 2);
1076
+
1077
+ // Verify A device and strides
1078
+ TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
1079
+ TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
1080
+ TORCH_CHECK(a.dtype() == torch::kFloat16,
1081
+ "A is not float16, currently only float16 is supported");
1082
+
1083
+ // Verify B device and strides
1084
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
1085
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
1086
+
1087
+ // Verify b_meta device and strides
1088
+ TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU");
1089
+ TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous");
1090
+
1091
+ // Verify scales device and strides
1092
+ TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
1093
+ TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
1094
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat16,
1095
+ "A is not float16, currently only float16 is supported");
1096
+
1097
+ // Alloc C matrix
1098
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
1099
+ auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
1100
+ torch::Tensor c = torch::empty({size_m, size_n}, options);
1101
+
1102
+ int thread_k = -1;
1103
+ int thread_m = -1;
1104
+ int sms = -1;
1105
+ int max_par = marlin_24::max_par;
1106
+
1107
+ int groupsize = -1;
1108
+ if (b_scales.size(0) > 1) {
1109
+ TORCH_CHECK(size_k % b_scales.size(0) == 0,
1110
+ "size_k = " + str(size_k) +
1111
+ ", is not divisible by b_scales.size(0) = " +
1112
+ str(b_scales.size(0)));
1113
+ groupsize = size_k / b_scales.size(0);
1114
+ groupsize /= 2; // Because of 24
1115
+ }
1116
+
1117
+ // Verify groupsize
1118
+ TORCH_CHECK(groupsize == -1 || groupsize == 64,
1119
+ "Unexpected groupsize = " + str(groupsize));
1120
+
1121
+ // Verify workspace size
1122
+ TORCH_CHECK(size_n % marlin_24::min_thread_n == 0,
1123
+ "size_n = " + str(size_n) +
1124
+ ", is not divisible by min_thread_n = " +
1125
+ str(marlin_24::min_thread_n));
1126
+ int min_workspace_size =
1127
+ (size_n / marlin_24::min_thread_n) * marlin_24::max_par;
1128
+ TORCH_CHECK(workspace.numel() >= min_workspace_size,
1129
+ "workspace.numel = " + str(workspace.numel()) +
1130
+ " is below min_workspace_size = " + str(min_workspace_size));
1131
+
1132
+ int dev = a.get_device();
1133
+ marlin_24::marlin_cuda_2_4(
1134
+ a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
1135
+ b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
1136
+ b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
1137
+ thread_k, thread_m, sms, max_par);
1138
+
1139
+ return c;
1140
+ }
tests/kernels/test_marlin_gemm.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the marlin kernel.
2
+
3
+ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
4
+ """
5
+
6
+ import pytest
7
+ import torch
8
+
9
+ import quantization
10
+
11
+ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
12
+
13
+ from quantization.utils.marlin_utils import (
14
+ GPTQ_MARLIN_24_MAX_PARALLEL,
15
+ GPTQ_MARLIN_24_MIN_THREAD_N,
16
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
17
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
18
+ GPTQ_MARLIN_MAX_PARALLEL,
19
+ GPTQ_MARLIN_MIN_THREAD_N,
20
+ MARLIN_SUPPORTED_GROUP_SIZES,
21
+ MARLIN_QQQ_MAX_PARALLEL,
22
+ MARLIN_QQQ_MIN_THREAD_N,
23
+ MARLIN_QQQ_SUPPORTED_GROUP_SIZES,
24
+ MARLIN_QQQ_SUPPORTED_NUM_BITS,
25
+ marlin_make_empty_g_idx,
26
+ marlin_permute_scales,
27
+ query_marlin_supported_quant_types,
28
+ )
29
+ from quantization.utils.marlin_utils_fp8 import (
30
+ pack_fp8_to_int32,
31
+ )
32
+ from quantization.utils.quant_utils import (
33
+ awq_pack,
34
+ gptq_pack,
35
+ gptq_quantize_weights,
36
+ quantize_weights,
37
+ sort_weights,
38
+ )
39
+ from quantization.scalar_type import scalar_types
40
+
41
+ from quantization.utils.marlin_utils_test import (
42
+ MarlinWorkspace,
43
+ awq_marlin_quantize,
44
+ get_weight_perm,
45
+ marlin_quantize,
46
+ marlin_weights,
47
+ )
48
+ from quantization.utils.marlin_utils_test_24 import (
49
+ marlin_24_quantize,
50
+ )
51
+ from quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
52
+ marlin_qqq_quantize,
53
+ )
54
+
55
+
56
+ # Avoid torch._dynamo.exc.Unsupported: cache_size_limit reached
57
+ torch._dynamo.config.cache_size_limit = 128
58
+
59
+
60
+ capability = torch.cuda.get_device_capability()
61
+ capability = capability[0] * 10 + capability[1]
62
+
63
+
64
+ ACT_ORDER_OPTS = [False, True]
65
+ K_FULL_OPTS = [False, True]
66
+ USE_FP32_REDUCE_OPTS = [False, True]
67
+
68
+ MARLIN_K_CHUNKS = [128]
69
+ MARLIN_N_CHUNKS = [64, 256]
70
+
71
+ MARLIN_24_K_CHUNKS = [128]
72
+ MARLIN_24_N_CHUNKS = [512]
73
+
74
+ HQQ_SUPPORTED_GROUP_SIZES = [64]
75
+
76
+ MNK_FACTORS = [
77
+ (1, 1, 1),
78
+ (1, 4, 8),
79
+ (1, 7, 5),
80
+ (13, 17, 67),
81
+ (26, 37, 13),
82
+ (67, 13, 11),
83
+ (257, 13, 11),
84
+ (658, 13, 11),
85
+ ]
86
+
87
+ DTYPES = [torch.float16, torch.bfloat16]
88
+
89
+
90
+ def compute_max_diff(output, output_ref):
91
+ return torch.mean(torch.abs(output - output_ref)) / torch.mean(
92
+ torch.abs(output_ref)
93
+ )
94
+
95
+
96
+ def rand_data(shape, dtype=torch.float16):
97
+ return torch.randn(shape, dtype=dtype, device="cuda")
98
+
99
+
100
+ @pytest.mark.skipif(
101
+ capability < 80,
102
+ reason="Marlin is not supported on this GPU type.",
103
+ )
104
+ @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
105
+ @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
106
+ @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
107
+ @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
108
+ @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
109
+ @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
110
+ def test_gptq_marlin_repack(
111
+ k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
112
+ ):
113
+ m_factor, n_factor, k_factor = mnk_factors
114
+
115
+ size_k = k_chunk * k_factor
116
+ size_n = n_chunk * n_factor
117
+
118
+ # Filter act_order
119
+ if act_order:
120
+ if group_size == -1:
121
+ return
122
+ if group_size == size_k:
123
+ return
124
+
125
+ # Normalize group_size
126
+ if group_size == -1:
127
+ group_size = size_k
128
+ assert group_size <= size_k
129
+
130
+ # Create input
131
+ b_weight = rand_data((size_k, size_n))
132
+
133
+ # Quantize (and apply act_order if provided)
134
+ w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
135
+ b_weight, quant_type, group_size, act_order
136
+ )
137
+
138
+ # Pack to GPTQ format
139
+ q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
140
+
141
+ # For act_order, sort the "weights" and "g_idx" so that group ids are
142
+ # increasing
143
+ sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
144
+ if act_order:
145
+ q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
146
+
147
+ # Pack to Marlin format
148
+ weight_perm = get_weight_perm(quant_type.size_bits)
149
+ marlin_q_w_1 = marlin_weights(
150
+ q_w, size_k, size_n, quant_type.size_bits, weight_perm
151
+ )
152
+
153
+ opcheck(
154
+ quantization._ops.ops.gptq_marlin_repack,
155
+ (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
156
+ )
157
+
158
+ # Run Marlin repack GPU kernel
159
+ marlin_q_w_2 = quantization.gptq_marlin_repack(
160
+ q_w_gptq,
161
+ sort_indices,
162
+ size_k,
163
+ size_n,
164
+ quant_type.size_bits,
165
+ )
166
+ torch.cuda.synchronize()
167
+
168
+ torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
169
+
170
+
171
+ @pytest.mark.skipif(
172
+ capability < 80,
173
+ reason="Marlin is not supported on this GPU type.",
174
+ )
175
+ @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
176
+ @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
177
+ @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
178
+ @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
179
+ @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
180
+ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
181
+ m_factor, n_factor, k_factor = mnk_factors
182
+
183
+ size_k = k_chunk * k_factor
184
+ size_n = n_chunk * n_factor
185
+
186
+ # Normalize group_size
187
+ if group_size == -1:
188
+ group_size = size_k
189
+ assert group_size <= size_k
190
+
191
+ # Create input
192
+ b_weight = rand_data((size_k, size_n))
193
+
194
+ # Quantize
195
+ w_ref, q_w, s, zp = quantize_weights(
196
+ b_weight, quant_type, group_size, zero_points=True
197
+ )
198
+
199
+ # Pack to AWQ format
200
+ q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
201
+
202
+ # Pack to Marlin format
203
+ weight_perm = get_weight_perm(quant_type.size_bits)
204
+ marlin_q_w_1 = marlin_weights(
205
+ q_w, size_k, size_n, quant_type.size_bits, weight_perm
206
+ )
207
+
208
+ opcheck(
209
+ quantization._ops.ops.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
210
+ )
211
+
212
+ # Run Marlin repack GPU kernel
213
+ marlin_q_w_2 = quantization.awq_marlin_repack(
214
+ q_w_awq,
215
+ size_k,
216
+ size_n,
217
+ quant_type.size_bits,
218
+ )
219
+ torch.cuda.synchronize()
220
+
221
+ torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
222
+
223
+
224
+ @pytest.mark.skipif(
225
+ capability < 80,
226
+ reason="Marlin is not supported on this GPU type.",
227
+ )
228
+ @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
229
+ @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
230
+ @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
231
+ @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
232
+ @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
233
+ @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
234
+ @pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
235
+ @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
236
+ def test_gptq_marlin_gemm(
237
+ k_chunk,
238
+ n_chunk,
239
+ quant_type,
240
+ group_size,
241
+ mnk_factors,
242
+ act_order,
243
+ is_k_full,
244
+ use_fp32_reduce,
245
+ ):
246
+ m_factor, n_factor, k_factor = mnk_factors
247
+
248
+ size_m = m_factor
249
+ size_k = k_chunk * k_factor
250
+ size_n = n_chunk * n_factor
251
+
252
+ if act_order:
253
+ if group_size == -1:
254
+ return
255
+ if group_size == size_k:
256
+ return
257
+
258
+ a_input = rand_data((size_m, size_k))
259
+ b_weight = rand_data((size_k, size_n))
260
+
261
+ w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
262
+ b_weight, quant_type, group_size, act_order
263
+ )
264
+
265
+ marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
266
+
267
+ workspace = MarlinWorkspace(
268
+ size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
269
+ )
270
+
271
+ opcheck(
272
+ quantization._ops.ops.gptq_marlin_gemm,
273
+ (
274
+ a_input,
275
+ marlin_q_w,
276
+ marlin_s,
277
+ marlin_zp,
278
+ g_idx,
279
+ sort_indices,
280
+ workspace.scratch,
281
+ quant_type.id,
282
+ a_input.shape[0],
283
+ b_weight.shape[1],
284
+ a_input.shape[1],
285
+ is_k_full,
286
+ False,
287
+ use_fp32_reduce,
288
+ False,
289
+ ),
290
+ test_utils=DEFAULT_OPCHECK_TEST_UTILS,
291
+ )
292
+
293
+ output = quantization.gptq_marlin_gemm(
294
+ a_input,
295
+ marlin_q_w,
296
+ marlin_s,
297
+ marlin_zp,
298
+ g_idx,
299
+ sort_indices,
300
+ workspace.scratch,
301
+ quant_type,
302
+ a_input.shape[0],
303
+ b_weight.shape[1],
304
+ a_input.shape[1],
305
+ is_k_full=is_k_full,
306
+ has_zp=False,
307
+ use_fp32_reduce=use_fp32_reduce,
308
+ is_zp_float=False,
309
+ )
310
+ output_ref = torch.matmul(a_input, w_ref)
311
+
312
+ torch.cuda.synchronize()
313
+
314
+ max_diff = compute_max_diff(output, output_ref)
315
+
316
+ assert max_diff < 0.04
317
+
318
+
319
+ # TODO: find better way to test this?
320
+ @torch.compile(fullgraph=True)
321
+ def marlin_24_gemm_tester(
322
+ a_input,
323
+ marlin_24_q_w_comp,
324
+ marlin_24_meta,
325
+ marlin_24_s,
326
+ scratch,
327
+ quant_type,
328
+ size_m,
329
+ size_n,
330
+ size_k,
331
+ ):
332
+ return quantization.gptq_marlin_24_gemm(
333
+ a_input,
334
+ marlin_24_q_w_comp,
335
+ marlin_24_meta,
336
+ marlin_24_s,
337
+ scratch,
338
+ quant_type,
339
+ size_m,
340
+ size_n,
341
+ size_k,
342
+ )
343
+
344
+
345
+ @pytest.mark.skipif(
346
+ capability < 80,
347
+ reason="Marlin is not supported on this GPU type.",
348
+ )
349
+ @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
350
+ @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
351
+ @pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
352
+ @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
353
+ @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
354
+ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
355
+ m_factor, n_factor, k_factor = mnk_factors
356
+
357
+ size_m = m_factor
358
+ size_k = k_chunk * k_factor
359
+ size_n = n_chunk * n_factor
360
+
361
+ a_input = rand_data((size_m, size_k))
362
+ b_weight = rand_data((size_k, size_n))
363
+
364
+ (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
365
+ b_weight, quant_type, group_size
366
+ )
367
+
368
+ workspace_24 = MarlinWorkspace(
369
+ size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
370
+ )
371
+
372
+ output_ref = torch.matmul(a_input, w_24_ref)
373
+
374
+ opcheck(
375
+ quantization._ops.ops.gptq_marlin_24_gemm,
376
+ (
377
+ a_input,
378
+ marlin_24_q_w_comp,
379
+ marlin_24_meta,
380
+ marlin_24_s,
381
+ workspace_24.scratch,
382
+ quant_type.id,
383
+ a_input.shape[0],
384
+ b_weight.shape[1],
385
+ a_input.shape[1],
386
+ ),
387
+ test_utils=DEFAULT_OPCHECK_TEST_UTILS,
388
+ )
389
+
390
+ output = marlin_24_gemm_tester(
391
+ a_input,
392
+ marlin_24_q_w_comp,
393
+ marlin_24_meta,
394
+ marlin_24_s,
395
+ workspace_24.scratch,
396
+ quant_type,
397
+ a_input.shape[0],
398
+ b_weight.shape[1],
399
+ a_input.shape[1],
400
+ )
401
+
402
+ torch.cuda.synchronize()
403
+
404
+ max_diff = compute_max_diff(output, output_ref)
405
+
406
+ assert max_diff < 0.04
407
+
408
+
409
+ @pytest.mark.skipif(
410
+ capability < 80,
411
+ reason="Marlin is not supported on this GPU type.",
412
+ )
413
+ @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
414
+ @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
415
+ @pytest.mark.parametrize("num_bits", [8])
416
+ @pytest.mark.parametrize("group_size", [-1])
417
+ @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
418
+ @pytest.mark.parametrize("dtype", DTYPES)
419
+ def test_fp8_marlin_gemm(
420
+ k_chunk,
421
+ n_chunk,
422
+ num_bits,
423
+ group_size,
424
+ mnk_factors,
425
+ dtype,
426
+ ):
427
+ m_factor, n_factor, k_factor = mnk_factors
428
+
429
+ size_m = m_factor
430
+ size_k = k_chunk * k_factor
431
+ size_n = n_chunk * n_factor
432
+
433
+ a_input = rand_data((size_m, size_k), dtype=dtype)
434
+ b_weight = rand_data((size_k, size_n), dtype=dtype)
435
+
436
+ # WEIGHTS
437
+ fp8_weight, weight_scale = quantization.scaled_fp8_quant(b_weight, scale=None)
438
+ # Repack weights to gptq format (packed int32 elements)
439
+ packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
440
+ # Repack weights to marlin format
441
+ marlin_qweight = quantization.gptq_marlin_repack(
442
+ b_q_weight=packed_gptq_qweight,
443
+ perm=torch.empty(0, dtype=torch.int, device="cuda"),
444
+ size_k=size_k,
445
+ size_n=size_n,
446
+ num_bits=8,
447
+ )
448
+
449
+ # WEIGHT SCALES
450
+ # Currently Marlin doesn't support per-tensor scales, so we
451
+ # expand it to channelwise
452
+ scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
453
+ # Permute scales
454
+ marlin_scales = marlin_permute_scales(
455
+ s=scales, size_k=size_k, size_n=size_n, group_size=-1
456
+ )
457
+
458
+ workspace = MarlinWorkspace(
459
+ size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
460
+ )
461
+
462
+ opcheck(
463
+ quantization._ops.ops.fp8_marlin_gemm,
464
+ (
465
+ a_input,
466
+ marlin_qweight,
467
+ marlin_scales,
468
+ workspace.scratch,
469
+ num_bits,
470
+ a_input.shape[0],
471
+ b_weight.shape[1],
472
+ a_input.shape[1],
473
+ ),
474
+ )
475
+
476
+ output = quantization.fp8_marlin_gemm(
477
+ a=a_input,
478
+ b_q_weight=marlin_qweight,
479
+ b_scales=marlin_scales,
480
+ workspace=workspace.scratch,
481
+ num_bits=num_bits,
482
+ size_m=a_input.shape[0],
483
+ size_n=b_weight.shape[1],
484
+ size_k=a_input.shape[1],
485
+ )
486
+ output_ref = torch.matmul(a_input, b_weight)
487
+
488
+ torch.cuda.synchronize()
489
+
490
+ max_diff = compute_max_diff(output, output_ref)
491
+
492
+ assert max_diff < 0.04
493
+
494
+
495
+ @pytest.mark.skipif(
496
+ capability < 80,
497
+ reason="Marlin is not supported on this GPU type.",
498
+ )
499
+ @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
500
+ @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
501
+ @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
502
+ @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
503
+ @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
504
+ @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
505
+ def test_awq_marlin_gemm(
506
+ k_chunk,
507
+ n_chunk,
508
+ quant_type,
509
+ group_size,
510
+ mnk_factors,
511
+ use_fp32_reduce,
512
+ ):
513
+ m_factor, n_factor, k_factor = mnk_factors
514
+
515
+ size_m = m_factor
516
+ size_k = k_chunk * k_factor
517
+ size_n = n_chunk * n_factor
518
+
519
+ a_input = rand_data((size_m, size_k))
520
+ b_weight = rand_data((size_k, size_n))
521
+
522
+ w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
523
+ b_weight, quant_type, group_size
524
+ )
525
+
526
+ g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
527
+ sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
528
+ is_k_full = True
529
+ has_zp = True
530
+
531
+ workspace = MarlinWorkspace(
532
+ size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
533
+ )
534
+
535
+ output = quantization.gptq_marlin_gemm(
536
+ a_input,
537
+ marlin_q_w,
538
+ marlin_s,
539
+ marlin_zp,
540
+ g_idx,
541
+ sort_indices,
542
+ workspace.scratch,
543
+ quant_type,
544
+ a_input.shape[0],
545
+ b_weight.shape[1],
546
+ a_input.shape[1],
547
+ is_k_full=is_k_full,
548
+ has_zp=has_zp,
549
+ use_fp32_reduce=use_fp32_reduce,
550
+ is_zp_float=False,
551
+ )
552
+ output_ref = torch.matmul(a_input, w_ref)
553
+
554
+ torch.cuda.synchronize()
555
+
556
+ max_diff = compute_max_diff(output, output_ref)
557
+
558
+ assert max_diff < 0.04
559
+
560
+
561
+ @pytest.mark.skipif(
562
+ capability < 80,
563
+ reason="Marlin is not supported on this GPU type.",
564
+ )
565
+ @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
566
+ @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
567
+ @pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
568
+ @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
569
+ @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
570
+ def test_hqq_marlin_gemm(
571
+ k_chunk,
572
+ n_chunk,
573
+ group_size,
574
+ mnk_factors,
575
+ use_fp32_reduce,
576
+ ):
577
+ m_factor, n_factor, k_factor = mnk_factors
578
+
579
+ size_m = m_factor
580
+ size_k = k_chunk * k_factor
581
+ size_n = n_chunk * n_factor
582
+
583
+ quant_type = scalar_types.uint4
584
+
585
+ a_input = rand_data((size_m, size_k))
586
+ dev = a_input.device
587
+
588
+ b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
589
+ scale = rand_data((size_n, size_k // group_size))
590
+ zero = rand_data((size_n, size_k // group_size))
591
+
592
+ gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
593
+
594
+ sort_indices = torch.empty(0, dtype=torch.int, device=dev)
595
+ marlin_w_q = quantization.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
596
+ dev
597
+ )
598
+ marlin_s = marlin_permute_scales(
599
+ scale.transpose(1, 0), size_k, size_n, group_size
600
+ ).to(dev)
601
+ marlin_zp = marlin_permute_scales(
602
+ zero.transpose(1, 0), size_k, size_n, group_size
603
+ ).to(dev)
604
+
605
+ g_idx = marlin_make_empty_g_idx(dev)
606
+ g_idx_sort_indices = marlin_make_empty_g_idx(dev)
607
+
608
+ workspace = MarlinWorkspace(
609
+ size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
610
+ )
611
+
612
+ output = quantization.gptq_marlin_gemm(
613
+ a_input,
614
+ marlin_w_q,
615
+ marlin_s,
616
+ marlin_zp,
617
+ g_idx,
618
+ g_idx_sort_indices,
619
+ workspace.scratch,
620
+ quant_type,
621
+ a_input.shape[0],
622
+ b_weight.shape[0],
623
+ a_input.shape[1],
624
+ is_k_full=True,
625
+ has_zp=True,
626
+ use_fp32_reduce=use_fp32_reduce,
627
+ is_zp_float=True,
628
+ )
629
+
630
+ b_flat = b_weight.reshape(-1, group_size)
631
+ zp_flat = zero.reshape(-1, 1)
632
+ s_flat = scale.reshape(-1, 1)
633
+ dequant = (b_flat - zp_flat) * s_flat
634
+
635
+ output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
636
+
637
+ torch.cuda.synchronize()
638
+
639
+ max_diff = compute_max_diff(output, output_ref)
640
+
641
+ assert max_diff < 0.04
642
+
643
+
644
+ @pytest.mark.skipif(
645
+ capability < 80,
646
+ reason="Marlin is not supported on this GPU type.",
647
+ )
648
+ @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
649
+ @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
650
+ @pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
651
+ @pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
652
+ @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
653
+ def test_marlin_qqq_gemm(
654
+ k_chunk,
655
+ n_chunk,
656
+ num_bits,
657
+ group_size,
658
+ mnk_factors,
659
+ ):
660
+ int8_traits = torch.iinfo(torch.int8)
661
+ m_factor, n_factor, k_factor = mnk_factors
662
+
663
+ size_m = m_factor
664
+ size_k = k_chunk * k_factor
665
+ size_n = n_chunk * n_factor
666
+
667
+ a_input = rand_data((size_m, size_k))
668
+ b_weight = rand_data((size_k, size_n))
669
+
670
+ # Quantize activations
671
+ s_a = (
672
+ a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(torch.float)
673
+ )
674
+ q_a = (a_input / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
675
+
676
+ # Quantize weights
677
+ w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = (
678
+ marlin_qqq_quantize(b_weight, num_bits, group_size)
679
+ )
680
+
681
+ workspace = MarlinWorkspace(
682
+ size_n, MARLIN_QQQ_MIN_THREAD_N, MARLIN_QQQ_MAX_PARALLEL
683
+ )
684
+
685
+ opcheck(
686
+ quantization._ops.ops.marlin_qqq_gemm,
687
+ (
688
+ q_a,
689
+ marlin_qqq_q_w,
690
+ s_a,
691
+ marlin_qqq_s_channel,
692
+ marlin_qqq_s_group,
693
+ workspace.scratch,
694
+ a_input.shape[0],
695
+ b_weight.shape[1],
696
+ a_input.shape[1],
697
+ ),
698
+ )
699
+
700
+ output = quantization.marlin_qqq_gemm(
701
+ q_a,
702
+ marlin_qqq_q_w,
703
+ s_a,
704
+ marlin_qqq_s_channel,
705
+ marlin_qqq_s_group,
706
+ workspace.scratch,
707
+ a_input.shape[0],
708
+ b_weight.shape[1],
709
+ a_input.shape[1],
710
+ )
711
+ output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
712
+
713
+ torch.cuda.synchronize()
714
+
715
+ max_diff = compute_max_diff(output, output_ref)
716
+
717
+ assert max_diff < 0.04
718
+
719
+
720
+ def test_marlin_gemm_opcheck():
721
+ size_m = 2048
722
+ size_n = 4096
723
+ size_k = 4096
724
+ a = torch.rand((size_m, size_n), device="cuda", dtype=torch.float16)
725
+ w = torch.randint(-5, 5, (256, 8192), device="cuda", dtype=torch.int32)
726
+ s = torch.full((32, size_k), 0.125, device="cuda", dtype=torch.float16)
727
+ wk = MarlinWorkspace(
728
+ size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
729
+ ).scratch
730
+ x = quantization._ops.ops.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
731
+ y = quantization._ops.ops.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
732
+ torch.testing.assert_close(x, y)
733
+ opcheck(quantization._ops.ops.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))
tests/kernels/utils.py CHANGED
@@ -4,13 +4,20 @@ import itertools
4
  import random
5
  import unittest
6
  from numbers import Number
7
- from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
8
- Union)
9
 
10
  import pytest
11
  import torch
12
  from torch._prims_common import TensorLikeType
13
 
 
 
 
 
 
 
 
 
14
  ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
15
  "test_schema",
16
  "test_autograd_registration",
@@ -18,6 +25,7 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
18
  "test_aot_dispatch_dynamic",
19
  )
20
 
 
21
  # Copied/modified from torch._refs.__init__.py
22
  def fp8_allclose(
23
  a: TensorLikeType,
@@ -29,34 +37,37 @@ def fp8_allclose(
29
  """
30
  Reference implementation of torch.allclose
31
  """
32
- torch._refs._check_close_args(name="torch.allclose",
33
- a=a,
34
- b=b,
35
- rtol=rtol,
36
- atol=atol)
37
 
38
  return bool(
39
  torch.all(
40
- torch.isclose(a.double(),
41
- b.double(),
42
- rtol=rtol,
43
- atol=atol,
44
- equal_nan=equal_nan)).item())
 
45
 
46
  # A special version of op check that has a restricted default set of test_utils
47
  # and a patched version of allclose that supports fp8 types.
48
- def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
49
- torch._library.custom_ops.CustomOpDef],
50
- args: Tuple[Any, ...],
51
- kwargs: Optional[Dict[str, Any]] = None,
52
- *,
53
- test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
54
- raise_exception: bool = True,
55
- cond: bool = True) -> Dict[str, str]:
56
- with unittest.mock.patch('torch.allclose', new=fp8_allclose):
57
- return torch.library.opcheck(
58
- op,
59
- args,
60
- kwargs,
61
- test_utils=test_utils,
62
- raise_exception=raise_exception) if cond else {}
 
 
 
 
 
 
 
4
  import random
5
  import unittest
6
  from numbers import Number
7
+ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
 
8
 
9
  import pytest
10
  import torch
11
  from torch._prims_common import TensorLikeType
12
 
13
+ # For now, disable "test_aot_dispatch_dynamic" since there are some
14
+ # bugs related to this test in PyTorch 2.4.
15
+ DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
16
+ "test_schema",
17
+ "test_autograd_registration",
18
+ "test_faketensor",
19
+ )
20
+
21
  ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
22
  "test_schema",
23
  "test_autograd_registration",
 
25
  "test_aot_dispatch_dynamic",
26
  )
27
 
28
+
29
  # Copied/modified from torch._refs.__init__.py
30
  def fp8_allclose(
31
  a: TensorLikeType,
 
37
  """
38
  Reference implementation of torch.allclose
39
  """
40
+ torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
 
 
 
 
41
 
42
  return bool(
43
  torch.all(
44
+ torch.isclose(
45
+ a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
46
+ )
47
+ ).item()
48
+ )
49
+
50
 
51
  # A special version of op check that has a restricted default set of test_utils
52
  # and a patched version of allclose that supports fp8 types.
53
+ def opcheck(
54
+ op: Union[
55
+ torch._ops.OpOverload,
56
+ torch._ops.OpOverloadPacket,
57
+ torch._library.custom_ops.CustomOpDef,
58
+ ],
59
+ args: Tuple[Any, ...],
60
+ kwargs: Optional[Dict[str, Any]] = None,
61
+ *,
62
+ test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
63
+ raise_exception: bool = True,
64
+ cond: bool = True
65
+ ) -> Dict[str, str]:
66
+ with unittest.mock.patch("torch.allclose", new=fp8_allclose):
67
+ return (
68
+ torch.library.opcheck(
69
+ op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
70
+ )
71
+ if cond
72
+ else {}
73
+ )