danieldk HF Staff commited on
Commit
a02ac19
·
1 Parent(s): 3370704

Build (AArch64)

Browse files
Files changed (44) hide show
  1. build/torch26-cxx11-cu126-aarch64-linux/quantization/__init__.py +2 -2
  2. build/torch26-cxx11-cu126-aarch64-linux/quantization/_ops.py +3 -3
  3. build/torch26-cxx11-cu126-aarch64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_9035540.abi3.so} +2 -2
  4. build/torch26-cxx11-cu126-aarch64-linux/quantization/compressed_tensors.py +34 -33
  5. build/torch26-cxx11-cu126-aarch64-linux/quantization/cutlass.py +10 -16
  6. build/torch26-cxx11-cu126-aarch64-linux/quantization/marlin.py +40 -74
  7. build/torch26-cxx11-cu126-aarch64-linux/quantization/platforms.py +69 -0
  8. build/torch26-cxx11-cu126-aarch64-linux/quantization/scalar_type.py +19 -2
  9. build/torch26-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils.py +231 -170
  10. build/torch26-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp4.py +282 -0
  11. build/torch26-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp8.py +90 -68
  12. build/torch26-cxx98-cu126-aarch64-linux/quantization/__init__.py +2 -2
  13. build/torch26-cxx98-cu126-aarch64-linux/quantization/_ops.py +3 -3
  14. build/torch26-cxx98-cu126-aarch64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_9035540.abi3.so} +2 -2
  15. build/torch26-cxx98-cu126-aarch64-linux/quantization/compressed_tensors.py +34 -33
  16. build/torch26-cxx98-cu126-aarch64-linux/quantization/cutlass.py +10 -16
  17. build/torch26-cxx98-cu126-aarch64-linux/quantization/marlin.py +40 -74
  18. build/torch26-cxx98-cu126-aarch64-linux/quantization/platforms.py +69 -0
  19. build/torch26-cxx98-cu126-aarch64-linux/quantization/scalar_type.py +19 -2
  20. build/torch26-cxx98-cu126-aarch64-linux/quantization/utils/marlin_utils.py +231 -170
  21. build/torch26-cxx98-cu126-aarch64-linux/quantization/utils/marlin_utils_fp4.py +282 -0
  22. build/torch26-cxx98-cu126-aarch64-linux/quantization/utils/marlin_utils_fp8.py +90 -68
  23. build/torch27-cxx11-cu126-aarch64-linux/quantization/__init__.py +2 -2
  24. build/torch27-cxx11-cu126-aarch64-linux/quantization/_ops.py +3 -3
  25. build/torch27-cxx11-cu126-aarch64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_9035540.abi3.so} +2 -2
  26. build/torch27-cxx11-cu126-aarch64-linux/quantization/compressed_tensors.py +34 -33
  27. build/torch27-cxx11-cu126-aarch64-linux/quantization/cutlass.py +10 -16
  28. build/torch27-cxx11-cu126-aarch64-linux/quantization/marlin.py +40 -74
  29. build/torch27-cxx11-cu126-aarch64-linux/quantization/platforms.py +69 -0
  30. build/torch27-cxx11-cu126-aarch64-linux/quantization/scalar_type.py +19 -2
  31. build/torch27-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils.py +231 -170
  32. build/torch27-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp4.py +282 -0
  33. build/torch27-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp8.py +90 -68
  34. build/torch27-cxx11-cu128-aarch64-linux/quantization/__init__.py +2 -2
  35. build/torch27-cxx11-cu128-aarch64-linux/quantization/_ops.py +3 -3
  36. build/torch27-cxx11-cu128-aarch64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_9035540.abi3.so} +2 -2
  37. build/torch27-cxx11-cu128-aarch64-linux/quantization/compressed_tensors.py +34 -33
  38. build/torch27-cxx11-cu128-aarch64-linux/quantization/cutlass.py +10 -16
  39. build/torch27-cxx11-cu128-aarch64-linux/quantization/marlin.py +40 -74
  40. build/torch27-cxx11-cu128-aarch64-linux/quantization/platforms.py +69 -0
  41. build/torch27-cxx11-cu128-aarch64-linux/quantization/scalar_type.py +19 -2
  42. build/torch27-cxx11-cu128-aarch64-linux/quantization/utils/marlin_utils.py +231 -170
  43. build/torch27-cxx11-cu128-aarch64-linux/quantization/utils/marlin_utils_fp4.py +282 -0
  44. build/torch27-cxx11-cu128-aarch64-linux/quantization/utils/marlin_utils_fp8.py +90 -68
build/torch26-cxx11-cu126-aarch64-linux/quantization/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
  from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
  from .cutlass import (
 
3
  cutlass_scaled_mm_supports_fp8,
4
  cutlass_scaled_mm,
5
  cutlass_scaled_mm_azp,
6
  )
7
  from .marlin import (
8
  awq_marlin_repack,
9
- fp8_marlin_gemm,
10
  gptq_marlin_gemm,
11
  gptq_marlin_repack,
12
  gptq_marlin_24_gemm,
@@ -25,8 +25,8 @@ __all__ = [
25
  "awq_marlin_repack",
26
  "cutlass_scaled_mm",
27
  "cutlass_scaled_mm_azp",
 
28
  "cutlass_scaled_mm_supports_fp8",
29
- "fp8_marlin_gemm",
30
  "gptq_marlin_24_gemm",
31
  "gptq_marlin_gemm",
32
  "gptq_marlin_repack",
 
1
  from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
  from .cutlass import (
3
+ cutlass_scaled_mm_supports_block_fp8,
4
  cutlass_scaled_mm_supports_fp8,
5
  cutlass_scaled_mm,
6
  cutlass_scaled_mm_azp,
7
  )
8
  from .marlin import (
9
  awq_marlin_repack,
 
10
  gptq_marlin_gemm,
11
  gptq_marlin_repack,
12
  gptq_marlin_24_gemm,
 
25
  "awq_marlin_repack",
26
  "cutlass_scaled_mm",
27
  "cutlass_scaled_mm_azp",
28
+ "cutlass_scaled_mm_supports_block_fp8",
29
  "cutlass_scaled_mm_supports_fp8",
 
30
  "gptq_marlin_24_gemm",
31
  "gptq_marlin_gemm",
32
  "gptq_marlin_repack",
build/torch26-cxx11-cu126-aarch64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_9035540
3
+ ops = torch.ops._quantization_9035540
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_9035540::{op_name}"
build/torch26-cxx11-cu126-aarch64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_9035540.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8088ed15517846e4eec7ea5975b32cc7e4164522cb9305510663f76d36a1cef5
3
- size 67890136
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aee128710f3a8587386120a226a6caddd5e77cd7a0296a1f7fad51b4028550b1
3
+ size 159934120
build/torch26-cxx11-cu126-aarch64-linux/quantization/compressed_tensors.py CHANGED
@@ -2,17 +2,7 @@ from typing import Optional, Tuple
2
 
3
  import torch
4
 
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
-
16
 
17
  # fp8
18
  def scaled_fp8_quant(
@@ -21,7 +11,8 @@ def scaled_fp8_quant(
21
  num_token_padding: Optional[int] = None,
22
  scale_ub: Optional[torch.Tensor] = None,
23
  use_per_token_if_dynamic: bool = False,
24
- ) -> Tuple[torch.Tensor, torch.Tensor]:
 
25
  """
26
  Quantize input tensor to FP8 and return quantized tensor and scale.
27
 
@@ -42,30 +33,36 @@ def scaled_fp8_quant(
42
  in the dynamic quantization case.
43
 
44
  Returns:
45
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
  scaling factor.
47
  """
48
  # This code assumes batch_dim and num_tokens are flattened
49
- assert input.ndim == 2
50
- shape: Union[Tuple[int, int], torch.Size] = input.shape
51
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
- # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
- # if current_platform.is_rocm() else torch.float8_e4m3fn
54
- out_dtype = torch.float8_e4m3fn
55
  if num_token_padding:
56
  shape = (max(num_token_padding, input.shape[0]), shape[1])
57
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
 
 
 
 
 
58
 
59
  if scale is None:
60
  if use_per_token_if_dynamic:
61
- scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
- ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
 
63
  else:
64
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
  ops.dynamic_scaled_fp8_quant(output, input, scale)
66
  else:
67
  # num_token_padding not implemented for this case
68
- assert scale.numel() == 1 or num_token_padding is None
69
  ops.static_scaled_fp8_quant(output, input, scale)
70
 
71
  return output, scale
@@ -76,8 +73,8 @@ def scaled_int8_quant(
76
  input: torch.Tensor,
77
  scale: Optional[torch.Tensor] = None,
78
  azp: Optional[torch.Tensor] = None,
79
- symmetric: bool = True,
80
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
  """
82
  Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
 
@@ -90,21 +87,25 @@ def scaled_int8_quant(
90
  symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
 
92
  Returns:
93
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
  """
95
  output = torch.empty_like(input, dtype=torch.int8)
96
  if scale is not None:
97
  # static-per-tensor quantization.
98
  assert symmetric == (
99
- azp is None
100
- ), "azp must only be provided for asymmetric quantization."
101
  ops.static_scaled_int8_quant(output, input, scale, azp)
102
  return output, scale, azp
103
 
104
  # dynamic-per-token quantization.
105
- input_scales = torch.empty(
106
- (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
- )
108
- input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
- ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
 
 
110
  return output, input_scales, input_azp
 
 
 
2
 
3
  import torch
4
 
5
+ from ._ops import ops
 
 
 
 
 
 
 
 
 
 
6
 
7
  # fp8
8
  def scaled_fp8_quant(
 
11
  num_token_padding: Optional[int] = None,
12
  scale_ub: Optional[torch.Tensor] = None,
13
  use_per_token_if_dynamic: bool = False,
14
+ output: Optional[torch.Tensor] = None,
15
+ ) -> tuple[torch.Tensor, torch.Tensor]:
16
  """
17
  Quantize input tensor to FP8 and return quantized tensor and scale.
18
 
 
33
  in the dynamic quantization case.
34
 
35
  Returns:
36
+ tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
37
  scaling factor.
38
  """
39
  # This code assumes batch_dim and num_tokens are flattened
40
+ assert (input.ndim == 2)
41
+ shape: Union[tuple[int, int], torch.Size] = input.shape
42
+ # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
43
+ out_dtype: torch.dtype = current_platform.fp8_dtype()
 
 
44
  if num_token_padding:
45
  shape = (max(num_token_padding, input.shape[0]), shape[1])
46
+ if output is None:
47
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
48
+ else:
49
+ assert num_token_padding is None, \
50
+ "padding not supported if output passed in"
51
+ assert output.dtype == out_dtype
52
 
53
  if scale is None:
54
  if use_per_token_if_dynamic:
55
+ scale = torch.empty((shape[0], 1),
56
+ device=input.device,
57
+ dtype=torch.float32)
58
+ ops.dynamic_per_token_scaled_fp8_quant(
59
+ output, input.contiguous(), scale, scale_ub)
60
  else:
61
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
62
  ops.dynamic_scaled_fp8_quant(output, input, scale)
63
  else:
64
  # num_token_padding not implemented for this case
65
+ assert (scale.numel() == 1 and num_token_padding is None)
66
  ops.static_scaled_fp8_quant(output, input, scale)
67
 
68
  return output, scale
 
73
  input: torch.Tensor,
74
  scale: Optional[torch.Tensor] = None,
75
  azp: Optional[torch.Tensor] = None,
76
+ symmetric: bool = True
77
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
78
  """
79
  Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
80
 
 
87
  symmetric: Whether to use symmetric quantization (scale only, azp ignored).
88
 
89
  Returns:
90
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
91
  """
92
  output = torch.empty_like(input, dtype=torch.int8)
93
  if scale is not None:
94
  # static-per-tensor quantization.
95
  assert symmetric == (
96
+ azp
97
+ is None), "azp must only be provided for asymmetric quantization."
98
  ops.static_scaled_int8_quant(output, input, scale, azp)
99
  return output, scale, azp
100
 
101
  # dynamic-per-token quantization.
102
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
103
+ device=input.device,
104
+ dtype=torch.float32)
105
+ input_azp = None if symmetric else torch.empty_like(input_scales,
106
+ dtype=torch.int32)
107
+ ops.dynamic_scaled_int8_quant(output, input.contiguous(),
108
+ input_scales, input_azp)
109
  return output, input_scales, input_azp
110
+
111
+
build/torch26-cxx11-cu126-aarch64-linux/quantization/cutlass.py CHANGED
@@ -2,22 +2,18 @@ from typing import Optional
2
 
3
  import torch
4
 
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
 
16
 
17
  def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
  return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
 
20
 
 
 
 
 
21
  def cutlass_scaled_mm(
22
  a: torch.Tensor,
23
  b: torch.Tensor,
@@ -33,12 +29,10 @@ def cutlass_scaled_mm(
33
  m = a.shape[0]
34
  n = b.shape[1]
35
 
36
- # if current_platform.is_rocm():
37
- # triton_scaled_mm_module = importlib.import_module(
38
- # "vllm.model_executor.layers.quantization.compressed_tensors."
39
- # "triton_scaled_mm")
40
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
 
43
  out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
 
 
2
 
3
  import torch
4
 
5
+ from ._ops import ops
6
+ from .platforms import current_platform
 
 
 
 
 
 
 
 
7
 
8
 
9
  def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
10
  return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
11
 
12
 
13
+ def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
14
+ return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability)
15
+
16
+
17
  def cutlass_scaled_mm(
18
  a: torch.Tensor,
19
  b: torch.Tensor,
 
29
  m = a.shape[0]
30
  n = b.shape[1]
31
 
32
+ cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
33
+ if not cutlass_compatible_b:
34
+ from .triton_scaled_mm import triton_scaled_mm
35
+ return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
 
 
36
 
37
  out = torch.empty((m, n), dtype=out_dtype, device=a.device)
38
 
build/torch26-cxx11-cu126-aarch64-linux/quantization/marlin.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING
2
 
3
  import torch
4
 
@@ -30,58 +30,30 @@ except ImportError as e:
30
  from .scalar_type import ScalarType
31
 
32
 
33
- # fp8 marlin
34
- def fp8_marlin_gemm(
35
- a: torch.Tensor,
36
- b_q_weight: torch.Tensor,
37
- b_scales: torch.Tensor,
38
- workspace: torch.Tensor,
39
- num_bits: int,
40
- size_m: int,
41
- size_n: int,
42
- size_k: int,
43
- ) -> torch.Tensor:
44
- return ops.fp8_marlin_gemm(
45
- a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
- )
47
-
48
-
49
  # gptq_marlin
50
- def gptq_marlin_gemm(
51
- a: torch.Tensor,
52
- b_q_weight: torch.Tensor,
53
- b_scales: torch.Tensor,
54
- b_zeros: torch.Tensor,
55
- g_idx: torch.Tensor,
56
- perm: torch.Tensor,
57
- workspace: torch.Tensor,
58
- b_q_type: ScalarType,
59
- size_m: int,
60
- size_n: int,
61
- size_k: int,
62
- is_k_full: bool,
63
- has_zp: bool = False,
64
- use_fp32_reduce: bool = False,
65
- is_zp_float: bool = False,
66
- ) -> torch.Tensor:
67
- return ops.gptq_marlin_gemm(
68
- a,
69
- b_q_weight,
70
- b_scales,
71
- b_zeros,
72
- g_idx,
73
- perm,
74
- workspace,
75
- b_q_type.id,
76
- size_m,
77
- size_n,
78
- size_k,
79
- is_k_full,
80
- has_zp,
81
- use_fp32_reduce,
82
- is_zp_float,
83
- )
84
-
85
 
86
  # gptq_marlin
87
  def gptq_marlin_repack(
@@ -153,14 +125,6 @@ def marlin_qqq_gemm(
153
  # Fake ops
154
 
155
  if hasattr(ops, "gptq_marlin_24_gemm"):
156
- @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
- def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
- b_scales: torch.Tensor, workspace: torch.Tensor,
159
- num_bits: int, size_m: torch.SymInt,
160
- size_n: torch.SymInt,
161
- size_k: torch.SymInt) -> torch.Tensor:
162
- return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
-
164
  @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
  def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
  b_meta: torch.Tensor, b_scales: torch.Tensor,
@@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"):
172
 
173
  @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
  def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
- b_q_weight: torch.Tensor,
176
- b_scales: torch.Tensor,
177
- b_zeros: torch.Tensor,
178
- g_idx: torch.Tensor,
179
- perm: torch.Tensor,
180
- workspace: torch.Tensor,
181
- b_q_type: ScalarType,
182
- size_m: torch.SymInt,
183
- size_n: torch.SymInt,
184
- size_k: torch.SymInt,
185
- is_k_full: bool,
186
- has_zp: bool = False,
187
- use_fp32_reduce: bool = False,
188
- is_zp_float: bool = False) -> torch.Tensor:
 
 
189
  return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
 
191
  @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
 
1
+ from typing import TYPE_CHECKING, Optional
2
 
3
  import torch
4
 
 
30
  from .scalar_type import ScalarType
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # gptq_marlin
34
+ def gptq_marlin_gemm(a: torch.Tensor,
35
+ c: Optional[torch.Tensor],
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ global_scale: Optional[torch.Tensor],
39
+ b_zeros: Optional[torch.Tensor],
40
+ g_idx: Optional[torch.Tensor],
41
+ perm: Optional[torch.Tensor],
42
+ workspace: torch.Tensor,
43
+ b_q_type: ScalarType,
44
+ size_m: int,
45
+ size_n: int,
46
+ size_k: int,
47
+ is_k_full: bool = True,
48
+ use_atomic_add: bool = False,
49
+ use_fp32_reduce: bool = False,
50
+ is_zp_float: bool = False) -> torch.Tensor:
51
+ return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
52
+ global_scale, b_zeros, g_idx, perm,
53
+ workspace, b_q_type.id, size_m,
54
+ size_n, size_k, is_k_full,
55
+ use_atomic_add, use_fp32_reduce,
56
+ is_zp_float)
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # gptq_marlin
59
  def gptq_marlin_repack(
 
125
  # Fake ops
126
 
127
  if hasattr(ops, "gptq_marlin_24_gemm"):
 
 
 
 
 
 
 
 
128
  @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
129
  def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
130
  b_meta: torch.Tensor, b_scales: torch.Tensor,
 
136
 
137
  @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
138
  def _gptq_marlin_gemm_fake(a: torch.Tensor,
139
+ c: Optional[torch.Tensor],
140
+ b_q_weight: torch.Tensor,
141
+ b_scales: torch.Tensor,
142
+ global_scale: Optional[torch.Tensor],
143
+ b_zeros: Optional[torch.Tensor],
144
+ g_idx: Optional[torch.Tensor],
145
+ perm: Optional[torch.Tensor],
146
+ workspace: torch.Tensor,
147
+ b_q_type_id: int,
148
+ size_m: torch.SymInt,
149
+ size_n: torch.SymInt,
150
+ size_k: torch.SymInt,
151
+ is_k_full: bool = True,
152
+ use_atomic_add: bool = False,
153
+ use_fp32_reduce: bool = False,
154
+ is_zp_float: bool = False) -> torch.Tensor:
155
  return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
156
 
157
  @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
build/torch26-cxx11-cu126-aarch64-linux/quantization/platforms.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import lru_cache
3
+ from typing import NamedTuple
4
+
5
+ import torch
6
+
7
+ IS_ROCM = torch.version.hip is not None
8
+
9
+
10
+ class DeviceCapability(NamedTuple):
11
+ major: int
12
+ minor: int
13
+
14
+ def as_version_str(self) -> str:
15
+ return f"{self.major}.{self.minor}"
16
+
17
+ def to_int(self) -> int:
18
+ """
19
+ Express device capability as an integer ``<major><minor>``.
20
+
21
+ It is assumed that the minor version is always a single digit.
22
+ """
23
+ assert 0 <= self.minor < 10
24
+ return self.major * 10 + self.minor
25
+
26
+
27
+ class Platform(ABC):
28
+ simple_compile_backend: str = "inductor"
29
+
30
+ @classmethod
31
+ @abstractmethod
32
+ def get_device_name(cls, device_id: int = 0) -> str: ...
33
+
34
+ @abstractmethod
35
+ def is_rocm(self): ...
36
+
37
+
38
+ class CudaPlatform(Platform):
39
+ @classmethod
40
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
41
+ major, minor = torch.cuda.get_device_capability(device_id)
42
+ return DeviceCapability(major=major, minor=minor)
43
+
44
+ @classmethod
45
+ @lru_cache(maxsize=8)
46
+ def get_device_name(cls, device_id: int = 0) -> str:
47
+ return torch.cuda.get_device_name(0)
48
+
49
+ def is_rocm(self):
50
+ return False
51
+
52
+
53
+ class RocmPlatform(Platform):
54
+ @classmethod
55
+ @lru_cache(maxsize=8)
56
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
57
+ major, minor = torch.cuda.get_device_capability(device_id)
58
+ return DeviceCapability(major=major, minor=minor)
59
+
60
+ @classmethod
61
+ @lru_cache(maxsize=8)
62
+ def get_device_name(cls, device_id: int = 0) -> str:
63
+ return torch.cuda.get_device_name(device_id)
64
+
65
+ def is_rocm(self):
66
+ return True
67
+
68
+
69
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch26-cxx11-cu126-aarch64-linux/quantization/scalar_type.py CHANGED
@@ -1,9 +1,14 @@
 
 
 
1
  import functools
2
  import struct
3
  from dataclasses import dataclass
4
  from enum import Enum
5
  from typing import Optional, Union
6
 
 
 
7
 
8
  # Mirrors enum in `core/scalar_type.hpp`
9
  class NanRepr(Enum):
@@ -121,8 +126,8 @@ class ScalarType:
121
  min_raw = max_raw | sign_bit_double
122
  return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
  else:
124
- assert (not self.is_signed() or
125
- self.size_bits <= 64), "Cannot represent min as a int64_t"
126
 
127
  if self.is_signed():
128
  return -(1 << (self.size_bits - 1))
@@ -156,6 +161,8 @@ class ScalarType:
156
  assert offset <= 64, \
157
  f"ScalarType fields too big {offset} to fit into an int64"
158
 
 
 
159
  return val
160
 
161
  @property
@@ -293,6 +300,13 @@ class ScalarType:
293
  ret.id # noqa B018: make sure the id is cached
294
  return ret
295
 
 
 
 
 
 
 
 
296
 
297
  # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
  # for floating point types (leading f) the scheme is:
@@ -319,6 +333,9 @@ class scalar_types:
319
  # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
  float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
 
 
 
 
322
  # "gptq" types
323
  uint2b2 = ScalarType.uint(2, 2)
324
  uint3b4 = ScalarType.uint(3, 4)
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  import functools
5
  import struct
6
  from dataclasses import dataclass
7
  from enum import Enum
8
  from typing import Optional, Union
9
 
10
+ _SCALAR_TYPES_ID_MAP = {}
11
+
12
 
13
  # Mirrors enum in `core/scalar_type.hpp`
14
  class NanRepr(Enum):
 
126
  min_raw = max_raw | sign_bit_double
127
  return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
128
  else:
129
+ assert (not self.is_signed() or self.size_bits
130
+ <= 64), "Cannot represent min as a int64_t"
131
 
132
  if self.is_signed():
133
  return -(1 << (self.size_bits - 1))
 
161
  assert offset <= 64, \
162
  f"ScalarType fields too big {offset} to fit into an int64"
163
 
164
+ _SCALAR_TYPES_ID_MAP[val] = self
165
+
166
  return val
167
 
168
  @property
 
300
  ret.id # noqa B018: make sure the id is cached
301
  return ret
302
 
303
+ @classmethod
304
+ def from_id(cls, scalar_type_id: int):
305
+ if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
306
+ raise ValueError(
307
+ f"scalar_type_id {scalar_type_id} doesn't exists.")
308
+ return _SCALAR_TYPES_ID_MAP[scalar_type_id]
309
+
310
 
311
  # naming generally follows: https://github.com/jax-ml/ml_dtypes
312
  # for floating point types (leading f) the scheme is:
 
333
  # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
334
  float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
335
 
336
+ # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
337
+ float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
338
+
339
  # "gptq" types
340
  uint2b2 = ScalarType.uint(2, 2)
341
  uint3b4 = ScalarType.uint(3, 4)
build/torch26-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils.py CHANGED
@@ -1,4 +1,7 @@
1
- from typing import List, Optional, Tuple
 
 
 
2
 
3
  import numpy
4
  import torch
@@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True
42
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
  # TODO: we may want to move this into the C++ so its closer to the actual impl
44
  def query_marlin_supported_quant_types(
45
- has_zp: bool, device_capability: Optional[int] = None
 
 
46
  ):
47
  if device_capability is None:
48
  capability_tuple = torch.cuda.get_device_capability()
@@ -51,137 +56,141 @@ def query_marlin_supported_quant_types(
51
  if device_capability < 80:
52
  return []
53
 
 
 
 
 
 
 
 
 
 
 
54
  if has_zp:
55
  # AWQ style, unsigned + runtime zero-point
56
- return [scalar_types.uint4, scalar_types.uint8]
57
  else:
58
  # GPTQ style, unsigned + symmetric bias
59
- # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
- # to add `scalar_types.float8_e4m3fn` here
61
- return [scalar_types.uint4b8, scalar_types.uint8b128]
 
62
 
63
 
64
  def _check_marlin_supported(
65
- quant_type: ScalarType,
66
- group_size: Optional[int],
67
- has_zp: bool,
68
- device_capability: Optional[int] = None,
69
- ) -> Tuple[bool, Optional[str]]:
70
 
71
  if device_capability is None:
72
  capability_tuple = torch.cuda.get_device_capability()
73
  device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
 
75
- supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
 
76
 
77
  if quant_type not in supported_types:
78
- return (
79
- False,
80
- f"Marlin does not support weight_bits = {quant_type}. "
81
- f"Only types = {supported_types} "
82
- f"are supported (for group_size = {group_size}, "
83
- f"device_capability = {device_capability}, zp = {has_zp}).",
84
- )
85
- if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
- return (
87
- False,
88
- f"Marlin does not support group_size = {group_size}. "
89
- f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
- "are supported.",
91
- )
92
 
93
  return True, None
94
 
95
 
96
- def check_marlin_supported(
97
- quant_type: ScalarType,
98
- group_size: int,
99
- has_zp: bool = False,
100
- device_capability: Optional[int] = None,
101
- ) -> bool:
102
- cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
  return cond
104
 
105
 
106
- def verify_marlin_supported(
107
- quant_type: ScalarType, group_size: int, has_zp: bool = False
108
- ) -> None:
109
  cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
  if not cond:
111
  assert err_msg is not None
112
  raise ValueError(err_msg)
113
 
114
 
115
- def verify_marlin_supports_shape(
116
- output_size_per_partition: int,
117
- input_size_per_partition: int,
118
- input_size: int,
119
- group_size: int,
120
- ) -> None:
121
 
122
  # Validate output_size_per_partition
123
  if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
- raise ValueError(
125
- f"Weight output_size_per_partition = "
126
- f"{output_size_per_partition} is not divisible by "
127
- f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
- "Consider reducing tensor_parallel_size or running "
129
- "with --quantization gptq."
130
- )
131
 
132
  # Validate input_size_per_partition
133
  if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
- raise ValueError(
135
- f"Weight input_size_per_partition = "
136
- f"{input_size_per_partition} is not divisible "
137
- f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
- "Consider reducing tensor_parallel_size or running "
139
- "with --quantization gptq."
140
- )
141
-
142
- if group_size < input_size and input_size_per_partition % group_size != 0:
143
  raise ValueError(
144
  f"Weight input_size_per_partition = {input_size_per_partition}"
145
- f" is not divisible by group_size = {group_size}."
146
  "Consider reducing tensor_parallel_size or running "
147
- "with --quantization gptq."
148
- )
149
 
150
 
151
- def check_marlin_supports_shape(
152
- output_size_per_partition: int,
153
- input_size_per_partition: int,
154
- input_size: int,
155
- group_size: int,
156
- ) -> Tuple[bool, Optional[str]]:
157
  try:
158
- verify_marlin_supports_shape(
159
- output_size_per_partition, input_size_per_partition, input_size, group_size
160
- )
161
  except ValueError as e:
162
  return False, e.__str__()
163
  return True, None
164
 
165
 
166
- def marlin_make_workspace(
167
- output_size_per_partition: int, device: torch.device
168
- ) -> torch.Tensor:
169
- max_workspace_size = (
170
- output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
- ) * GPTQ_MARLIN_MAX_PARALLEL
172
 
173
- return torch.zeros(
174
- max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
- )
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
  return (not act_order) or (act_order and not is_row_parallel)
180
 
181
 
182
- def marlin_repeat_scales_on_all_ranks(
183
- act_order: bool, group_size: int, is_row_parallel: bool
184
- ) -> bool:
185
  # Need to repeat scales on every rank if act_ordering or
186
  # channelwise and RowParallelLinear
187
  is_channelwise = group_size == -1
@@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks(
189
 
190
 
191
  def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
- return torch.nn.Parameter(
193
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
- )
195
 
196
 
197
  def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
- return torch.nn.Parameter(
199
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
- )
201
 
202
 
203
- def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
204
  g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
  return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
 
207
 
208
  def get_scale_perms():
209
- scale_perm: List[int] = []
210
  for i in range(8):
211
  scale_perm.extend([i + 8 * j for j in range(8)])
212
- scale_perm_single: List[int] = []
213
  for i in range(4):
214
- scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
 
215
  return scale_perm, scale_perm_single
216
 
217
 
218
- def marlin_permute_scales(
219
- s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
- ) -> torch.Tensor:
221
 
222
  scale_perm, scale_perm_single = get_scale_perms()
223
  if group_size < size_k and group_size != -1:
@@ -247,9 +255,8 @@ def marlin_moe_permute_scales(
247
  return output
248
 
249
 
250
- def marlin_zero_points(
251
- zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
- ) -> torch.Tensor:
253
  # Permute zero-points in a similar way to scales, but do not use the
254
  # "single" permutation, since zero-points are applied on every MMA
255
  scale_perm, _ = get_scale_perms()
@@ -270,9 +277,8 @@ def marlin_zero_points(
270
  return zp
271
 
272
 
273
- def awq_to_marlin_zero_points(
274
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
- ) -> torch.Tensor:
276
  # AWQ zero-points are quantized and packed on the column dim.
277
  # In addition, the values are permuted based on dequantizer.
278
  # Here we undo both of these, and then apply marlin permutation
@@ -294,9 +300,8 @@ def awq_to_marlin_zero_points(
294
  return marlin_zp
295
 
296
 
297
- def moe_awq_to_marlin_zero_points(
298
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
- ):
300
  num_experts = q_zp_packed.shape[0]
301
  output = torch.empty(
302
  (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
@@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points(
304
  dtype=q_zp_packed.dtype,
305
  )
306
  for e in range(num_experts):
307
- output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
 
308
  return output
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def apply_gptq_marlin_linear(
312
- input: torch.Tensor,
313
- weight: torch.Tensor,
314
- weight_scale: torch.Tensor,
315
- weight_zp: torch.Tensor,
316
- g_idx: torch.Tensor,
317
- g_idx_sort_indices: torch.Tensor,
318
- workspace: torch.Tensor,
319
- wtype: ScalarType,
320
- output_size_per_partition: int,
321
- input_size_per_partition: int,
322
- is_k_full: bool,
323
- bias: Optional[torch.Tensor] = None,
324
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
- ) -> torch.Tensor:
326
  reshaped_x = input.reshape(-1, input.shape[-1])
327
- out_shape = input.shape[:-1] + (output_size_per_partition,)
328
-
329
- output = ops.gptq_marlin_gemm(
330
- reshaped_x,
331
- weight,
332
- weight_scale,
333
- weight_zp,
334
- g_idx,
335
- g_idx_sort_indices,
336
- workspace,
337
- wtype,
338
- size_m=reshaped_x.shape[0],
339
- size_n=output_size_per_partition,
340
- size_k=input_size_per_partition,
341
- is_k_full=is_k_full,
342
- has_zp=False,
343
- use_fp32_reduce=use_fp32_reduce,
344
- is_zp_float=False,
345
- )
 
 
 
 
 
 
346
 
347
  if bias is not None:
348
  output.add_(bias) # In-place add
@@ -351,39 +408,43 @@ def apply_gptq_marlin_linear(
351
 
352
 
353
  def apply_awq_marlin_linear(
354
- input: torch.Tensor,
355
- weight: torch.Tensor,
356
- weight_scale: torch.Tensor,
357
- weight_zp: torch.Tensor,
358
- g_idx: torch.Tensor,
359
- g_idx_sort_indices: torch.Tensor,
360
- workspace: torch.Tensor,
361
- quant_type: ScalarType,
362
- output_size_per_partition: int,
363
- input_size_per_partition: int,
364
- bias: Optional[torch.Tensor] = None,
365
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
- ) -> torch.Tensor:
367
  reshaped_x = input.reshape(-1, input.shape[-1])
368
- out_shape = input.shape[:-1] + (output_size_per_partition,)
369
-
370
- output = ops.gptq_marlin_gemm(
371
- reshaped_x,
372
- weight,
373
- weight_scale,
374
- weight_zp,
375
- g_idx,
376
- g_idx_sort_indices,
377
- workspace,
378
- quant_type,
379
- size_m=reshaped_x.shape[0],
380
- size_n=output_size_per_partition,
381
- size_k=input_size_per_partition,
382
- is_k_full=True,
383
- has_zp=True,
384
- use_fp32_reduce=use_fp32_reduce,
385
- is_zp_float=False,
386
- )
 
 
 
 
 
387
 
388
  if bias is not None:
389
  output.add_(bias) # In-place add
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
 
6
  import numpy
7
  import torch
 
45
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
46
  # TODO: we may want to move this into the C++ so its closer to the actual impl
47
  def query_marlin_supported_quant_types(
48
+ has_zp: Optional[bool] = None,
49
+ include_fp_type: bool = True,
50
+ device_capability: Optional[int] = None,
51
  ):
52
  if device_capability is None:
53
  capability_tuple = torch.cuda.get_device_capability()
 
56
  if device_capability < 80:
57
  return []
58
 
59
+ # - has_zp is True: return quant_types that has zero points
60
+ # - has_zp is False: return quant_types that has not zero points
61
+ # - has_zp is None: both
62
+ if has_zp is None:
63
+ types0 = query_marlin_supported_quant_types(False, include_fp_type,
64
+ device_capability)
65
+ types1 = query_marlin_supported_quant_types(True, include_fp_type,
66
+ device_capability)
67
+ return types0 + types1
68
+
69
  if has_zp:
70
  # AWQ style, unsigned + runtime zero-point
71
+ return [scalar_types.uint4]
72
  else:
73
  # GPTQ style, unsigned + symmetric bias
74
+ res = [scalar_types.uint4b8, scalar_types.uint8b128]
75
+ if include_fp_type:
76
+ res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
77
+ return res
78
 
79
 
80
  def _check_marlin_supported(
81
+ quant_type: ScalarType,
82
+ group_size: Optional[int],
83
+ has_zp: bool,
84
+ device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
 
85
 
86
  if device_capability is None:
87
  capability_tuple = torch.cuda.get_device_capability()
88
  device_capability = capability_tuple[0] * 10 + capability_tuple[1]
89
 
90
+ supported_types = query_marlin_supported_quant_types(
91
+ has_zp, True, device_capability)
92
 
93
  if quant_type not in supported_types:
94
+ return (False, f"Marlin does not support weight_bits = {quant_type}. "
95
+ f"Only types = {supported_types} "
96
+ f"are supported (for group_size = {group_size}, "
97
+ f"device_capability = {device_capability}, zp = {has_zp}).")
98
+ if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
99
+ return (False, f"Marlin does not support group_size = {group_size}. "
100
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
101
+ "are supported.")
 
 
 
 
 
 
102
 
103
  return True, None
104
 
105
 
106
+ def check_marlin_supported(quant_type: ScalarType,
107
+ group_size: int,
108
+ has_zp: bool = False,
109
+ device_capability: Optional[int] = None) -> bool:
110
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
111
+ device_capability)
 
112
  return cond
113
 
114
 
115
+ def verify_marlin_supported(quant_type: ScalarType,
116
+ group_size: int,
117
+ has_zp: bool = False) -> None:
118
  cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
119
  if not cond:
120
  assert err_msg is not None
121
  raise ValueError(err_msg)
122
 
123
 
124
+ def verify_marlin_supports_shape(output_size_per_partition: int,
125
+ input_size_per_partition: int,
126
+ input_size: int, group_size: int) -> None:
 
 
 
127
 
128
  # Validate output_size_per_partition
129
  if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
130
+ raise ValueError(f"Weight output_size_per_partition = "
131
+ f"{output_size_per_partition} is not divisible by "
132
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
133
+ "Consider reducing tensor_parallel_size or running "
134
+ "with --quantization gptq.")
 
 
135
 
136
  # Validate input_size_per_partition
137
  if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
138
+ raise ValueError(f"Weight input_size_per_partition = "
139
+ f"{input_size_per_partition} is not divisible "
140
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
141
+ "Consider reducing tensor_parallel_size or running "
142
+ "with --quantization gptq.")
143
+
144
+ if (group_size < input_size
145
+ and input_size_per_partition % group_size != 0):
 
146
  raise ValueError(
147
  f"Weight input_size_per_partition = {input_size_per_partition}"
148
+ f" is not divisible by group_size = {group_size}. "
149
  "Consider reducing tensor_parallel_size or running "
150
+ "with --quantization gptq.")
 
151
 
152
 
153
+ def check_marlin_supports_shape(output_size_per_partition: int,
154
+ input_size_per_partition: int,
155
+ input_size: int, group_size: int) \
156
+ -> tuple[bool, Optional[str]]:
 
 
157
  try:
158
+ verify_marlin_supports_shape(output_size_per_partition,
159
+ input_size_per_partition, input_size,
160
+ group_size)
161
  except ValueError as e:
162
  return False, e.__str__()
163
  return True, None
164
 
165
 
166
+ def marlin_make_workspace(output_size_per_partition: int,
167
+ device: torch.device) -> torch.Tensor:
168
+ max_workspace_size = (output_size_per_partition //
169
+ GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
 
 
170
 
171
+ return torch.zeros(max_workspace_size,
172
+ dtype=torch.int,
173
+ device=device,
174
+ requires_grad=False)
175
+
176
+
177
+ def marlin_make_workspace_new(device: torch.device,
178
+ max_blocks_per_sm: int = 1) -> torch.Tensor:
179
+ # In the new marlin kernel, we use the num of threadblocks as workspace
180
+ # size. The num of threadblocks is is sms_count * max_blocks_per_sm.
181
+ sms = torch.cuda.get_device_properties(device).multi_processor_count
182
+ return torch.zeros(sms * max_blocks_per_sm,
183
+ dtype=torch.int,
184
+ device=device,
185
+ requires_grad=False)
186
 
187
 
188
  def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
189
  return (not act_order) or (act_order and not is_row_parallel)
190
 
191
 
192
+ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
193
+ is_row_parallel: bool) -> bool:
 
194
  # Need to repeat scales on every rank if act_ordering or
195
  # channelwise and RowParallelLinear
196
  is_channelwise = group_size == -1
 
198
 
199
 
200
  def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
201
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
202
+ requires_grad=False)
 
203
 
204
 
205
  def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
206
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
207
+ requires_grad=False)
 
208
 
209
 
210
+ def marlin_sort_g_idx(
211
+ g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
212
  g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
213
  return g_idx[g_idx_sort_indices], g_idx_sort_indices
214
 
215
 
216
  def get_scale_perms():
217
+ scale_perm: list[int] = []
218
  for i in range(8):
219
  scale_perm.extend([i + 8 * j for j in range(8)])
220
+ scale_perm_single: list[int] = []
221
  for i in range(4):
222
+ scale_perm_single.extend(
223
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
224
  return scale_perm, scale_perm_single
225
 
226
 
227
+ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
228
+ group_size: int) -> torch.Tensor:
 
229
 
230
  scale_perm, scale_perm_single = get_scale_perms()
231
  if group_size < size_k and group_size != -1:
 
255
  return output
256
 
257
 
258
+ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
259
+ num_bits: int) -> torch.Tensor:
 
260
  # Permute zero-points in a similar way to scales, but do not use the
261
  # "single" permutation, since zero-points are applied on every MMA
262
  scale_perm, _ = get_scale_perms()
 
277
  return zp
278
 
279
 
280
+ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
281
+ size_n: int, num_bits: int) -> torch.Tensor:
 
282
  # AWQ zero-points are quantized and packed on the column dim.
283
  # In addition, the values are permuted based on dequantizer.
284
  # Here we undo both of these, and then apply marlin permutation
 
300
  return marlin_zp
301
 
302
 
303
+ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
304
+ size_n: int, num_bits: int):
 
305
  num_experts = q_zp_packed.shape[0]
306
  output = torch.empty(
307
  (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
 
309
  dtype=q_zp_packed.dtype,
310
  )
311
  for e in range(num_experts):
312
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
313
+ num_bits)
314
  return output
315
 
316
 
317
+ def maybe_warn_marlin_atomic_add(device, dtype):
318
+ if torch.compiler.is_dynamo_compiling():
319
+ return
320
+ device_capability = torch.cuda.get_device_capability(device)
321
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
322
+ logger.info_once(
323
+ "You are running Marlin kernel with bf16 on GPUs before SM90. "
324
+ "You can consider change to fp16 to achieve better performance "
325
+ "if possible.")
326
+
327
+
328
+ def maybe_warn_marlin_atomic_add_env():
329
+ if torch.compiler.is_dynamo_compiling():
330
+ return
331
+ if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
332
+ return
333
+ logger.info_once(
334
+ "Marlin kernel can achieve better performance for small size_n "
335
+ "with experimental use_atomic_add feature. "
336
+ "You can consider set environment variable "
337
+ "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
338
+
339
+
340
+ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
341
+ dtype: torch.dtype) -> bool:
342
+
343
+ # the performance of atomicAdd is better than global reduce
344
+ # only when m*n is small and k is large
345
+ if n >= 2048 or k < 2048 or device.type != "cuda":
346
+ return False
347
+
348
+ # disable atomicAdd reduce by default,
349
+ # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
350
+ if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
351
+ maybe_warn_marlin_atomic_add_env()
352
+ return False
353
+
354
+ # sm8x doesn't support atomicAdd + bfloat16 natively
355
+ device_capability = torch.cuda.get_device_capability(device)
356
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
357
+ maybe_warn_marlin_atomic_add(device, dtype)
358
+ return False
359
+
360
+ return True
361
+
362
+
363
  def apply_gptq_marlin_linear(
364
+ input: torch.Tensor,
365
+ weight: torch.Tensor,
366
+ weight_scale: torch.Tensor,
367
+ weight_zp: torch.Tensor,
368
+ g_idx: torch.Tensor,
369
+ g_idx_sort_indices: torch.Tensor,
370
+ workspace: torch.Tensor,
371
+ wtype: ScalarType,
372
+ output_size_per_partition: int,
373
+ input_size_per_partition: int,
374
+ is_k_full: bool,
375
+ bias: Optional[torch.Tensor] = None,
376
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
 
377
  reshaped_x = input.reshape(-1, input.shape[-1])
378
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
379
+
380
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
381
+ n=output_size_per_partition,
382
+ k=reshaped_x.size(1),
383
+ device=input.device,
384
+ dtype=input.dtype)
385
+
386
+ output = ops.gptq_marlin_gemm(reshaped_x,
387
+ None,
388
+ weight,
389
+ weight_scale,
390
+ None,
391
+ weight_zp,
392
+ g_idx,
393
+ g_idx_sort_indices,
394
+ workspace,
395
+ wtype,
396
+ size_m=reshaped_x.shape[0],
397
+ size_n=output_size_per_partition,
398
+ size_k=input_size_per_partition,
399
+ is_k_full=is_k_full,
400
+ use_atomic_add=use_atomic_add,
401
+ use_fp32_reduce=use_fp32_reduce,
402
+ is_zp_float=False)
403
 
404
  if bias is not None:
405
  output.add_(bias) # In-place add
 
408
 
409
 
410
  def apply_awq_marlin_linear(
411
+ input: torch.Tensor,
412
+ weight: torch.Tensor,
413
+ weight_scale: torch.Tensor,
414
+ weight_zp: torch.Tensor,
415
+ g_idx: torch.Tensor,
416
+ g_idx_sort_indices: torch.Tensor,
417
+ workspace: torch.Tensor,
418
+ quant_type: ScalarType,
419
+ output_size_per_partition: int,
420
+ input_size_per_partition: int,
421
+ bias: Optional[torch.Tensor] = None,
422
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
 
423
  reshaped_x = input.reshape(-1, input.shape[-1])
424
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
425
+
426
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
427
+ n=output_size_per_partition,
428
+ k=reshaped_x.size(1),
429
+ device=input.device,
430
+ dtype=input.dtype)
431
+
432
+ output = ops.gptq_marlin_gemm(reshaped_x,
433
+ None,
434
+ weight,
435
+ weight_scale,
436
+ None,
437
+ weight_zp,
438
+ g_idx,
439
+ g_idx_sort_indices,
440
+ workspace,
441
+ quant_type,
442
+ size_m=reshaped_x.shape[0],
443
+ size_n=output_size_per_partition,
444
+ size_k=input_size_per_partition,
445
+ use_atomic_add=use_atomic_add,
446
+ use_fp32_reduce=use_fp32_reduce,
447
+ is_zp_float=False)
448
 
449
  if bias is not None:
450
  output.add_(bias) # In-place add
build/torch26-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp4.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ import quantization as ops
9
+
10
+ from .marlin_utils import (
11
+ USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
12
+ should_use_atomic_add_reduce)
13
+ from quantization.scalar_type import scalar_types
14
+
15
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
16
+
17
+
18
+ def is_fp4_marlin_supported():
19
+ capability = torch.cuda.get_device_capability()
20
+ capability = capability[0] * 10 + capability[1]
21
+ return capability >= 80
22
+
23
+
24
+ def fp4_marlin_process_scales(marlin_scales):
25
+ if not (marlin_scales >= 0).all():
26
+ logger.warning_once(
27
+ "NVFP4 Marlin assumes the scales to be >=0, but has encountered "
28
+ "negative scales. Accuracy will likely be degraded. This is "
29
+ "because it changes the scales from FP8-S1E4M3 to a special "
30
+ "FP8-S0E5M3 format to speedup the dequantization.")
31
+
32
+ # convert to half first, we would convert to fp8 later
33
+ marlin_scales = marlin_scales.to(torch.half)
34
+
35
+ # 8 is the number of scale number using by one thread
36
+ marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
37
+ marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
38
+ marlin_scales.size(0) * 2, -1)
39
+
40
+ # fit the layout of fp8 dequantization
41
+ marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
42
+ marlin_scales.size(0), -1)
43
+
44
+ # We assume that weight_scale (FP8-S1E4M3) is always greater
45
+ # than or equal to 0. So we can convert
46
+ # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
47
+ # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
48
+ # when weight_scale > 0. This allows us to have an exponent bias
49
+ # closer to zero after dequantization.
50
+
51
+ marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
52
+ marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
53
+ marlin_scales = marlin_scales[:, 1::2].contiguous()
54
+
55
+ return marlin_scales
56
+
57
+
58
+ def fp4_marlin_process_global_scale(global_scale):
59
+ assert global_scale.dtype in [torch.half, torch.bfloat16]
60
+ fp4_exponent = 2
61
+ if global_scale.dtype == torch.half:
62
+ target_exponent = 5
63
+ elif global_scale.dtype == torch.bfloat16:
64
+ target_exponent = 8
65
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
66
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
67
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
68
+ return global_scale * (2.0**(exponent_bias - 7))
69
+
70
+
71
+ def apply_fp4_marlin_linear(
72
+ input: torch.Tensor,
73
+ weight: torch.Tensor,
74
+ weight_scale: torch.Tensor,
75
+ weight_scale_2: torch.Tensor,
76
+ workspace: torch.Tensor,
77
+ size_n: int,
78
+ size_k: int,
79
+ bias: Optional[torch.Tensor] = None,
80
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
81
+ # For GPUs that lack FP4 hardware support, we can leverage the
82
+ # Marlin kernel for fast weight-only FP4 quantization
83
+
84
+ reshaped_x = input.reshape(-1, input.shape[-1])
85
+ out_shape = input.shape[:-1] + (size_n, )
86
+
87
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
88
+ n=size_n,
89
+ k=size_k,
90
+ device=input.device,
91
+ dtype=input.dtype)
92
+
93
+ output = ops.gptq_marlin_gemm(a=reshaped_x,
94
+ c=None,
95
+ b_q_weight=weight,
96
+ b_scales=weight_scale,
97
+ global_scale=weight_scale_2,
98
+ b_zeros=None,
99
+ g_idx=None,
100
+ perm=None,
101
+ workspace=workspace,
102
+ b_q_type=scalar_types.float4_e2m1f,
103
+ size_m=reshaped_x.size(0),
104
+ size_n=size_n,
105
+ size_k=size_k,
106
+ use_atomic_add=use_atomic_add,
107
+ use_fp32_reduce=use_fp32_reduce)
108
+
109
+ if bias is not None:
110
+ output.add_(bias) # In-place add
111
+
112
+ return output.reshape(out_shape)
113
+
114
+
115
+ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
116
+ logger.warning_once(
117
+ "Your GPU does not have native support for FP4 computation but "
118
+ "FP4 quantization is being used. Weight-only FP4 compression will "
119
+ "be used leveraging the Marlin kernel. This may degrade "
120
+ "performance for compute-heavy workloads.")
121
+
122
+ part_size_n = layer.output_size_per_partition
123
+ part_size_k = layer.input_size_per_partition
124
+ param_dtype = layer.params_dtype
125
+
126
+ assert layer.weight.shape == (part_size_n, part_size_k // 2)
127
+
128
+ device = layer.weight.device
129
+
130
+ # WORKSPACE
131
+ layer.workspace = marlin_make_workspace_new(device)
132
+
133
+ # WEIGHT
134
+ # Repack weights to marlin format
135
+ perm = torch.empty(0, dtype=torch.int, device=device)
136
+ qweight = layer.weight.view(torch.int32).T.contiguous()
137
+
138
+ marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
139
+ perm=perm,
140
+ size_k=part_size_k,
141
+ size_n=part_size_n,
142
+ num_bits=4)
143
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
144
+
145
+ # WEIGHT SCALES
146
+ # Permute scales
147
+ weight_scale = layer.weight_scale.T.to(param_dtype)
148
+ weight_scale = marlin_permute_scales(s=weight_scale,
149
+ size_k=part_size_k,
150
+ size_n=part_size_n,
151
+ group_size=16)
152
+ weight_scale = fp4_marlin_process_scales(weight_scale)
153
+ layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
154
+
155
+ weight_scale_2 = layer.weight_scale_2.to(param_dtype)
156
+ weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
157
+ layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
158
+ requires_grad=False)
159
+
160
+ return
161
+
162
+
163
+ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
164
+ logger.warning_once(
165
+ "Your GPU does not have native support for FP4 computation but "
166
+ "FP4 quantization is being used. Weight-only FP4 compression will "
167
+ "be used leveraging the Marlin kernel. This may degrade "
168
+ "performance for compute-heavy workloads.")
169
+
170
+ e = layer.num_experts
171
+ k = layer.hidden_size
172
+ n = layer.intermediate_size_per_partition
173
+
174
+ # WORKSPACE
175
+ device = layer.w13_weight.device
176
+ param_dtype = layer.params_dtype
177
+ layer.workspace = marlin_make_workspace_new(device, 4)
178
+ perm = torch.empty(0, dtype=torch.int, device=device)
179
+
180
+ # WEIGHT
181
+ # Repack weights to marlin format
182
+ for name in ["w13_weight", "w2_weight"]:
183
+ weight = getattr(layer, name)
184
+ tensor_list = []
185
+ if "w13" in name:
186
+ size_n, size_k = n * 2, k
187
+ else:
188
+ size_n, size_k = k, n
189
+
190
+ assert weight.shape == (e, size_n, size_k // 2)
191
+
192
+ for i in range(e):
193
+ qweight = weight[i].view(torch.int32).T.contiguous()
194
+
195
+ marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
196
+ perm=perm,
197
+ size_k=size_k,
198
+ size_n=size_n,
199
+ num_bits=4)
200
+ tensor_list.append(marlin_qweight)
201
+
202
+ weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
203
+ weight = torch.nn.Parameter(weight, requires_grad=False)
204
+
205
+ setattr(layer, name, weight)
206
+
207
+ # WEIGHT SCALES
208
+ # Permute scales
209
+ for name in ["w13", "w2"]:
210
+ scales = getattr(layer, name + "_weight_scale").to(param_dtype)
211
+ global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
212
+
213
+ tensor_list = []
214
+ if "w13" in name:
215
+ size_n, size_k = n * 2, k
216
+ else:
217
+ size_n, size_k = k, n
218
+
219
+ for i in range(e):
220
+ marlin_scales = marlin_permute_scales(s=scales[i].T,
221
+ size_k=size_k,
222
+ size_n=size_n,
223
+ group_size=16)
224
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
225
+ tensor_list.append(marlin_scales)
226
+
227
+ scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
228
+ scales = torch.nn.Parameter(scales, requires_grad=False)
229
+ setattr(layer, name + "_weight_scale", scales)
230
+
231
+ global_scale = fp4_marlin_process_global_scale(global_scale)
232
+ global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
233
+ setattr(layer, name + "_weight_scale_2", global_scale)
234
+
235
+
236
+ def rand_marlin_weight_fp4_like(weight, group_size):
237
+ assert group_size > 0
238
+ size_n, size_k = weight.shape
239
+ device = weight.device
240
+
241
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
242
+ global_scale = scales.max() / 448
243
+ scales = (scales / global_scale).to(torch.float8_e4m3fn)
244
+
245
+ fp4_weight = torch.randint(0,
246
+ 256, (size_n, size_k // 2),
247
+ dtype=torch.uint8,
248
+ device=weight.device)
249
+ fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
250
+ ((fp4_weight & 0b01110000) >> 2))
251
+ fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
252
+ fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
253
+
254
+ fp4_weight2 = fp4_weight << 4
255
+ fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
256
+ ((fp4_weight2 & 0b01110000) >> 2))
257
+ fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
258
+ fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
259
+
260
+ weight_ref = torch.cat(
261
+ [fp4_weight_part_2.unsqueeze(2),
262
+ fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
263
+ weight_ref = weight_ref * global_scale.to(weight.dtype) * \
264
+ scales.repeat_interleave(group_size, 1).to(weight.dtype)
265
+
266
+ marlin_qweight = ops.gptq_marlin_repack(
267
+ b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
268
+ perm=torch.empty(0, dtype=torch.int, device=device),
269
+ size_k=size_k,
270
+ size_n=size_n,
271
+ num_bits=4,
272
+ )
273
+
274
+ marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
275
+ size_k=size_k,
276
+ size_n=size_n,
277
+ group_size=group_size)
278
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
279
+
280
+ global_scale = fp4_marlin_process_global_scale(global_scale)
281
+
282
+ return weight_ref.T, marlin_qweight, marlin_scales, global_scale
build/torch26-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp8.py CHANGED
@@ -1,10 +1,13 @@
 
 
 
1
  from typing import Optional
2
 
3
  import torch
4
 
5
  import quantization as ops
6
 
7
- from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
 
9
 
10
  def is_fp8_marlin_supported():
@@ -13,88 +16,107 @@ def is_fp8_marlin_supported():
13
  return capability >= 80
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def apply_fp8_marlin_linear(
17
- input: torch.Tensor,
18
- weight: torch.Tensor,
19
- weight_scale: torch.Tensor,
20
- workspace: torch.Tensor,
21
- size_n: int,
22
- size_k: int,
23
- bias: Optional[torch.Tensor],
24
- ) -> torch.Tensor:
25
  # For GPUs that lack FP8 hardware support, we can leverage the
26
  # Marlin kernel for fast weight-only FP8 quantization
27
 
28
  reshaped_x = input.reshape(-1, input.shape[-1])
29
- out_shape = input.shape[:-1] + (size_n,)
30
-
31
- output = ops.fp8_marlin_gemm(
32
- a=reshaped_x,
33
- b_q_weight=weight,
34
- b_scales=weight_scale,
35
- workspace=workspace,
36
- num_bits=8,
37
- size_m=reshaped_x.shape[0],
38
- size_n=size_n,
39
- size_k=size_k,
40
- )
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  if bias is not None:
43
  output.add_(bias) # In-place add
44
 
45
  return output.reshape(out_shape)
46
 
47
-
48
- def prepare_fp8_layer_for_marlin(
49
- layer: torch.nn.Module, strategy: str = "tensor"
50
- ) -> None:
51
- part_size_n = layer.output_size_per_partition
52
- part_size_k = layer.input_size_per_partition
53
-
54
- device = layer.weight.device
55
-
56
- # WORKSPACE
57
- layer.workspace = marlin_make_workspace(part_size_n, device)
58
-
59
- # WEIGHT
60
- # Repack weights to marlin format
61
- marlin_qweight = ops.gptq_marlin_repack(
62
- b_q_weight=pack_fp8_to_int32(layer.weight),
63
- perm=torch.empty(0, dtype=torch.int, device=device),
64
- size_k=part_size_k,
65
- size_n=part_size_n,
66
- num_bits=8,
67
- )
68
- layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
-
70
- # WEIGHT SCALES
71
- scales = layer.weight_scale.to(layer.orig_dtype)
72
- # Permute scales
73
- marlin_scales = marlin_permute_scales(
74
- s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
- )
76
- layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
-
78
-
79
- def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
  """
81
  Repack FP8 weights to gptq format (packed int32 elements)
82
  """
83
  assert fp8_tensor.dtype == torch.float8_e4m3fn
84
- assert fp8_tensor.shape[0] % 4 == 0
85
-
86
- # Reshape to prepare for packing
87
- reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Convert fp8 to uint8 (byte) representation
90
- byte_tensor = reshaped.view(torch.uint8)
 
 
91
 
92
- # Pack 4 uint8 values into one int32
93
- packed = (
94
- byte_tensor[:, 0].to(torch.int32)
95
- | (byte_tensor[:, 1].to(torch.int32) << 8)
96
- | (byte_tensor[:, 2].to(torch.int32) << 16)
97
- | (byte_tensor[:, 3].to(torch.int32) << 24)
98
- )
99
 
100
- return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  from typing import Optional
5
 
6
  import torch
7
 
8
  import quantization as ops
9
 
10
+ from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
11
 
12
 
13
  def is_fp8_marlin_supported():
 
16
  return capability >= 80
17
 
18
 
19
+ def fp8_fused_exponent_bias_into_scales(scales):
20
+ fp8_exponent = 4
21
+ if scales.dtype == torch.half:
22
+ target_exponent = 5
23
+ elif scales.dtype == torch.bfloat16:
24
+ target_exponent = 8
25
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
26
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
27
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
28
+ s = torch.ones_like(scales) * 2
29
+ s = s**exponent_bias
30
+ return scales * s
31
+
32
+
33
  def apply_fp8_marlin_linear(
34
+ input: torch.Tensor,
35
+ weight: torch.Tensor,
36
+ weight_scale: torch.Tensor,
37
+ workspace: torch.Tensor,
38
+ size_n: int,
39
+ size_k: int,
40
+ bias: Optional[torch.Tensor],
41
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
42
  # For GPUs that lack FP8 hardware support, we can leverage the
43
  # Marlin kernel for fast weight-only FP8 quantization
44
 
45
  reshaped_x = input.reshape(-1, input.shape[-1])
46
+ out_shape = input.shape[:-1] + (size_n, )
47
+
48
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
49
+ n=size_n,
50
+ k=size_k,
51
+ device=input.device,
52
+ dtype=input.dtype)
53
+
54
+ output = ops.gptq_marlin_gemm(a=reshaped_x,
55
+ c=None,
56
+ b_q_weight=weight,
57
+ b_scales=weight_scale,
58
+ global_scale=None,
59
+ b_zeros=None,
60
+ g_idx=None,
61
+ perm=None,
62
+ workspace=workspace,
63
+ b_q_type=scalar_types.float8_e4m3fn,
64
+ size_m=reshaped_x.size(0),
65
+ size_n=size_n,
66
+ size_k=size_k,
67
+ use_atomic_add=use_atomic_add,
68
+ use_fp32_reduce=use_fp32_reduce)
69
 
70
  if bias is not None:
71
  output.add_(bias) # In-place add
72
 
73
  return output.reshape(out_shape)
74
 
75
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
76
+ size_k_first: bool = True) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  Repack FP8 weights to gptq format (packed int32 elements)
79
  """
80
  assert fp8_tensor.dtype == torch.float8_e4m3fn
81
+ assert fp8_tensor.ndim == 2
82
+
83
+ fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
84
+ fp8_tensor = fp8_tensor.contiguous()
85
+ # fp8_tensor is contiguous and have shape (N, K) now
86
+ # with `.view(torch.int32)`, it become (N, K // 4)
87
+ int32_tensor = fp8_tensor.view(torch.int32)
88
+ return int32_tensor.T.contiguous() if size_k_first else int32_tensor
89
+
90
+
91
+ def marlin_quant_fp8_torch(weight, group_size):
92
+ size_n, size_k = weight.shape
93
+ device = weight.device
94
+
95
+ if group_size != -1:
96
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
97
+ repeated_scales = scales.repeat_interleave(group_size, 1)
98
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
99
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
100
+ else:
101
+ scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
102
+ repeated_scales = scales.repeat_interleave(size_k, 1)
103
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
104
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
105
+
106
+ packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
107
+ marlin_qweight = ops.gptq_marlin_repack(
108
+ b_q_weight=packed_weight,
109
+ perm=torch.empty(0, dtype=torch.int, device=device),
110
+ size_k=size_k,
111
+ size_n=size_n,
112
+ num_bits=8,
113
+ )
114
 
115
+ marlin_scales = marlin_permute_scales(s=scales.T,
116
+ size_k=size_k,
117
+ size_n=size_n,
118
+ group_size=group_size)
119
 
120
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
 
 
 
 
 
 
121
 
122
+ return weight_ref.T, marlin_qweight, marlin_scales
build/torch26-cxx98-cu126-aarch64-linux/quantization/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
  from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
  from .cutlass import (
 
3
  cutlass_scaled_mm_supports_fp8,
4
  cutlass_scaled_mm,
5
  cutlass_scaled_mm_azp,
6
  )
7
  from .marlin import (
8
  awq_marlin_repack,
9
- fp8_marlin_gemm,
10
  gptq_marlin_gemm,
11
  gptq_marlin_repack,
12
  gptq_marlin_24_gemm,
@@ -25,8 +25,8 @@ __all__ = [
25
  "awq_marlin_repack",
26
  "cutlass_scaled_mm",
27
  "cutlass_scaled_mm_azp",
 
28
  "cutlass_scaled_mm_supports_fp8",
29
- "fp8_marlin_gemm",
30
  "gptq_marlin_24_gemm",
31
  "gptq_marlin_gemm",
32
  "gptq_marlin_repack",
 
1
  from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
  from .cutlass import (
3
+ cutlass_scaled_mm_supports_block_fp8,
4
  cutlass_scaled_mm_supports_fp8,
5
  cutlass_scaled_mm,
6
  cutlass_scaled_mm_azp,
7
  )
8
  from .marlin import (
9
  awq_marlin_repack,
 
10
  gptq_marlin_gemm,
11
  gptq_marlin_repack,
12
  gptq_marlin_24_gemm,
 
25
  "awq_marlin_repack",
26
  "cutlass_scaled_mm",
27
  "cutlass_scaled_mm_azp",
28
+ "cutlass_scaled_mm_supports_block_fp8",
29
  "cutlass_scaled_mm_supports_fp8",
 
30
  "gptq_marlin_24_gemm",
31
  "gptq_marlin_gemm",
32
  "gptq_marlin_repack",
build/torch26-cxx98-cu126-aarch64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_9035540
3
+ ops = torch.ops._quantization_9035540
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_9035540::{op_name}"
build/torch26-cxx98-cu126-aarch64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_9035540.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b8c55f64398210ec79884d722f11a1a40f8e61dd9cc1aaf31111592db04da151
3
- size 67884088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3685a434362226370f1956f59790a58d2f4c8999f9f35acafd25ca9d73bfc5ae
3
+ size 159991696
build/torch26-cxx98-cu126-aarch64-linux/quantization/compressed_tensors.py CHANGED
@@ -2,17 +2,7 @@ from typing import Optional, Tuple
2
 
3
  import torch
4
 
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
-
16
 
17
  # fp8
18
  def scaled_fp8_quant(
@@ -21,7 +11,8 @@ def scaled_fp8_quant(
21
  num_token_padding: Optional[int] = None,
22
  scale_ub: Optional[torch.Tensor] = None,
23
  use_per_token_if_dynamic: bool = False,
24
- ) -> Tuple[torch.Tensor, torch.Tensor]:
 
25
  """
26
  Quantize input tensor to FP8 and return quantized tensor and scale.
27
 
@@ -42,30 +33,36 @@ def scaled_fp8_quant(
42
  in the dynamic quantization case.
43
 
44
  Returns:
45
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
  scaling factor.
47
  """
48
  # This code assumes batch_dim and num_tokens are flattened
49
- assert input.ndim == 2
50
- shape: Union[Tuple[int, int], torch.Size] = input.shape
51
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
- # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
- # if current_platform.is_rocm() else torch.float8_e4m3fn
54
- out_dtype = torch.float8_e4m3fn
55
  if num_token_padding:
56
  shape = (max(num_token_padding, input.shape[0]), shape[1])
57
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
 
 
 
 
 
58
 
59
  if scale is None:
60
  if use_per_token_if_dynamic:
61
- scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
- ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
 
63
  else:
64
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
  ops.dynamic_scaled_fp8_quant(output, input, scale)
66
  else:
67
  # num_token_padding not implemented for this case
68
- assert scale.numel() == 1 or num_token_padding is None
69
  ops.static_scaled_fp8_quant(output, input, scale)
70
 
71
  return output, scale
@@ -76,8 +73,8 @@ def scaled_int8_quant(
76
  input: torch.Tensor,
77
  scale: Optional[torch.Tensor] = None,
78
  azp: Optional[torch.Tensor] = None,
79
- symmetric: bool = True,
80
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
  """
82
  Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
 
@@ -90,21 +87,25 @@ def scaled_int8_quant(
90
  symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
 
92
  Returns:
93
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
  """
95
  output = torch.empty_like(input, dtype=torch.int8)
96
  if scale is not None:
97
  # static-per-tensor quantization.
98
  assert symmetric == (
99
- azp is None
100
- ), "azp must only be provided for asymmetric quantization."
101
  ops.static_scaled_int8_quant(output, input, scale, azp)
102
  return output, scale, azp
103
 
104
  # dynamic-per-token quantization.
105
- input_scales = torch.empty(
106
- (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
- )
108
- input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
- ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
 
 
110
  return output, input_scales, input_azp
 
 
 
2
 
3
  import torch
4
 
5
+ from ._ops import ops
 
 
 
 
 
 
 
 
 
 
6
 
7
  # fp8
8
  def scaled_fp8_quant(
 
11
  num_token_padding: Optional[int] = None,
12
  scale_ub: Optional[torch.Tensor] = None,
13
  use_per_token_if_dynamic: bool = False,
14
+ output: Optional[torch.Tensor] = None,
15
+ ) -> tuple[torch.Tensor, torch.Tensor]:
16
  """
17
  Quantize input tensor to FP8 and return quantized tensor and scale.
18
 
 
33
  in the dynamic quantization case.
34
 
35
  Returns:
36
+ tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
37
  scaling factor.
38
  """
39
  # This code assumes batch_dim and num_tokens are flattened
40
+ assert (input.ndim == 2)
41
+ shape: Union[tuple[int, int], torch.Size] = input.shape
42
+ # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
43
+ out_dtype: torch.dtype = current_platform.fp8_dtype()
 
 
44
  if num_token_padding:
45
  shape = (max(num_token_padding, input.shape[0]), shape[1])
46
+ if output is None:
47
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
48
+ else:
49
+ assert num_token_padding is None, \
50
+ "padding not supported if output passed in"
51
+ assert output.dtype == out_dtype
52
 
53
  if scale is None:
54
  if use_per_token_if_dynamic:
55
+ scale = torch.empty((shape[0], 1),
56
+ device=input.device,
57
+ dtype=torch.float32)
58
+ ops.dynamic_per_token_scaled_fp8_quant(
59
+ output, input.contiguous(), scale, scale_ub)
60
  else:
61
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
62
  ops.dynamic_scaled_fp8_quant(output, input, scale)
63
  else:
64
  # num_token_padding not implemented for this case
65
+ assert (scale.numel() == 1 and num_token_padding is None)
66
  ops.static_scaled_fp8_quant(output, input, scale)
67
 
68
  return output, scale
 
73
  input: torch.Tensor,
74
  scale: Optional[torch.Tensor] = None,
75
  azp: Optional[torch.Tensor] = None,
76
+ symmetric: bool = True
77
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
78
  """
79
  Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
80
 
 
87
  symmetric: Whether to use symmetric quantization (scale only, azp ignored).
88
 
89
  Returns:
90
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
91
  """
92
  output = torch.empty_like(input, dtype=torch.int8)
93
  if scale is not None:
94
  # static-per-tensor quantization.
95
  assert symmetric == (
96
+ azp
97
+ is None), "azp must only be provided for asymmetric quantization."
98
  ops.static_scaled_int8_quant(output, input, scale, azp)
99
  return output, scale, azp
100
 
101
  # dynamic-per-token quantization.
102
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
103
+ device=input.device,
104
+ dtype=torch.float32)
105
+ input_azp = None if symmetric else torch.empty_like(input_scales,
106
+ dtype=torch.int32)
107
+ ops.dynamic_scaled_int8_quant(output, input.contiguous(),
108
+ input_scales, input_azp)
109
  return output, input_scales, input_azp
110
+
111
+
build/torch26-cxx98-cu126-aarch64-linux/quantization/cutlass.py CHANGED
@@ -2,22 +2,18 @@ from typing import Optional
2
 
3
  import torch
4
 
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
 
16
 
17
  def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
  return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
 
20
 
 
 
 
 
21
  def cutlass_scaled_mm(
22
  a: torch.Tensor,
23
  b: torch.Tensor,
@@ -33,12 +29,10 @@ def cutlass_scaled_mm(
33
  m = a.shape[0]
34
  n = b.shape[1]
35
 
36
- # if current_platform.is_rocm():
37
- # triton_scaled_mm_module = importlib.import_module(
38
- # "vllm.model_executor.layers.quantization.compressed_tensors."
39
- # "triton_scaled_mm")
40
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
 
43
  out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
 
 
2
 
3
  import torch
4
 
5
+ from ._ops import ops
6
+ from .platforms import current_platform
 
 
 
 
 
 
 
 
7
 
8
 
9
  def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
10
  return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
11
 
12
 
13
+ def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
14
+ return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability)
15
+
16
+
17
  def cutlass_scaled_mm(
18
  a: torch.Tensor,
19
  b: torch.Tensor,
 
29
  m = a.shape[0]
30
  n = b.shape[1]
31
 
32
+ cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
33
+ if not cutlass_compatible_b:
34
+ from .triton_scaled_mm import triton_scaled_mm
35
+ return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
 
 
36
 
37
  out = torch.empty((m, n), dtype=out_dtype, device=a.device)
38
 
build/torch26-cxx98-cu126-aarch64-linux/quantization/marlin.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING
2
 
3
  import torch
4
 
@@ -30,58 +30,30 @@ except ImportError as e:
30
  from .scalar_type import ScalarType
31
 
32
 
33
- # fp8 marlin
34
- def fp8_marlin_gemm(
35
- a: torch.Tensor,
36
- b_q_weight: torch.Tensor,
37
- b_scales: torch.Tensor,
38
- workspace: torch.Tensor,
39
- num_bits: int,
40
- size_m: int,
41
- size_n: int,
42
- size_k: int,
43
- ) -> torch.Tensor:
44
- return ops.fp8_marlin_gemm(
45
- a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
- )
47
-
48
-
49
  # gptq_marlin
50
- def gptq_marlin_gemm(
51
- a: torch.Tensor,
52
- b_q_weight: torch.Tensor,
53
- b_scales: torch.Tensor,
54
- b_zeros: torch.Tensor,
55
- g_idx: torch.Tensor,
56
- perm: torch.Tensor,
57
- workspace: torch.Tensor,
58
- b_q_type: ScalarType,
59
- size_m: int,
60
- size_n: int,
61
- size_k: int,
62
- is_k_full: bool,
63
- has_zp: bool = False,
64
- use_fp32_reduce: bool = False,
65
- is_zp_float: bool = False,
66
- ) -> torch.Tensor:
67
- return ops.gptq_marlin_gemm(
68
- a,
69
- b_q_weight,
70
- b_scales,
71
- b_zeros,
72
- g_idx,
73
- perm,
74
- workspace,
75
- b_q_type.id,
76
- size_m,
77
- size_n,
78
- size_k,
79
- is_k_full,
80
- has_zp,
81
- use_fp32_reduce,
82
- is_zp_float,
83
- )
84
-
85
 
86
  # gptq_marlin
87
  def gptq_marlin_repack(
@@ -153,14 +125,6 @@ def marlin_qqq_gemm(
153
  # Fake ops
154
 
155
  if hasattr(ops, "gptq_marlin_24_gemm"):
156
- @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
- def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
- b_scales: torch.Tensor, workspace: torch.Tensor,
159
- num_bits: int, size_m: torch.SymInt,
160
- size_n: torch.SymInt,
161
- size_k: torch.SymInt) -> torch.Tensor:
162
- return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
-
164
  @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
  def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
  b_meta: torch.Tensor, b_scales: torch.Tensor,
@@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"):
172
 
173
  @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
  def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
- b_q_weight: torch.Tensor,
176
- b_scales: torch.Tensor,
177
- b_zeros: torch.Tensor,
178
- g_idx: torch.Tensor,
179
- perm: torch.Tensor,
180
- workspace: torch.Tensor,
181
- b_q_type: ScalarType,
182
- size_m: torch.SymInt,
183
- size_n: torch.SymInt,
184
- size_k: torch.SymInt,
185
- is_k_full: bool,
186
- has_zp: bool = False,
187
- use_fp32_reduce: bool = False,
188
- is_zp_float: bool = False) -> torch.Tensor:
 
 
189
  return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
 
191
  @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
 
1
+ from typing import TYPE_CHECKING, Optional
2
 
3
  import torch
4
 
 
30
  from .scalar_type import ScalarType
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # gptq_marlin
34
+ def gptq_marlin_gemm(a: torch.Tensor,
35
+ c: Optional[torch.Tensor],
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ global_scale: Optional[torch.Tensor],
39
+ b_zeros: Optional[torch.Tensor],
40
+ g_idx: Optional[torch.Tensor],
41
+ perm: Optional[torch.Tensor],
42
+ workspace: torch.Tensor,
43
+ b_q_type: ScalarType,
44
+ size_m: int,
45
+ size_n: int,
46
+ size_k: int,
47
+ is_k_full: bool = True,
48
+ use_atomic_add: bool = False,
49
+ use_fp32_reduce: bool = False,
50
+ is_zp_float: bool = False) -> torch.Tensor:
51
+ return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
52
+ global_scale, b_zeros, g_idx, perm,
53
+ workspace, b_q_type.id, size_m,
54
+ size_n, size_k, is_k_full,
55
+ use_atomic_add, use_fp32_reduce,
56
+ is_zp_float)
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # gptq_marlin
59
  def gptq_marlin_repack(
 
125
  # Fake ops
126
 
127
  if hasattr(ops, "gptq_marlin_24_gemm"):
 
 
 
 
 
 
 
 
128
  @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
129
  def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
130
  b_meta: torch.Tensor, b_scales: torch.Tensor,
 
136
 
137
  @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
138
  def _gptq_marlin_gemm_fake(a: torch.Tensor,
139
+ c: Optional[torch.Tensor],
140
+ b_q_weight: torch.Tensor,
141
+ b_scales: torch.Tensor,
142
+ global_scale: Optional[torch.Tensor],
143
+ b_zeros: Optional[torch.Tensor],
144
+ g_idx: Optional[torch.Tensor],
145
+ perm: Optional[torch.Tensor],
146
+ workspace: torch.Tensor,
147
+ b_q_type_id: int,
148
+ size_m: torch.SymInt,
149
+ size_n: torch.SymInt,
150
+ size_k: torch.SymInt,
151
+ is_k_full: bool = True,
152
+ use_atomic_add: bool = False,
153
+ use_fp32_reduce: bool = False,
154
+ is_zp_float: bool = False) -> torch.Tensor:
155
  return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
156
 
157
  @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
build/torch26-cxx98-cu126-aarch64-linux/quantization/platforms.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import lru_cache
3
+ from typing import NamedTuple
4
+
5
+ import torch
6
+
7
+ IS_ROCM = torch.version.hip is not None
8
+
9
+
10
+ class DeviceCapability(NamedTuple):
11
+ major: int
12
+ minor: int
13
+
14
+ def as_version_str(self) -> str:
15
+ return f"{self.major}.{self.minor}"
16
+
17
+ def to_int(self) -> int:
18
+ """
19
+ Express device capability as an integer ``<major><minor>``.
20
+
21
+ It is assumed that the minor version is always a single digit.
22
+ """
23
+ assert 0 <= self.minor < 10
24
+ return self.major * 10 + self.minor
25
+
26
+
27
+ class Platform(ABC):
28
+ simple_compile_backend: str = "inductor"
29
+
30
+ @classmethod
31
+ @abstractmethod
32
+ def get_device_name(cls, device_id: int = 0) -> str: ...
33
+
34
+ @abstractmethod
35
+ def is_rocm(self): ...
36
+
37
+
38
+ class CudaPlatform(Platform):
39
+ @classmethod
40
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
41
+ major, minor = torch.cuda.get_device_capability(device_id)
42
+ return DeviceCapability(major=major, minor=minor)
43
+
44
+ @classmethod
45
+ @lru_cache(maxsize=8)
46
+ def get_device_name(cls, device_id: int = 0) -> str:
47
+ return torch.cuda.get_device_name(0)
48
+
49
+ def is_rocm(self):
50
+ return False
51
+
52
+
53
+ class RocmPlatform(Platform):
54
+ @classmethod
55
+ @lru_cache(maxsize=8)
56
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
57
+ major, minor = torch.cuda.get_device_capability(device_id)
58
+ return DeviceCapability(major=major, minor=minor)
59
+
60
+ @classmethod
61
+ @lru_cache(maxsize=8)
62
+ def get_device_name(cls, device_id: int = 0) -> str:
63
+ return torch.cuda.get_device_name(device_id)
64
+
65
+ def is_rocm(self):
66
+ return True
67
+
68
+
69
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch26-cxx98-cu126-aarch64-linux/quantization/scalar_type.py CHANGED
@@ -1,9 +1,14 @@
 
 
 
1
  import functools
2
  import struct
3
  from dataclasses import dataclass
4
  from enum import Enum
5
  from typing import Optional, Union
6
 
 
 
7
 
8
  # Mirrors enum in `core/scalar_type.hpp`
9
  class NanRepr(Enum):
@@ -121,8 +126,8 @@ class ScalarType:
121
  min_raw = max_raw | sign_bit_double
122
  return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
  else:
124
- assert (not self.is_signed() or
125
- self.size_bits <= 64), "Cannot represent min as a int64_t"
126
 
127
  if self.is_signed():
128
  return -(1 << (self.size_bits - 1))
@@ -156,6 +161,8 @@ class ScalarType:
156
  assert offset <= 64, \
157
  f"ScalarType fields too big {offset} to fit into an int64"
158
 
 
 
159
  return val
160
 
161
  @property
@@ -293,6 +300,13 @@ class ScalarType:
293
  ret.id # noqa B018: make sure the id is cached
294
  return ret
295
 
 
 
 
 
 
 
 
296
 
297
  # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
  # for floating point types (leading f) the scheme is:
@@ -319,6 +333,9 @@ class scalar_types:
319
  # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
  float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
 
 
 
 
322
  # "gptq" types
323
  uint2b2 = ScalarType.uint(2, 2)
324
  uint3b4 = ScalarType.uint(3, 4)
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  import functools
5
  import struct
6
  from dataclasses import dataclass
7
  from enum import Enum
8
  from typing import Optional, Union
9
 
10
+ _SCALAR_TYPES_ID_MAP = {}
11
+
12
 
13
  # Mirrors enum in `core/scalar_type.hpp`
14
  class NanRepr(Enum):
 
126
  min_raw = max_raw | sign_bit_double
127
  return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
128
  else:
129
+ assert (not self.is_signed() or self.size_bits
130
+ <= 64), "Cannot represent min as a int64_t"
131
 
132
  if self.is_signed():
133
  return -(1 << (self.size_bits - 1))
 
161
  assert offset <= 64, \
162
  f"ScalarType fields too big {offset} to fit into an int64"
163
 
164
+ _SCALAR_TYPES_ID_MAP[val] = self
165
+
166
  return val
167
 
168
  @property
 
300
  ret.id # noqa B018: make sure the id is cached
301
  return ret
302
 
303
+ @classmethod
304
+ def from_id(cls, scalar_type_id: int):
305
+ if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
306
+ raise ValueError(
307
+ f"scalar_type_id {scalar_type_id} doesn't exists.")
308
+ return _SCALAR_TYPES_ID_MAP[scalar_type_id]
309
+
310
 
311
  # naming generally follows: https://github.com/jax-ml/ml_dtypes
312
  # for floating point types (leading f) the scheme is:
 
333
  # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
334
  float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
335
 
336
+ # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
337
+ float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
338
+
339
  # "gptq" types
340
  uint2b2 = ScalarType.uint(2, 2)
341
  uint3b4 = ScalarType.uint(3, 4)
build/torch26-cxx98-cu126-aarch64-linux/quantization/utils/marlin_utils.py CHANGED
@@ -1,4 +1,7 @@
1
- from typing import List, Optional, Tuple
 
 
 
2
 
3
  import numpy
4
  import torch
@@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True
42
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
  # TODO: we may want to move this into the C++ so its closer to the actual impl
44
  def query_marlin_supported_quant_types(
45
- has_zp: bool, device_capability: Optional[int] = None
 
 
46
  ):
47
  if device_capability is None:
48
  capability_tuple = torch.cuda.get_device_capability()
@@ -51,137 +56,141 @@ def query_marlin_supported_quant_types(
51
  if device_capability < 80:
52
  return []
53
 
 
 
 
 
 
 
 
 
 
 
54
  if has_zp:
55
  # AWQ style, unsigned + runtime zero-point
56
- return [scalar_types.uint4, scalar_types.uint8]
57
  else:
58
  # GPTQ style, unsigned + symmetric bias
59
- # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
- # to add `scalar_types.float8_e4m3fn` here
61
- return [scalar_types.uint4b8, scalar_types.uint8b128]
 
62
 
63
 
64
  def _check_marlin_supported(
65
- quant_type: ScalarType,
66
- group_size: Optional[int],
67
- has_zp: bool,
68
- device_capability: Optional[int] = None,
69
- ) -> Tuple[bool, Optional[str]]:
70
 
71
  if device_capability is None:
72
  capability_tuple = torch.cuda.get_device_capability()
73
  device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
 
75
- supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
 
76
 
77
  if quant_type not in supported_types:
78
- return (
79
- False,
80
- f"Marlin does not support weight_bits = {quant_type}. "
81
- f"Only types = {supported_types} "
82
- f"are supported (for group_size = {group_size}, "
83
- f"device_capability = {device_capability}, zp = {has_zp}).",
84
- )
85
- if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
- return (
87
- False,
88
- f"Marlin does not support group_size = {group_size}. "
89
- f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
- "are supported.",
91
- )
92
 
93
  return True, None
94
 
95
 
96
- def check_marlin_supported(
97
- quant_type: ScalarType,
98
- group_size: int,
99
- has_zp: bool = False,
100
- device_capability: Optional[int] = None,
101
- ) -> bool:
102
- cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
  return cond
104
 
105
 
106
- def verify_marlin_supported(
107
- quant_type: ScalarType, group_size: int, has_zp: bool = False
108
- ) -> None:
109
  cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
  if not cond:
111
  assert err_msg is not None
112
  raise ValueError(err_msg)
113
 
114
 
115
- def verify_marlin_supports_shape(
116
- output_size_per_partition: int,
117
- input_size_per_partition: int,
118
- input_size: int,
119
- group_size: int,
120
- ) -> None:
121
 
122
  # Validate output_size_per_partition
123
  if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
- raise ValueError(
125
- f"Weight output_size_per_partition = "
126
- f"{output_size_per_partition} is not divisible by "
127
- f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
- "Consider reducing tensor_parallel_size or running "
129
- "with --quantization gptq."
130
- )
131
 
132
  # Validate input_size_per_partition
133
  if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
- raise ValueError(
135
- f"Weight input_size_per_partition = "
136
- f"{input_size_per_partition} is not divisible "
137
- f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
- "Consider reducing tensor_parallel_size or running "
139
- "with --quantization gptq."
140
- )
141
-
142
- if group_size < input_size and input_size_per_partition % group_size != 0:
143
  raise ValueError(
144
  f"Weight input_size_per_partition = {input_size_per_partition}"
145
- f" is not divisible by group_size = {group_size}."
146
  "Consider reducing tensor_parallel_size or running "
147
- "with --quantization gptq."
148
- )
149
 
150
 
151
- def check_marlin_supports_shape(
152
- output_size_per_partition: int,
153
- input_size_per_partition: int,
154
- input_size: int,
155
- group_size: int,
156
- ) -> Tuple[bool, Optional[str]]:
157
  try:
158
- verify_marlin_supports_shape(
159
- output_size_per_partition, input_size_per_partition, input_size, group_size
160
- )
161
  except ValueError as e:
162
  return False, e.__str__()
163
  return True, None
164
 
165
 
166
- def marlin_make_workspace(
167
- output_size_per_partition: int, device: torch.device
168
- ) -> torch.Tensor:
169
- max_workspace_size = (
170
- output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
- ) * GPTQ_MARLIN_MAX_PARALLEL
172
 
173
- return torch.zeros(
174
- max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
- )
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
  return (not act_order) or (act_order and not is_row_parallel)
180
 
181
 
182
- def marlin_repeat_scales_on_all_ranks(
183
- act_order: bool, group_size: int, is_row_parallel: bool
184
- ) -> bool:
185
  # Need to repeat scales on every rank if act_ordering or
186
  # channelwise and RowParallelLinear
187
  is_channelwise = group_size == -1
@@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks(
189
 
190
 
191
  def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
- return torch.nn.Parameter(
193
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
- )
195
 
196
 
197
  def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
- return torch.nn.Parameter(
199
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
- )
201
 
202
 
203
- def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
204
  g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
  return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
 
207
 
208
  def get_scale_perms():
209
- scale_perm: List[int] = []
210
  for i in range(8):
211
  scale_perm.extend([i + 8 * j for j in range(8)])
212
- scale_perm_single: List[int] = []
213
  for i in range(4):
214
- scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
 
215
  return scale_perm, scale_perm_single
216
 
217
 
218
- def marlin_permute_scales(
219
- s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
- ) -> torch.Tensor:
221
 
222
  scale_perm, scale_perm_single = get_scale_perms()
223
  if group_size < size_k and group_size != -1:
@@ -247,9 +255,8 @@ def marlin_moe_permute_scales(
247
  return output
248
 
249
 
250
- def marlin_zero_points(
251
- zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
- ) -> torch.Tensor:
253
  # Permute zero-points in a similar way to scales, but do not use the
254
  # "single" permutation, since zero-points are applied on every MMA
255
  scale_perm, _ = get_scale_perms()
@@ -270,9 +277,8 @@ def marlin_zero_points(
270
  return zp
271
 
272
 
273
- def awq_to_marlin_zero_points(
274
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
- ) -> torch.Tensor:
276
  # AWQ zero-points are quantized and packed on the column dim.
277
  # In addition, the values are permuted based on dequantizer.
278
  # Here we undo both of these, and then apply marlin permutation
@@ -294,9 +300,8 @@ def awq_to_marlin_zero_points(
294
  return marlin_zp
295
 
296
 
297
- def moe_awq_to_marlin_zero_points(
298
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
- ):
300
  num_experts = q_zp_packed.shape[0]
301
  output = torch.empty(
302
  (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
@@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points(
304
  dtype=q_zp_packed.dtype,
305
  )
306
  for e in range(num_experts):
307
- output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
 
308
  return output
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def apply_gptq_marlin_linear(
312
- input: torch.Tensor,
313
- weight: torch.Tensor,
314
- weight_scale: torch.Tensor,
315
- weight_zp: torch.Tensor,
316
- g_idx: torch.Tensor,
317
- g_idx_sort_indices: torch.Tensor,
318
- workspace: torch.Tensor,
319
- wtype: ScalarType,
320
- output_size_per_partition: int,
321
- input_size_per_partition: int,
322
- is_k_full: bool,
323
- bias: Optional[torch.Tensor] = None,
324
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
- ) -> torch.Tensor:
326
  reshaped_x = input.reshape(-1, input.shape[-1])
327
- out_shape = input.shape[:-1] + (output_size_per_partition,)
328
-
329
- output = ops.gptq_marlin_gemm(
330
- reshaped_x,
331
- weight,
332
- weight_scale,
333
- weight_zp,
334
- g_idx,
335
- g_idx_sort_indices,
336
- workspace,
337
- wtype,
338
- size_m=reshaped_x.shape[0],
339
- size_n=output_size_per_partition,
340
- size_k=input_size_per_partition,
341
- is_k_full=is_k_full,
342
- has_zp=False,
343
- use_fp32_reduce=use_fp32_reduce,
344
- is_zp_float=False,
345
- )
 
 
 
 
 
 
346
 
347
  if bias is not None:
348
  output.add_(bias) # In-place add
@@ -351,39 +408,43 @@ def apply_gptq_marlin_linear(
351
 
352
 
353
  def apply_awq_marlin_linear(
354
- input: torch.Tensor,
355
- weight: torch.Tensor,
356
- weight_scale: torch.Tensor,
357
- weight_zp: torch.Tensor,
358
- g_idx: torch.Tensor,
359
- g_idx_sort_indices: torch.Tensor,
360
- workspace: torch.Tensor,
361
- quant_type: ScalarType,
362
- output_size_per_partition: int,
363
- input_size_per_partition: int,
364
- bias: Optional[torch.Tensor] = None,
365
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
- ) -> torch.Tensor:
367
  reshaped_x = input.reshape(-1, input.shape[-1])
368
- out_shape = input.shape[:-1] + (output_size_per_partition,)
369
-
370
- output = ops.gptq_marlin_gemm(
371
- reshaped_x,
372
- weight,
373
- weight_scale,
374
- weight_zp,
375
- g_idx,
376
- g_idx_sort_indices,
377
- workspace,
378
- quant_type,
379
- size_m=reshaped_x.shape[0],
380
- size_n=output_size_per_partition,
381
- size_k=input_size_per_partition,
382
- is_k_full=True,
383
- has_zp=True,
384
- use_fp32_reduce=use_fp32_reduce,
385
- is_zp_float=False,
386
- )
 
 
 
 
 
387
 
388
  if bias is not None:
389
  output.add_(bias) # In-place add
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
 
6
  import numpy
7
  import torch
 
45
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
46
  # TODO: we may want to move this into the C++ so its closer to the actual impl
47
  def query_marlin_supported_quant_types(
48
+ has_zp: Optional[bool] = None,
49
+ include_fp_type: bool = True,
50
+ device_capability: Optional[int] = None,
51
  ):
52
  if device_capability is None:
53
  capability_tuple = torch.cuda.get_device_capability()
 
56
  if device_capability < 80:
57
  return []
58
 
59
+ # - has_zp is True: return quant_types that has zero points
60
+ # - has_zp is False: return quant_types that has not zero points
61
+ # - has_zp is None: both
62
+ if has_zp is None:
63
+ types0 = query_marlin_supported_quant_types(False, include_fp_type,
64
+ device_capability)
65
+ types1 = query_marlin_supported_quant_types(True, include_fp_type,
66
+ device_capability)
67
+ return types0 + types1
68
+
69
  if has_zp:
70
  # AWQ style, unsigned + runtime zero-point
71
+ return [scalar_types.uint4]
72
  else:
73
  # GPTQ style, unsigned + symmetric bias
74
+ res = [scalar_types.uint4b8, scalar_types.uint8b128]
75
+ if include_fp_type:
76
+ res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
77
+ return res
78
 
79
 
80
  def _check_marlin_supported(
81
+ quant_type: ScalarType,
82
+ group_size: Optional[int],
83
+ has_zp: bool,
84
+ device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
 
85
 
86
  if device_capability is None:
87
  capability_tuple = torch.cuda.get_device_capability()
88
  device_capability = capability_tuple[0] * 10 + capability_tuple[1]
89
 
90
+ supported_types = query_marlin_supported_quant_types(
91
+ has_zp, True, device_capability)
92
 
93
  if quant_type not in supported_types:
94
+ return (False, f"Marlin does not support weight_bits = {quant_type}. "
95
+ f"Only types = {supported_types} "
96
+ f"are supported (for group_size = {group_size}, "
97
+ f"device_capability = {device_capability}, zp = {has_zp}).")
98
+ if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
99
+ return (False, f"Marlin does not support group_size = {group_size}. "
100
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
101
+ "are supported.")
 
 
 
 
 
 
102
 
103
  return True, None
104
 
105
 
106
+ def check_marlin_supported(quant_type: ScalarType,
107
+ group_size: int,
108
+ has_zp: bool = False,
109
+ device_capability: Optional[int] = None) -> bool:
110
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
111
+ device_capability)
 
112
  return cond
113
 
114
 
115
+ def verify_marlin_supported(quant_type: ScalarType,
116
+ group_size: int,
117
+ has_zp: bool = False) -> None:
118
  cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
119
  if not cond:
120
  assert err_msg is not None
121
  raise ValueError(err_msg)
122
 
123
 
124
+ def verify_marlin_supports_shape(output_size_per_partition: int,
125
+ input_size_per_partition: int,
126
+ input_size: int, group_size: int) -> None:
 
 
 
127
 
128
  # Validate output_size_per_partition
129
  if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
130
+ raise ValueError(f"Weight output_size_per_partition = "
131
+ f"{output_size_per_partition} is not divisible by "
132
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
133
+ "Consider reducing tensor_parallel_size or running "
134
+ "with --quantization gptq.")
 
 
135
 
136
  # Validate input_size_per_partition
137
  if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
138
+ raise ValueError(f"Weight input_size_per_partition = "
139
+ f"{input_size_per_partition} is not divisible "
140
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
141
+ "Consider reducing tensor_parallel_size or running "
142
+ "with --quantization gptq.")
143
+
144
+ if (group_size < input_size
145
+ and input_size_per_partition % group_size != 0):
 
146
  raise ValueError(
147
  f"Weight input_size_per_partition = {input_size_per_partition}"
148
+ f" is not divisible by group_size = {group_size}. "
149
  "Consider reducing tensor_parallel_size or running "
150
+ "with --quantization gptq.")
 
151
 
152
 
153
+ def check_marlin_supports_shape(output_size_per_partition: int,
154
+ input_size_per_partition: int,
155
+ input_size: int, group_size: int) \
156
+ -> tuple[bool, Optional[str]]:
 
 
157
  try:
158
+ verify_marlin_supports_shape(output_size_per_partition,
159
+ input_size_per_partition, input_size,
160
+ group_size)
161
  except ValueError as e:
162
  return False, e.__str__()
163
  return True, None
164
 
165
 
166
+ def marlin_make_workspace(output_size_per_partition: int,
167
+ device: torch.device) -> torch.Tensor:
168
+ max_workspace_size = (output_size_per_partition //
169
+ GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
 
 
170
 
171
+ return torch.zeros(max_workspace_size,
172
+ dtype=torch.int,
173
+ device=device,
174
+ requires_grad=False)
175
+
176
+
177
+ def marlin_make_workspace_new(device: torch.device,
178
+ max_blocks_per_sm: int = 1) -> torch.Tensor:
179
+ # In the new marlin kernel, we use the num of threadblocks as workspace
180
+ # size. The num of threadblocks is is sms_count * max_blocks_per_sm.
181
+ sms = torch.cuda.get_device_properties(device).multi_processor_count
182
+ return torch.zeros(sms * max_blocks_per_sm,
183
+ dtype=torch.int,
184
+ device=device,
185
+ requires_grad=False)
186
 
187
 
188
  def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
189
  return (not act_order) or (act_order and not is_row_parallel)
190
 
191
 
192
+ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
193
+ is_row_parallel: bool) -> bool:
 
194
  # Need to repeat scales on every rank if act_ordering or
195
  # channelwise and RowParallelLinear
196
  is_channelwise = group_size == -1
 
198
 
199
 
200
  def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
201
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
202
+ requires_grad=False)
 
203
 
204
 
205
  def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
206
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
207
+ requires_grad=False)
 
208
 
209
 
210
+ def marlin_sort_g_idx(
211
+ g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
212
  g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
213
  return g_idx[g_idx_sort_indices], g_idx_sort_indices
214
 
215
 
216
  def get_scale_perms():
217
+ scale_perm: list[int] = []
218
  for i in range(8):
219
  scale_perm.extend([i + 8 * j for j in range(8)])
220
+ scale_perm_single: list[int] = []
221
  for i in range(4):
222
+ scale_perm_single.extend(
223
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
224
  return scale_perm, scale_perm_single
225
 
226
 
227
+ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
228
+ group_size: int) -> torch.Tensor:
 
229
 
230
  scale_perm, scale_perm_single = get_scale_perms()
231
  if group_size < size_k and group_size != -1:
 
255
  return output
256
 
257
 
258
+ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
259
+ num_bits: int) -> torch.Tensor:
 
260
  # Permute zero-points in a similar way to scales, but do not use the
261
  # "single" permutation, since zero-points are applied on every MMA
262
  scale_perm, _ = get_scale_perms()
 
277
  return zp
278
 
279
 
280
+ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
281
+ size_n: int, num_bits: int) -> torch.Tensor:
 
282
  # AWQ zero-points are quantized and packed on the column dim.
283
  # In addition, the values are permuted based on dequantizer.
284
  # Here we undo both of these, and then apply marlin permutation
 
300
  return marlin_zp
301
 
302
 
303
+ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
304
+ size_n: int, num_bits: int):
 
305
  num_experts = q_zp_packed.shape[0]
306
  output = torch.empty(
307
  (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
 
309
  dtype=q_zp_packed.dtype,
310
  )
311
  for e in range(num_experts):
312
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
313
+ num_bits)
314
  return output
315
 
316
 
317
+ def maybe_warn_marlin_atomic_add(device, dtype):
318
+ if torch.compiler.is_dynamo_compiling():
319
+ return
320
+ device_capability = torch.cuda.get_device_capability(device)
321
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
322
+ logger.info_once(
323
+ "You are running Marlin kernel with bf16 on GPUs before SM90. "
324
+ "You can consider change to fp16 to achieve better performance "
325
+ "if possible.")
326
+
327
+
328
+ def maybe_warn_marlin_atomic_add_env():
329
+ if torch.compiler.is_dynamo_compiling():
330
+ return
331
+ if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
332
+ return
333
+ logger.info_once(
334
+ "Marlin kernel can achieve better performance for small size_n "
335
+ "with experimental use_atomic_add feature. "
336
+ "You can consider set environment variable "
337
+ "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
338
+
339
+
340
+ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
341
+ dtype: torch.dtype) -> bool:
342
+
343
+ # the performance of atomicAdd is better than global reduce
344
+ # only when m*n is small and k is large
345
+ if n >= 2048 or k < 2048 or device.type != "cuda":
346
+ return False
347
+
348
+ # disable atomicAdd reduce by default,
349
+ # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
350
+ if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
351
+ maybe_warn_marlin_atomic_add_env()
352
+ return False
353
+
354
+ # sm8x doesn't support atomicAdd + bfloat16 natively
355
+ device_capability = torch.cuda.get_device_capability(device)
356
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
357
+ maybe_warn_marlin_atomic_add(device, dtype)
358
+ return False
359
+
360
+ return True
361
+
362
+
363
  def apply_gptq_marlin_linear(
364
+ input: torch.Tensor,
365
+ weight: torch.Tensor,
366
+ weight_scale: torch.Tensor,
367
+ weight_zp: torch.Tensor,
368
+ g_idx: torch.Tensor,
369
+ g_idx_sort_indices: torch.Tensor,
370
+ workspace: torch.Tensor,
371
+ wtype: ScalarType,
372
+ output_size_per_partition: int,
373
+ input_size_per_partition: int,
374
+ is_k_full: bool,
375
+ bias: Optional[torch.Tensor] = None,
376
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
 
377
  reshaped_x = input.reshape(-1, input.shape[-1])
378
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
379
+
380
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
381
+ n=output_size_per_partition,
382
+ k=reshaped_x.size(1),
383
+ device=input.device,
384
+ dtype=input.dtype)
385
+
386
+ output = ops.gptq_marlin_gemm(reshaped_x,
387
+ None,
388
+ weight,
389
+ weight_scale,
390
+ None,
391
+ weight_zp,
392
+ g_idx,
393
+ g_idx_sort_indices,
394
+ workspace,
395
+ wtype,
396
+ size_m=reshaped_x.shape[0],
397
+ size_n=output_size_per_partition,
398
+ size_k=input_size_per_partition,
399
+ is_k_full=is_k_full,
400
+ use_atomic_add=use_atomic_add,
401
+ use_fp32_reduce=use_fp32_reduce,
402
+ is_zp_float=False)
403
 
404
  if bias is not None:
405
  output.add_(bias) # In-place add
 
408
 
409
 
410
  def apply_awq_marlin_linear(
411
+ input: torch.Tensor,
412
+ weight: torch.Tensor,
413
+ weight_scale: torch.Tensor,
414
+ weight_zp: torch.Tensor,
415
+ g_idx: torch.Tensor,
416
+ g_idx_sort_indices: torch.Tensor,
417
+ workspace: torch.Tensor,
418
+ quant_type: ScalarType,
419
+ output_size_per_partition: int,
420
+ input_size_per_partition: int,
421
+ bias: Optional[torch.Tensor] = None,
422
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
 
423
  reshaped_x = input.reshape(-1, input.shape[-1])
424
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
425
+
426
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
427
+ n=output_size_per_partition,
428
+ k=reshaped_x.size(1),
429
+ device=input.device,
430
+ dtype=input.dtype)
431
+
432
+ output = ops.gptq_marlin_gemm(reshaped_x,
433
+ None,
434
+ weight,
435
+ weight_scale,
436
+ None,
437
+ weight_zp,
438
+ g_idx,
439
+ g_idx_sort_indices,
440
+ workspace,
441
+ quant_type,
442
+ size_m=reshaped_x.shape[0],
443
+ size_n=output_size_per_partition,
444
+ size_k=input_size_per_partition,
445
+ use_atomic_add=use_atomic_add,
446
+ use_fp32_reduce=use_fp32_reduce,
447
+ is_zp_float=False)
448
 
449
  if bias is not None:
450
  output.add_(bias) # In-place add
build/torch26-cxx98-cu126-aarch64-linux/quantization/utils/marlin_utils_fp4.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ import quantization as ops
9
+
10
+ from .marlin_utils import (
11
+ USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
12
+ should_use_atomic_add_reduce)
13
+ from quantization.scalar_type import scalar_types
14
+
15
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
16
+
17
+
18
+ def is_fp4_marlin_supported():
19
+ capability = torch.cuda.get_device_capability()
20
+ capability = capability[0] * 10 + capability[1]
21
+ return capability >= 80
22
+
23
+
24
+ def fp4_marlin_process_scales(marlin_scales):
25
+ if not (marlin_scales >= 0).all():
26
+ logger.warning_once(
27
+ "NVFP4 Marlin assumes the scales to be >=0, but has encountered "
28
+ "negative scales. Accuracy will likely be degraded. This is "
29
+ "because it changes the scales from FP8-S1E4M3 to a special "
30
+ "FP8-S0E5M3 format to speedup the dequantization.")
31
+
32
+ # convert to half first, we would convert to fp8 later
33
+ marlin_scales = marlin_scales.to(torch.half)
34
+
35
+ # 8 is the number of scale number using by one thread
36
+ marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
37
+ marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
38
+ marlin_scales.size(0) * 2, -1)
39
+
40
+ # fit the layout of fp8 dequantization
41
+ marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
42
+ marlin_scales.size(0), -1)
43
+
44
+ # We assume that weight_scale (FP8-S1E4M3) is always greater
45
+ # than or equal to 0. So we can convert
46
+ # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
47
+ # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
48
+ # when weight_scale > 0. This allows us to have an exponent bias
49
+ # closer to zero after dequantization.
50
+
51
+ marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
52
+ marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
53
+ marlin_scales = marlin_scales[:, 1::2].contiguous()
54
+
55
+ return marlin_scales
56
+
57
+
58
+ def fp4_marlin_process_global_scale(global_scale):
59
+ assert global_scale.dtype in [torch.half, torch.bfloat16]
60
+ fp4_exponent = 2
61
+ if global_scale.dtype == torch.half:
62
+ target_exponent = 5
63
+ elif global_scale.dtype == torch.bfloat16:
64
+ target_exponent = 8
65
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
66
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
67
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
68
+ return global_scale * (2.0**(exponent_bias - 7))
69
+
70
+
71
+ def apply_fp4_marlin_linear(
72
+ input: torch.Tensor,
73
+ weight: torch.Tensor,
74
+ weight_scale: torch.Tensor,
75
+ weight_scale_2: torch.Tensor,
76
+ workspace: torch.Tensor,
77
+ size_n: int,
78
+ size_k: int,
79
+ bias: Optional[torch.Tensor] = None,
80
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
81
+ # For GPUs that lack FP4 hardware support, we can leverage the
82
+ # Marlin kernel for fast weight-only FP4 quantization
83
+
84
+ reshaped_x = input.reshape(-1, input.shape[-1])
85
+ out_shape = input.shape[:-1] + (size_n, )
86
+
87
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
88
+ n=size_n,
89
+ k=size_k,
90
+ device=input.device,
91
+ dtype=input.dtype)
92
+
93
+ output = ops.gptq_marlin_gemm(a=reshaped_x,
94
+ c=None,
95
+ b_q_weight=weight,
96
+ b_scales=weight_scale,
97
+ global_scale=weight_scale_2,
98
+ b_zeros=None,
99
+ g_idx=None,
100
+ perm=None,
101
+ workspace=workspace,
102
+ b_q_type=scalar_types.float4_e2m1f,
103
+ size_m=reshaped_x.size(0),
104
+ size_n=size_n,
105
+ size_k=size_k,
106
+ use_atomic_add=use_atomic_add,
107
+ use_fp32_reduce=use_fp32_reduce)
108
+
109
+ if bias is not None:
110
+ output.add_(bias) # In-place add
111
+
112
+ return output.reshape(out_shape)
113
+
114
+
115
+ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
116
+ logger.warning_once(
117
+ "Your GPU does not have native support for FP4 computation but "
118
+ "FP4 quantization is being used. Weight-only FP4 compression will "
119
+ "be used leveraging the Marlin kernel. This may degrade "
120
+ "performance for compute-heavy workloads.")
121
+
122
+ part_size_n = layer.output_size_per_partition
123
+ part_size_k = layer.input_size_per_partition
124
+ param_dtype = layer.params_dtype
125
+
126
+ assert layer.weight.shape == (part_size_n, part_size_k // 2)
127
+
128
+ device = layer.weight.device
129
+
130
+ # WORKSPACE
131
+ layer.workspace = marlin_make_workspace_new(device)
132
+
133
+ # WEIGHT
134
+ # Repack weights to marlin format
135
+ perm = torch.empty(0, dtype=torch.int, device=device)
136
+ qweight = layer.weight.view(torch.int32).T.contiguous()
137
+
138
+ marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
139
+ perm=perm,
140
+ size_k=part_size_k,
141
+ size_n=part_size_n,
142
+ num_bits=4)
143
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
144
+
145
+ # WEIGHT SCALES
146
+ # Permute scales
147
+ weight_scale = layer.weight_scale.T.to(param_dtype)
148
+ weight_scale = marlin_permute_scales(s=weight_scale,
149
+ size_k=part_size_k,
150
+ size_n=part_size_n,
151
+ group_size=16)
152
+ weight_scale = fp4_marlin_process_scales(weight_scale)
153
+ layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
154
+
155
+ weight_scale_2 = layer.weight_scale_2.to(param_dtype)
156
+ weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
157
+ layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
158
+ requires_grad=False)
159
+
160
+ return
161
+
162
+
163
+ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
164
+ logger.warning_once(
165
+ "Your GPU does not have native support for FP4 computation but "
166
+ "FP4 quantization is being used. Weight-only FP4 compression will "
167
+ "be used leveraging the Marlin kernel. This may degrade "
168
+ "performance for compute-heavy workloads.")
169
+
170
+ e = layer.num_experts
171
+ k = layer.hidden_size
172
+ n = layer.intermediate_size_per_partition
173
+
174
+ # WORKSPACE
175
+ device = layer.w13_weight.device
176
+ param_dtype = layer.params_dtype
177
+ layer.workspace = marlin_make_workspace_new(device, 4)
178
+ perm = torch.empty(0, dtype=torch.int, device=device)
179
+
180
+ # WEIGHT
181
+ # Repack weights to marlin format
182
+ for name in ["w13_weight", "w2_weight"]:
183
+ weight = getattr(layer, name)
184
+ tensor_list = []
185
+ if "w13" in name:
186
+ size_n, size_k = n * 2, k
187
+ else:
188
+ size_n, size_k = k, n
189
+
190
+ assert weight.shape == (e, size_n, size_k // 2)
191
+
192
+ for i in range(e):
193
+ qweight = weight[i].view(torch.int32).T.contiguous()
194
+
195
+ marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
196
+ perm=perm,
197
+ size_k=size_k,
198
+ size_n=size_n,
199
+ num_bits=4)
200
+ tensor_list.append(marlin_qweight)
201
+
202
+ weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
203
+ weight = torch.nn.Parameter(weight, requires_grad=False)
204
+
205
+ setattr(layer, name, weight)
206
+
207
+ # WEIGHT SCALES
208
+ # Permute scales
209
+ for name in ["w13", "w2"]:
210
+ scales = getattr(layer, name + "_weight_scale").to(param_dtype)
211
+ global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
212
+
213
+ tensor_list = []
214
+ if "w13" in name:
215
+ size_n, size_k = n * 2, k
216
+ else:
217
+ size_n, size_k = k, n
218
+
219
+ for i in range(e):
220
+ marlin_scales = marlin_permute_scales(s=scales[i].T,
221
+ size_k=size_k,
222
+ size_n=size_n,
223
+ group_size=16)
224
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
225
+ tensor_list.append(marlin_scales)
226
+
227
+ scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
228
+ scales = torch.nn.Parameter(scales, requires_grad=False)
229
+ setattr(layer, name + "_weight_scale", scales)
230
+
231
+ global_scale = fp4_marlin_process_global_scale(global_scale)
232
+ global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
233
+ setattr(layer, name + "_weight_scale_2", global_scale)
234
+
235
+
236
+ def rand_marlin_weight_fp4_like(weight, group_size):
237
+ assert group_size > 0
238
+ size_n, size_k = weight.shape
239
+ device = weight.device
240
+
241
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
242
+ global_scale = scales.max() / 448
243
+ scales = (scales / global_scale).to(torch.float8_e4m3fn)
244
+
245
+ fp4_weight = torch.randint(0,
246
+ 256, (size_n, size_k // 2),
247
+ dtype=torch.uint8,
248
+ device=weight.device)
249
+ fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
250
+ ((fp4_weight & 0b01110000) >> 2))
251
+ fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
252
+ fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
253
+
254
+ fp4_weight2 = fp4_weight << 4
255
+ fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
256
+ ((fp4_weight2 & 0b01110000) >> 2))
257
+ fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
258
+ fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
259
+
260
+ weight_ref = torch.cat(
261
+ [fp4_weight_part_2.unsqueeze(2),
262
+ fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
263
+ weight_ref = weight_ref * global_scale.to(weight.dtype) * \
264
+ scales.repeat_interleave(group_size, 1).to(weight.dtype)
265
+
266
+ marlin_qweight = ops.gptq_marlin_repack(
267
+ b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
268
+ perm=torch.empty(0, dtype=torch.int, device=device),
269
+ size_k=size_k,
270
+ size_n=size_n,
271
+ num_bits=4,
272
+ )
273
+
274
+ marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
275
+ size_k=size_k,
276
+ size_n=size_n,
277
+ group_size=group_size)
278
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
279
+
280
+ global_scale = fp4_marlin_process_global_scale(global_scale)
281
+
282
+ return weight_ref.T, marlin_qweight, marlin_scales, global_scale
build/torch26-cxx98-cu126-aarch64-linux/quantization/utils/marlin_utils_fp8.py CHANGED
@@ -1,10 +1,13 @@
 
 
 
1
  from typing import Optional
2
 
3
  import torch
4
 
5
  import quantization as ops
6
 
7
- from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
 
9
 
10
  def is_fp8_marlin_supported():
@@ -13,88 +16,107 @@ def is_fp8_marlin_supported():
13
  return capability >= 80
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def apply_fp8_marlin_linear(
17
- input: torch.Tensor,
18
- weight: torch.Tensor,
19
- weight_scale: torch.Tensor,
20
- workspace: torch.Tensor,
21
- size_n: int,
22
- size_k: int,
23
- bias: Optional[torch.Tensor],
24
- ) -> torch.Tensor:
25
  # For GPUs that lack FP8 hardware support, we can leverage the
26
  # Marlin kernel for fast weight-only FP8 quantization
27
 
28
  reshaped_x = input.reshape(-1, input.shape[-1])
29
- out_shape = input.shape[:-1] + (size_n,)
30
-
31
- output = ops.fp8_marlin_gemm(
32
- a=reshaped_x,
33
- b_q_weight=weight,
34
- b_scales=weight_scale,
35
- workspace=workspace,
36
- num_bits=8,
37
- size_m=reshaped_x.shape[0],
38
- size_n=size_n,
39
- size_k=size_k,
40
- )
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  if bias is not None:
43
  output.add_(bias) # In-place add
44
 
45
  return output.reshape(out_shape)
46
 
47
-
48
- def prepare_fp8_layer_for_marlin(
49
- layer: torch.nn.Module, strategy: str = "tensor"
50
- ) -> None:
51
- part_size_n = layer.output_size_per_partition
52
- part_size_k = layer.input_size_per_partition
53
-
54
- device = layer.weight.device
55
-
56
- # WORKSPACE
57
- layer.workspace = marlin_make_workspace(part_size_n, device)
58
-
59
- # WEIGHT
60
- # Repack weights to marlin format
61
- marlin_qweight = ops.gptq_marlin_repack(
62
- b_q_weight=pack_fp8_to_int32(layer.weight),
63
- perm=torch.empty(0, dtype=torch.int, device=device),
64
- size_k=part_size_k,
65
- size_n=part_size_n,
66
- num_bits=8,
67
- )
68
- layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
-
70
- # WEIGHT SCALES
71
- scales = layer.weight_scale.to(layer.orig_dtype)
72
- # Permute scales
73
- marlin_scales = marlin_permute_scales(
74
- s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
- )
76
- layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
-
78
-
79
- def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
  """
81
  Repack FP8 weights to gptq format (packed int32 elements)
82
  """
83
  assert fp8_tensor.dtype == torch.float8_e4m3fn
84
- assert fp8_tensor.shape[0] % 4 == 0
85
-
86
- # Reshape to prepare for packing
87
- reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Convert fp8 to uint8 (byte) representation
90
- byte_tensor = reshaped.view(torch.uint8)
 
 
91
 
92
- # Pack 4 uint8 values into one int32
93
- packed = (
94
- byte_tensor[:, 0].to(torch.int32)
95
- | (byte_tensor[:, 1].to(torch.int32) << 8)
96
- | (byte_tensor[:, 2].to(torch.int32) << 16)
97
- | (byte_tensor[:, 3].to(torch.int32) << 24)
98
- )
99
 
100
- return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  from typing import Optional
5
 
6
  import torch
7
 
8
  import quantization as ops
9
 
10
+ from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
11
 
12
 
13
  def is_fp8_marlin_supported():
 
16
  return capability >= 80
17
 
18
 
19
+ def fp8_fused_exponent_bias_into_scales(scales):
20
+ fp8_exponent = 4
21
+ if scales.dtype == torch.half:
22
+ target_exponent = 5
23
+ elif scales.dtype == torch.bfloat16:
24
+ target_exponent = 8
25
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
26
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
27
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
28
+ s = torch.ones_like(scales) * 2
29
+ s = s**exponent_bias
30
+ return scales * s
31
+
32
+
33
  def apply_fp8_marlin_linear(
34
+ input: torch.Tensor,
35
+ weight: torch.Tensor,
36
+ weight_scale: torch.Tensor,
37
+ workspace: torch.Tensor,
38
+ size_n: int,
39
+ size_k: int,
40
+ bias: Optional[torch.Tensor],
41
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
42
  # For GPUs that lack FP8 hardware support, we can leverage the
43
  # Marlin kernel for fast weight-only FP8 quantization
44
 
45
  reshaped_x = input.reshape(-1, input.shape[-1])
46
+ out_shape = input.shape[:-1] + (size_n, )
47
+
48
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
49
+ n=size_n,
50
+ k=size_k,
51
+ device=input.device,
52
+ dtype=input.dtype)
53
+
54
+ output = ops.gptq_marlin_gemm(a=reshaped_x,
55
+ c=None,
56
+ b_q_weight=weight,
57
+ b_scales=weight_scale,
58
+ global_scale=None,
59
+ b_zeros=None,
60
+ g_idx=None,
61
+ perm=None,
62
+ workspace=workspace,
63
+ b_q_type=scalar_types.float8_e4m3fn,
64
+ size_m=reshaped_x.size(0),
65
+ size_n=size_n,
66
+ size_k=size_k,
67
+ use_atomic_add=use_atomic_add,
68
+ use_fp32_reduce=use_fp32_reduce)
69
 
70
  if bias is not None:
71
  output.add_(bias) # In-place add
72
 
73
  return output.reshape(out_shape)
74
 
75
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
76
+ size_k_first: bool = True) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  Repack FP8 weights to gptq format (packed int32 elements)
79
  """
80
  assert fp8_tensor.dtype == torch.float8_e4m3fn
81
+ assert fp8_tensor.ndim == 2
82
+
83
+ fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
84
+ fp8_tensor = fp8_tensor.contiguous()
85
+ # fp8_tensor is contiguous and have shape (N, K) now
86
+ # with `.view(torch.int32)`, it become (N, K // 4)
87
+ int32_tensor = fp8_tensor.view(torch.int32)
88
+ return int32_tensor.T.contiguous() if size_k_first else int32_tensor
89
+
90
+
91
+ def marlin_quant_fp8_torch(weight, group_size):
92
+ size_n, size_k = weight.shape
93
+ device = weight.device
94
+
95
+ if group_size != -1:
96
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
97
+ repeated_scales = scales.repeat_interleave(group_size, 1)
98
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
99
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
100
+ else:
101
+ scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
102
+ repeated_scales = scales.repeat_interleave(size_k, 1)
103
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
104
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
105
+
106
+ packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
107
+ marlin_qweight = ops.gptq_marlin_repack(
108
+ b_q_weight=packed_weight,
109
+ perm=torch.empty(0, dtype=torch.int, device=device),
110
+ size_k=size_k,
111
+ size_n=size_n,
112
+ num_bits=8,
113
+ )
114
 
115
+ marlin_scales = marlin_permute_scales(s=scales.T,
116
+ size_k=size_k,
117
+ size_n=size_n,
118
+ group_size=group_size)
119
 
120
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
 
 
 
 
 
 
121
 
122
+ return weight_ref.T, marlin_qweight, marlin_scales
build/torch27-cxx11-cu126-aarch64-linux/quantization/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
  from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
  from .cutlass import (
 
3
  cutlass_scaled_mm_supports_fp8,
4
  cutlass_scaled_mm,
5
  cutlass_scaled_mm_azp,
6
  )
7
  from .marlin import (
8
  awq_marlin_repack,
9
- fp8_marlin_gemm,
10
  gptq_marlin_gemm,
11
  gptq_marlin_repack,
12
  gptq_marlin_24_gemm,
@@ -25,8 +25,8 @@ __all__ = [
25
  "awq_marlin_repack",
26
  "cutlass_scaled_mm",
27
  "cutlass_scaled_mm_azp",
 
28
  "cutlass_scaled_mm_supports_fp8",
29
- "fp8_marlin_gemm",
30
  "gptq_marlin_24_gemm",
31
  "gptq_marlin_gemm",
32
  "gptq_marlin_repack",
 
1
  from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
  from .cutlass import (
3
+ cutlass_scaled_mm_supports_block_fp8,
4
  cutlass_scaled_mm_supports_fp8,
5
  cutlass_scaled_mm,
6
  cutlass_scaled_mm_azp,
7
  )
8
  from .marlin import (
9
  awq_marlin_repack,
 
10
  gptq_marlin_gemm,
11
  gptq_marlin_repack,
12
  gptq_marlin_24_gemm,
 
25
  "awq_marlin_repack",
26
  "cutlass_scaled_mm",
27
  "cutlass_scaled_mm_azp",
28
+ "cutlass_scaled_mm_supports_block_fp8",
29
  "cutlass_scaled_mm_supports_fp8",
 
30
  "gptq_marlin_24_gemm",
31
  "gptq_marlin_gemm",
32
  "gptq_marlin_repack",
build/torch27-cxx11-cu126-aarch64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_9035540
3
+ ops = torch.ops._quantization_9035540
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_9035540::{op_name}"
build/torch27-cxx11-cu126-aarch64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_9035540.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7e8455e10805adf431198b60afbdbc1c7d79e65a67aab2a501ef9fe822484f3c
3
- size 67890208
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:219fc94b48e46777769dd2cd61785791b4fd00c58824d6de5252defbf48c30e5
3
+ size 159999608
build/torch27-cxx11-cu126-aarch64-linux/quantization/compressed_tensors.py CHANGED
@@ -2,17 +2,7 @@ from typing import Optional, Tuple
2
 
3
  import torch
4
 
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
-
16
 
17
  # fp8
18
  def scaled_fp8_quant(
@@ -21,7 +11,8 @@ def scaled_fp8_quant(
21
  num_token_padding: Optional[int] = None,
22
  scale_ub: Optional[torch.Tensor] = None,
23
  use_per_token_if_dynamic: bool = False,
24
- ) -> Tuple[torch.Tensor, torch.Tensor]:
 
25
  """
26
  Quantize input tensor to FP8 and return quantized tensor and scale.
27
 
@@ -42,30 +33,36 @@ def scaled_fp8_quant(
42
  in the dynamic quantization case.
43
 
44
  Returns:
45
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
  scaling factor.
47
  """
48
  # This code assumes batch_dim and num_tokens are flattened
49
- assert input.ndim == 2
50
- shape: Union[Tuple[int, int], torch.Size] = input.shape
51
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
- # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
- # if current_platform.is_rocm() else torch.float8_e4m3fn
54
- out_dtype = torch.float8_e4m3fn
55
  if num_token_padding:
56
  shape = (max(num_token_padding, input.shape[0]), shape[1])
57
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
 
 
 
 
 
58
 
59
  if scale is None:
60
  if use_per_token_if_dynamic:
61
- scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
- ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
 
63
  else:
64
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
  ops.dynamic_scaled_fp8_quant(output, input, scale)
66
  else:
67
  # num_token_padding not implemented for this case
68
- assert scale.numel() == 1 or num_token_padding is None
69
  ops.static_scaled_fp8_quant(output, input, scale)
70
 
71
  return output, scale
@@ -76,8 +73,8 @@ def scaled_int8_quant(
76
  input: torch.Tensor,
77
  scale: Optional[torch.Tensor] = None,
78
  azp: Optional[torch.Tensor] = None,
79
- symmetric: bool = True,
80
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
  """
82
  Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
 
@@ -90,21 +87,25 @@ def scaled_int8_quant(
90
  symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
 
92
  Returns:
93
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
  """
95
  output = torch.empty_like(input, dtype=torch.int8)
96
  if scale is not None:
97
  # static-per-tensor quantization.
98
  assert symmetric == (
99
- azp is None
100
- ), "azp must only be provided for asymmetric quantization."
101
  ops.static_scaled_int8_quant(output, input, scale, azp)
102
  return output, scale, azp
103
 
104
  # dynamic-per-token quantization.
105
- input_scales = torch.empty(
106
- (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
- )
108
- input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
- ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
 
 
110
  return output, input_scales, input_azp
 
 
 
2
 
3
  import torch
4
 
5
+ from ._ops import ops
 
 
 
 
 
 
 
 
 
 
6
 
7
  # fp8
8
  def scaled_fp8_quant(
 
11
  num_token_padding: Optional[int] = None,
12
  scale_ub: Optional[torch.Tensor] = None,
13
  use_per_token_if_dynamic: bool = False,
14
+ output: Optional[torch.Tensor] = None,
15
+ ) -> tuple[torch.Tensor, torch.Tensor]:
16
  """
17
  Quantize input tensor to FP8 and return quantized tensor and scale.
18
 
 
33
  in the dynamic quantization case.
34
 
35
  Returns:
36
+ tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
37
  scaling factor.
38
  """
39
  # This code assumes batch_dim and num_tokens are flattened
40
+ assert (input.ndim == 2)
41
+ shape: Union[tuple[int, int], torch.Size] = input.shape
42
+ # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
43
+ out_dtype: torch.dtype = current_platform.fp8_dtype()
 
 
44
  if num_token_padding:
45
  shape = (max(num_token_padding, input.shape[0]), shape[1])
46
+ if output is None:
47
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
48
+ else:
49
+ assert num_token_padding is None, \
50
+ "padding not supported if output passed in"
51
+ assert output.dtype == out_dtype
52
 
53
  if scale is None:
54
  if use_per_token_if_dynamic:
55
+ scale = torch.empty((shape[0], 1),
56
+ device=input.device,
57
+ dtype=torch.float32)
58
+ ops.dynamic_per_token_scaled_fp8_quant(
59
+ output, input.contiguous(), scale, scale_ub)
60
  else:
61
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
62
  ops.dynamic_scaled_fp8_quant(output, input, scale)
63
  else:
64
  # num_token_padding not implemented for this case
65
+ assert (scale.numel() == 1 and num_token_padding is None)
66
  ops.static_scaled_fp8_quant(output, input, scale)
67
 
68
  return output, scale
 
73
  input: torch.Tensor,
74
  scale: Optional[torch.Tensor] = None,
75
  azp: Optional[torch.Tensor] = None,
76
+ symmetric: bool = True
77
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
78
  """
79
  Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
80
 
 
87
  symmetric: Whether to use symmetric quantization (scale only, azp ignored).
88
 
89
  Returns:
90
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
91
  """
92
  output = torch.empty_like(input, dtype=torch.int8)
93
  if scale is not None:
94
  # static-per-tensor quantization.
95
  assert symmetric == (
96
+ azp
97
+ is None), "azp must only be provided for asymmetric quantization."
98
  ops.static_scaled_int8_quant(output, input, scale, azp)
99
  return output, scale, azp
100
 
101
  # dynamic-per-token quantization.
102
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
103
+ device=input.device,
104
+ dtype=torch.float32)
105
+ input_azp = None if symmetric else torch.empty_like(input_scales,
106
+ dtype=torch.int32)
107
+ ops.dynamic_scaled_int8_quant(output, input.contiguous(),
108
+ input_scales, input_azp)
109
  return output, input_scales, input_azp
110
+
111
+
build/torch27-cxx11-cu126-aarch64-linux/quantization/cutlass.py CHANGED
@@ -2,22 +2,18 @@ from typing import Optional
2
 
3
  import torch
4
 
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
 
16
 
17
  def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
  return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
 
20
 
 
 
 
 
21
  def cutlass_scaled_mm(
22
  a: torch.Tensor,
23
  b: torch.Tensor,
@@ -33,12 +29,10 @@ def cutlass_scaled_mm(
33
  m = a.shape[0]
34
  n = b.shape[1]
35
 
36
- # if current_platform.is_rocm():
37
- # triton_scaled_mm_module = importlib.import_module(
38
- # "vllm.model_executor.layers.quantization.compressed_tensors."
39
- # "triton_scaled_mm")
40
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
 
43
  out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
 
 
2
 
3
  import torch
4
 
5
+ from ._ops import ops
6
+ from .platforms import current_platform
 
 
 
 
 
 
 
 
7
 
8
 
9
  def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
10
  return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
11
 
12
 
13
+ def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
14
+ return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability)
15
+
16
+
17
  def cutlass_scaled_mm(
18
  a: torch.Tensor,
19
  b: torch.Tensor,
 
29
  m = a.shape[0]
30
  n = b.shape[1]
31
 
32
+ cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
33
+ if not cutlass_compatible_b:
34
+ from .triton_scaled_mm import triton_scaled_mm
35
+ return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
 
 
36
 
37
  out = torch.empty((m, n), dtype=out_dtype, device=a.device)
38
 
build/torch27-cxx11-cu126-aarch64-linux/quantization/marlin.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING
2
 
3
  import torch
4
 
@@ -30,58 +30,30 @@ except ImportError as e:
30
  from .scalar_type import ScalarType
31
 
32
 
33
- # fp8 marlin
34
- def fp8_marlin_gemm(
35
- a: torch.Tensor,
36
- b_q_weight: torch.Tensor,
37
- b_scales: torch.Tensor,
38
- workspace: torch.Tensor,
39
- num_bits: int,
40
- size_m: int,
41
- size_n: int,
42
- size_k: int,
43
- ) -> torch.Tensor:
44
- return ops.fp8_marlin_gemm(
45
- a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
- )
47
-
48
-
49
  # gptq_marlin
50
- def gptq_marlin_gemm(
51
- a: torch.Tensor,
52
- b_q_weight: torch.Tensor,
53
- b_scales: torch.Tensor,
54
- b_zeros: torch.Tensor,
55
- g_idx: torch.Tensor,
56
- perm: torch.Tensor,
57
- workspace: torch.Tensor,
58
- b_q_type: ScalarType,
59
- size_m: int,
60
- size_n: int,
61
- size_k: int,
62
- is_k_full: bool,
63
- has_zp: bool = False,
64
- use_fp32_reduce: bool = False,
65
- is_zp_float: bool = False,
66
- ) -> torch.Tensor:
67
- return ops.gptq_marlin_gemm(
68
- a,
69
- b_q_weight,
70
- b_scales,
71
- b_zeros,
72
- g_idx,
73
- perm,
74
- workspace,
75
- b_q_type.id,
76
- size_m,
77
- size_n,
78
- size_k,
79
- is_k_full,
80
- has_zp,
81
- use_fp32_reduce,
82
- is_zp_float,
83
- )
84
-
85
 
86
  # gptq_marlin
87
  def gptq_marlin_repack(
@@ -153,14 +125,6 @@ def marlin_qqq_gemm(
153
  # Fake ops
154
 
155
  if hasattr(ops, "gptq_marlin_24_gemm"):
156
- @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
- def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
- b_scales: torch.Tensor, workspace: torch.Tensor,
159
- num_bits: int, size_m: torch.SymInt,
160
- size_n: torch.SymInt,
161
- size_k: torch.SymInt) -> torch.Tensor:
162
- return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
-
164
  @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
  def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
  b_meta: torch.Tensor, b_scales: torch.Tensor,
@@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"):
172
 
173
  @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
  def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
- b_q_weight: torch.Tensor,
176
- b_scales: torch.Tensor,
177
- b_zeros: torch.Tensor,
178
- g_idx: torch.Tensor,
179
- perm: torch.Tensor,
180
- workspace: torch.Tensor,
181
- b_q_type: ScalarType,
182
- size_m: torch.SymInt,
183
- size_n: torch.SymInt,
184
- size_k: torch.SymInt,
185
- is_k_full: bool,
186
- has_zp: bool = False,
187
- use_fp32_reduce: bool = False,
188
- is_zp_float: bool = False) -> torch.Tensor:
 
 
189
  return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
 
191
  @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
 
1
+ from typing import TYPE_CHECKING, Optional
2
 
3
  import torch
4
 
 
30
  from .scalar_type import ScalarType
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # gptq_marlin
34
+ def gptq_marlin_gemm(a: torch.Tensor,
35
+ c: Optional[torch.Tensor],
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ global_scale: Optional[torch.Tensor],
39
+ b_zeros: Optional[torch.Tensor],
40
+ g_idx: Optional[torch.Tensor],
41
+ perm: Optional[torch.Tensor],
42
+ workspace: torch.Tensor,
43
+ b_q_type: ScalarType,
44
+ size_m: int,
45
+ size_n: int,
46
+ size_k: int,
47
+ is_k_full: bool = True,
48
+ use_atomic_add: bool = False,
49
+ use_fp32_reduce: bool = False,
50
+ is_zp_float: bool = False) -> torch.Tensor:
51
+ return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
52
+ global_scale, b_zeros, g_idx, perm,
53
+ workspace, b_q_type.id, size_m,
54
+ size_n, size_k, is_k_full,
55
+ use_atomic_add, use_fp32_reduce,
56
+ is_zp_float)
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # gptq_marlin
59
  def gptq_marlin_repack(
 
125
  # Fake ops
126
 
127
  if hasattr(ops, "gptq_marlin_24_gemm"):
 
 
 
 
 
 
 
 
128
  @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
129
  def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
130
  b_meta: torch.Tensor, b_scales: torch.Tensor,
 
136
 
137
  @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
138
  def _gptq_marlin_gemm_fake(a: torch.Tensor,
139
+ c: Optional[torch.Tensor],
140
+ b_q_weight: torch.Tensor,
141
+ b_scales: torch.Tensor,
142
+ global_scale: Optional[torch.Tensor],
143
+ b_zeros: Optional[torch.Tensor],
144
+ g_idx: Optional[torch.Tensor],
145
+ perm: Optional[torch.Tensor],
146
+ workspace: torch.Tensor,
147
+ b_q_type_id: int,
148
+ size_m: torch.SymInt,
149
+ size_n: torch.SymInt,
150
+ size_k: torch.SymInt,
151
+ is_k_full: bool = True,
152
+ use_atomic_add: bool = False,
153
+ use_fp32_reduce: bool = False,
154
+ is_zp_float: bool = False) -> torch.Tensor:
155
  return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
156
 
157
  @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
build/torch27-cxx11-cu126-aarch64-linux/quantization/platforms.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import lru_cache
3
+ from typing import NamedTuple
4
+
5
+ import torch
6
+
7
+ IS_ROCM = torch.version.hip is not None
8
+
9
+
10
+ class DeviceCapability(NamedTuple):
11
+ major: int
12
+ minor: int
13
+
14
+ def as_version_str(self) -> str:
15
+ return f"{self.major}.{self.minor}"
16
+
17
+ def to_int(self) -> int:
18
+ """
19
+ Express device capability as an integer ``<major><minor>``.
20
+
21
+ It is assumed that the minor version is always a single digit.
22
+ """
23
+ assert 0 <= self.minor < 10
24
+ return self.major * 10 + self.minor
25
+
26
+
27
+ class Platform(ABC):
28
+ simple_compile_backend: str = "inductor"
29
+
30
+ @classmethod
31
+ @abstractmethod
32
+ def get_device_name(cls, device_id: int = 0) -> str: ...
33
+
34
+ @abstractmethod
35
+ def is_rocm(self): ...
36
+
37
+
38
+ class CudaPlatform(Platform):
39
+ @classmethod
40
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
41
+ major, minor = torch.cuda.get_device_capability(device_id)
42
+ return DeviceCapability(major=major, minor=minor)
43
+
44
+ @classmethod
45
+ @lru_cache(maxsize=8)
46
+ def get_device_name(cls, device_id: int = 0) -> str:
47
+ return torch.cuda.get_device_name(0)
48
+
49
+ def is_rocm(self):
50
+ return False
51
+
52
+
53
+ class RocmPlatform(Platform):
54
+ @classmethod
55
+ @lru_cache(maxsize=8)
56
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
57
+ major, minor = torch.cuda.get_device_capability(device_id)
58
+ return DeviceCapability(major=major, minor=minor)
59
+
60
+ @classmethod
61
+ @lru_cache(maxsize=8)
62
+ def get_device_name(cls, device_id: int = 0) -> str:
63
+ return torch.cuda.get_device_name(device_id)
64
+
65
+ def is_rocm(self):
66
+ return True
67
+
68
+
69
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch27-cxx11-cu126-aarch64-linux/quantization/scalar_type.py CHANGED
@@ -1,9 +1,14 @@
 
 
 
1
  import functools
2
  import struct
3
  from dataclasses import dataclass
4
  from enum import Enum
5
  from typing import Optional, Union
6
 
 
 
7
 
8
  # Mirrors enum in `core/scalar_type.hpp`
9
  class NanRepr(Enum):
@@ -121,8 +126,8 @@ class ScalarType:
121
  min_raw = max_raw | sign_bit_double
122
  return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
  else:
124
- assert (not self.is_signed() or
125
- self.size_bits <= 64), "Cannot represent min as a int64_t"
126
 
127
  if self.is_signed():
128
  return -(1 << (self.size_bits - 1))
@@ -156,6 +161,8 @@ class ScalarType:
156
  assert offset <= 64, \
157
  f"ScalarType fields too big {offset} to fit into an int64"
158
 
 
 
159
  return val
160
 
161
  @property
@@ -293,6 +300,13 @@ class ScalarType:
293
  ret.id # noqa B018: make sure the id is cached
294
  return ret
295
 
 
 
 
 
 
 
 
296
 
297
  # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
  # for floating point types (leading f) the scheme is:
@@ -319,6 +333,9 @@ class scalar_types:
319
  # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
  float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
 
 
 
 
322
  # "gptq" types
323
  uint2b2 = ScalarType.uint(2, 2)
324
  uint3b4 = ScalarType.uint(3, 4)
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  import functools
5
  import struct
6
  from dataclasses import dataclass
7
  from enum import Enum
8
  from typing import Optional, Union
9
 
10
+ _SCALAR_TYPES_ID_MAP = {}
11
+
12
 
13
  # Mirrors enum in `core/scalar_type.hpp`
14
  class NanRepr(Enum):
 
126
  min_raw = max_raw | sign_bit_double
127
  return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
128
  else:
129
+ assert (not self.is_signed() or self.size_bits
130
+ <= 64), "Cannot represent min as a int64_t"
131
 
132
  if self.is_signed():
133
  return -(1 << (self.size_bits - 1))
 
161
  assert offset <= 64, \
162
  f"ScalarType fields too big {offset} to fit into an int64"
163
 
164
+ _SCALAR_TYPES_ID_MAP[val] = self
165
+
166
  return val
167
 
168
  @property
 
300
  ret.id # noqa B018: make sure the id is cached
301
  return ret
302
 
303
+ @classmethod
304
+ def from_id(cls, scalar_type_id: int):
305
+ if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
306
+ raise ValueError(
307
+ f"scalar_type_id {scalar_type_id} doesn't exists.")
308
+ return _SCALAR_TYPES_ID_MAP[scalar_type_id]
309
+
310
 
311
  # naming generally follows: https://github.com/jax-ml/ml_dtypes
312
  # for floating point types (leading f) the scheme is:
 
333
  # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
334
  float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
335
 
336
+ # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
337
+ float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
338
+
339
  # "gptq" types
340
  uint2b2 = ScalarType.uint(2, 2)
341
  uint3b4 = ScalarType.uint(3, 4)
build/torch27-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils.py CHANGED
@@ -1,4 +1,7 @@
1
- from typing import List, Optional, Tuple
 
 
 
2
 
3
  import numpy
4
  import torch
@@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True
42
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
  # TODO: we may want to move this into the C++ so its closer to the actual impl
44
  def query_marlin_supported_quant_types(
45
- has_zp: bool, device_capability: Optional[int] = None
 
 
46
  ):
47
  if device_capability is None:
48
  capability_tuple = torch.cuda.get_device_capability()
@@ -51,137 +56,141 @@ def query_marlin_supported_quant_types(
51
  if device_capability < 80:
52
  return []
53
 
 
 
 
 
 
 
 
 
 
 
54
  if has_zp:
55
  # AWQ style, unsigned + runtime zero-point
56
- return [scalar_types.uint4, scalar_types.uint8]
57
  else:
58
  # GPTQ style, unsigned + symmetric bias
59
- # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
- # to add `scalar_types.float8_e4m3fn` here
61
- return [scalar_types.uint4b8, scalar_types.uint8b128]
 
62
 
63
 
64
  def _check_marlin_supported(
65
- quant_type: ScalarType,
66
- group_size: Optional[int],
67
- has_zp: bool,
68
- device_capability: Optional[int] = None,
69
- ) -> Tuple[bool, Optional[str]]:
70
 
71
  if device_capability is None:
72
  capability_tuple = torch.cuda.get_device_capability()
73
  device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
 
75
- supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
 
76
 
77
  if quant_type not in supported_types:
78
- return (
79
- False,
80
- f"Marlin does not support weight_bits = {quant_type}. "
81
- f"Only types = {supported_types} "
82
- f"are supported (for group_size = {group_size}, "
83
- f"device_capability = {device_capability}, zp = {has_zp}).",
84
- )
85
- if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
- return (
87
- False,
88
- f"Marlin does not support group_size = {group_size}. "
89
- f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
- "are supported.",
91
- )
92
 
93
  return True, None
94
 
95
 
96
- def check_marlin_supported(
97
- quant_type: ScalarType,
98
- group_size: int,
99
- has_zp: bool = False,
100
- device_capability: Optional[int] = None,
101
- ) -> bool:
102
- cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
  return cond
104
 
105
 
106
- def verify_marlin_supported(
107
- quant_type: ScalarType, group_size: int, has_zp: bool = False
108
- ) -> None:
109
  cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
  if not cond:
111
  assert err_msg is not None
112
  raise ValueError(err_msg)
113
 
114
 
115
- def verify_marlin_supports_shape(
116
- output_size_per_partition: int,
117
- input_size_per_partition: int,
118
- input_size: int,
119
- group_size: int,
120
- ) -> None:
121
 
122
  # Validate output_size_per_partition
123
  if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
- raise ValueError(
125
- f"Weight output_size_per_partition = "
126
- f"{output_size_per_partition} is not divisible by "
127
- f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
- "Consider reducing tensor_parallel_size or running "
129
- "with --quantization gptq."
130
- )
131
 
132
  # Validate input_size_per_partition
133
  if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
- raise ValueError(
135
- f"Weight input_size_per_partition = "
136
- f"{input_size_per_partition} is not divisible "
137
- f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
- "Consider reducing tensor_parallel_size or running "
139
- "with --quantization gptq."
140
- )
141
-
142
- if group_size < input_size and input_size_per_partition % group_size != 0:
143
  raise ValueError(
144
  f"Weight input_size_per_partition = {input_size_per_partition}"
145
- f" is not divisible by group_size = {group_size}."
146
  "Consider reducing tensor_parallel_size or running "
147
- "with --quantization gptq."
148
- )
149
 
150
 
151
- def check_marlin_supports_shape(
152
- output_size_per_partition: int,
153
- input_size_per_partition: int,
154
- input_size: int,
155
- group_size: int,
156
- ) -> Tuple[bool, Optional[str]]:
157
  try:
158
- verify_marlin_supports_shape(
159
- output_size_per_partition, input_size_per_partition, input_size, group_size
160
- )
161
  except ValueError as e:
162
  return False, e.__str__()
163
  return True, None
164
 
165
 
166
- def marlin_make_workspace(
167
- output_size_per_partition: int, device: torch.device
168
- ) -> torch.Tensor:
169
- max_workspace_size = (
170
- output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
- ) * GPTQ_MARLIN_MAX_PARALLEL
172
 
173
- return torch.zeros(
174
- max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
- )
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
  return (not act_order) or (act_order and not is_row_parallel)
180
 
181
 
182
- def marlin_repeat_scales_on_all_ranks(
183
- act_order: bool, group_size: int, is_row_parallel: bool
184
- ) -> bool:
185
  # Need to repeat scales on every rank if act_ordering or
186
  # channelwise and RowParallelLinear
187
  is_channelwise = group_size == -1
@@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks(
189
 
190
 
191
  def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
- return torch.nn.Parameter(
193
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
- )
195
 
196
 
197
  def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
- return torch.nn.Parameter(
199
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
- )
201
 
202
 
203
- def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
204
  g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
  return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
 
207
 
208
  def get_scale_perms():
209
- scale_perm: List[int] = []
210
  for i in range(8):
211
  scale_perm.extend([i + 8 * j for j in range(8)])
212
- scale_perm_single: List[int] = []
213
  for i in range(4):
214
- scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
 
215
  return scale_perm, scale_perm_single
216
 
217
 
218
- def marlin_permute_scales(
219
- s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
- ) -> torch.Tensor:
221
 
222
  scale_perm, scale_perm_single = get_scale_perms()
223
  if group_size < size_k and group_size != -1:
@@ -247,9 +255,8 @@ def marlin_moe_permute_scales(
247
  return output
248
 
249
 
250
- def marlin_zero_points(
251
- zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
- ) -> torch.Tensor:
253
  # Permute zero-points in a similar way to scales, but do not use the
254
  # "single" permutation, since zero-points are applied on every MMA
255
  scale_perm, _ = get_scale_perms()
@@ -270,9 +277,8 @@ def marlin_zero_points(
270
  return zp
271
 
272
 
273
- def awq_to_marlin_zero_points(
274
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
- ) -> torch.Tensor:
276
  # AWQ zero-points are quantized and packed on the column dim.
277
  # In addition, the values are permuted based on dequantizer.
278
  # Here we undo both of these, and then apply marlin permutation
@@ -294,9 +300,8 @@ def awq_to_marlin_zero_points(
294
  return marlin_zp
295
 
296
 
297
- def moe_awq_to_marlin_zero_points(
298
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
- ):
300
  num_experts = q_zp_packed.shape[0]
301
  output = torch.empty(
302
  (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
@@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points(
304
  dtype=q_zp_packed.dtype,
305
  )
306
  for e in range(num_experts):
307
- output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
 
308
  return output
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def apply_gptq_marlin_linear(
312
- input: torch.Tensor,
313
- weight: torch.Tensor,
314
- weight_scale: torch.Tensor,
315
- weight_zp: torch.Tensor,
316
- g_idx: torch.Tensor,
317
- g_idx_sort_indices: torch.Tensor,
318
- workspace: torch.Tensor,
319
- wtype: ScalarType,
320
- output_size_per_partition: int,
321
- input_size_per_partition: int,
322
- is_k_full: bool,
323
- bias: Optional[torch.Tensor] = None,
324
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
- ) -> torch.Tensor:
326
  reshaped_x = input.reshape(-1, input.shape[-1])
327
- out_shape = input.shape[:-1] + (output_size_per_partition,)
328
-
329
- output = ops.gptq_marlin_gemm(
330
- reshaped_x,
331
- weight,
332
- weight_scale,
333
- weight_zp,
334
- g_idx,
335
- g_idx_sort_indices,
336
- workspace,
337
- wtype,
338
- size_m=reshaped_x.shape[0],
339
- size_n=output_size_per_partition,
340
- size_k=input_size_per_partition,
341
- is_k_full=is_k_full,
342
- has_zp=False,
343
- use_fp32_reduce=use_fp32_reduce,
344
- is_zp_float=False,
345
- )
 
 
 
 
 
 
346
 
347
  if bias is not None:
348
  output.add_(bias) # In-place add
@@ -351,39 +408,43 @@ def apply_gptq_marlin_linear(
351
 
352
 
353
  def apply_awq_marlin_linear(
354
- input: torch.Tensor,
355
- weight: torch.Tensor,
356
- weight_scale: torch.Tensor,
357
- weight_zp: torch.Tensor,
358
- g_idx: torch.Tensor,
359
- g_idx_sort_indices: torch.Tensor,
360
- workspace: torch.Tensor,
361
- quant_type: ScalarType,
362
- output_size_per_partition: int,
363
- input_size_per_partition: int,
364
- bias: Optional[torch.Tensor] = None,
365
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
- ) -> torch.Tensor:
367
  reshaped_x = input.reshape(-1, input.shape[-1])
368
- out_shape = input.shape[:-1] + (output_size_per_partition,)
369
-
370
- output = ops.gptq_marlin_gemm(
371
- reshaped_x,
372
- weight,
373
- weight_scale,
374
- weight_zp,
375
- g_idx,
376
- g_idx_sort_indices,
377
- workspace,
378
- quant_type,
379
- size_m=reshaped_x.shape[0],
380
- size_n=output_size_per_partition,
381
- size_k=input_size_per_partition,
382
- is_k_full=True,
383
- has_zp=True,
384
- use_fp32_reduce=use_fp32_reduce,
385
- is_zp_float=False,
386
- )
 
 
 
 
 
387
 
388
  if bias is not None:
389
  output.add_(bias) # In-place add
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
 
6
  import numpy
7
  import torch
 
45
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
46
  # TODO: we may want to move this into the C++ so its closer to the actual impl
47
  def query_marlin_supported_quant_types(
48
+ has_zp: Optional[bool] = None,
49
+ include_fp_type: bool = True,
50
+ device_capability: Optional[int] = None,
51
  ):
52
  if device_capability is None:
53
  capability_tuple = torch.cuda.get_device_capability()
 
56
  if device_capability < 80:
57
  return []
58
 
59
+ # - has_zp is True: return quant_types that has zero points
60
+ # - has_zp is False: return quant_types that has not zero points
61
+ # - has_zp is None: both
62
+ if has_zp is None:
63
+ types0 = query_marlin_supported_quant_types(False, include_fp_type,
64
+ device_capability)
65
+ types1 = query_marlin_supported_quant_types(True, include_fp_type,
66
+ device_capability)
67
+ return types0 + types1
68
+
69
  if has_zp:
70
  # AWQ style, unsigned + runtime zero-point
71
+ return [scalar_types.uint4]
72
  else:
73
  # GPTQ style, unsigned + symmetric bias
74
+ res = [scalar_types.uint4b8, scalar_types.uint8b128]
75
+ if include_fp_type:
76
+ res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
77
+ return res
78
 
79
 
80
  def _check_marlin_supported(
81
+ quant_type: ScalarType,
82
+ group_size: Optional[int],
83
+ has_zp: bool,
84
+ device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
 
85
 
86
  if device_capability is None:
87
  capability_tuple = torch.cuda.get_device_capability()
88
  device_capability = capability_tuple[0] * 10 + capability_tuple[1]
89
 
90
+ supported_types = query_marlin_supported_quant_types(
91
+ has_zp, True, device_capability)
92
 
93
  if quant_type not in supported_types:
94
+ return (False, f"Marlin does not support weight_bits = {quant_type}. "
95
+ f"Only types = {supported_types} "
96
+ f"are supported (for group_size = {group_size}, "
97
+ f"device_capability = {device_capability}, zp = {has_zp}).")
98
+ if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
99
+ return (False, f"Marlin does not support group_size = {group_size}. "
100
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
101
+ "are supported.")
 
 
 
 
 
 
102
 
103
  return True, None
104
 
105
 
106
+ def check_marlin_supported(quant_type: ScalarType,
107
+ group_size: int,
108
+ has_zp: bool = False,
109
+ device_capability: Optional[int] = None) -> bool:
110
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
111
+ device_capability)
 
112
  return cond
113
 
114
 
115
+ def verify_marlin_supported(quant_type: ScalarType,
116
+ group_size: int,
117
+ has_zp: bool = False) -> None:
118
  cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
119
  if not cond:
120
  assert err_msg is not None
121
  raise ValueError(err_msg)
122
 
123
 
124
+ def verify_marlin_supports_shape(output_size_per_partition: int,
125
+ input_size_per_partition: int,
126
+ input_size: int, group_size: int) -> None:
 
 
 
127
 
128
  # Validate output_size_per_partition
129
  if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
130
+ raise ValueError(f"Weight output_size_per_partition = "
131
+ f"{output_size_per_partition} is not divisible by "
132
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
133
+ "Consider reducing tensor_parallel_size or running "
134
+ "with --quantization gptq.")
 
 
135
 
136
  # Validate input_size_per_partition
137
  if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
138
+ raise ValueError(f"Weight input_size_per_partition = "
139
+ f"{input_size_per_partition} is not divisible "
140
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
141
+ "Consider reducing tensor_parallel_size or running "
142
+ "with --quantization gptq.")
143
+
144
+ if (group_size < input_size
145
+ and input_size_per_partition % group_size != 0):
 
146
  raise ValueError(
147
  f"Weight input_size_per_partition = {input_size_per_partition}"
148
+ f" is not divisible by group_size = {group_size}. "
149
  "Consider reducing tensor_parallel_size or running "
150
+ "with --quantization gptq.")
 
151
 
152
 
153
+ def check_marlin_supports_shape(output_size_per_partition: int,
154
+ input_size_per_partition: int,
155
+ input_size: int, group_size: int) \
156
+ -> tuple[bool, Optional[str]]:
 
 
157
  try:
158
+ verify_marlin_supports_shape(output_size_per_partition,
159
+ input_size_per_partition, input_size,
160
+ group_size)
161
  except ValueError as e:
162
  return False, e.__str__()
163
  return True, None
164
 
165
 
166
+ def marlin_make_workspace(output_size_per_partition: int,
167
+ device: torch.device) -> torch.Tensor:
168
+ max_workspace_size = (output_size_per_partition //
169
+ GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
 
 
170
 
171
+ return torch.zeros(max_workspace_size,
172
+ dtype=torch.int,
173
+ device=device,
174
+ requires_grad=False)
175
+
176
+
177
+ def marlin_make_workspace_new(device: torch.device,
178
+ max_blocks_per_sm: int = 1) -> torch.Tensor:
179
+ # In the new marlin kernel, we use the num of threadblocks as workspace
180
+ # size. The num of threadblocks is is sms_count * max_blocks_per_sm.
181
+ sms = torch.cuda.get_device_properties(device).multi_processor_count
182
+ return torch.zeros(sms * max_blocks_per_sm,
183
+ dtype=torch.int,
184
+ device=device,
185
+ requires_grad=False)
186
 
187
 
188
  def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
189
  return (not act_order) or (act_order and not is_row_parallel)
190
 
191
 
192
+ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
193
+ is_row_parallel: bool) -> bool:
 
194
  # Need to repeat scales on every rank if act_ordering or
195
  # channelwise and RowParallelLinear
196
  is_channelwise = group_size == -1
 
198
 
199
 
200
  def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
201
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
202
+ requires_grad=False)
 
203
 
204
 
205
  def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
206
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
207
+ requires_grad=False)
 
208
 
209
 
210
+ def marlin_sort_g_idx(
211
+ g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
212
  g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
213
  return g_idx[g_idx_sort_indices], g_idx_sort_indices
214
 
215
 
216
  def get_scale_perms():
217
+ scale_perm: list[int] = []
218
  for i in range(8):
219
  scale_perm.extend([i + 8 * j for j in range(8)])
220
+ scale_perm_single: list[int] = []
221
  for i in range(4):
222
+ scale_perm_single.extend(
223
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
224
  return scale_perm, scale_perm_single
225
 
226
 
227
+ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
228
+ group_size: int) -> torch.Tensor:
 
229
 
230
  scale_perm, scale_perm_single = get_scale_perms()
231
  if group_size < size_k and group_size != -1:
 
255
  return output
256
 
257
 
258
+ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
259
+ num_bits: int) -> torch.Tensor:
 
260
  # Permute zero-points in a similar way to scales, but do not use the
261
  # "single" permutation, since zero-points are applied on every MMA
262
  scale_perm, _ = get_scale_perms()
 
277
  return zp
278
 
279
 
280
+ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
281
+ size_n: int, num_bits: int) -> torch.Tensor:
 
282
  # AWQ zero-points are quantized and packed on the column dim.
283
  # In addition, the values are permuted based on dequantizer.
284
  # Here we undo both of these, and then apply marlin permutation
 
300
  return marlin_zp
301
 
302
 
303
+ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
304
+ size_n: int, num_bits: int):
 
305
  num_experts = q_zp_packed.shape[0]
306
  output = torch.empty(
307
  (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
 
309
  dtype=q_zp_packed.dtype,
310
  )
311
  for e in range(num_experts):
312
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
313
+ num_bits)
314
  return output
315
 
316
 
317
+ def maybe_warn_marlin_atomic_add(device, dtype):
318
+ if torch.compiler.is_dynamo_compiling():
319
+ return
320
+ device_capability = torch.cuda.get_device_capability(device)
321
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
322
+ logger.info_once(
323
+ "You are running Marlin kernel with bf16 on GPUs before SM90. "
324
+ "You can consider change to fp16 to achieve better performance "
325
+ "if possible.")
326
+
327
+
328
+ def maybe_warn_marlin_atomic_add_env():
329
+ if torch.compiler.is_dynamo_compiling():
330
+ return
331
+ if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
332
+ return
333
+ logger.info_once(
334
+ "Marlin kernel can achieve better performance for small size_n "
335
+ "with experimental use_atomic_add feature. "
336
+ "You can consider set environment variable "
337
+ "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
338
+
339
+
340
+ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
341
+ dtype: torch.dtype) -> bool:
342
+
343
+ # the performance of atomicAdd is better than global reduce
344
+ # only when m*n is small and k is large
345
+ if n >= 2048 or k < 2048 or device.type != "cuda":
346
+ return False
347
+
348
+ # disable atomicAdd reduce by default,
349
+ # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
350
+ if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
351
+ maybe_warn_marlin_atomic_add_env()
352
+ return False
353
+
354
+ # sm8x doesn't support atomicAdd + bfloat16 natively
355
+ device_capability = torch.cuda.get_device_capability(device)
356
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
357
+ maybe_warn_marlin_atomic_add(device, dtype)
358
+ return False
359
+
360
+ return True
361
+
362
+
363
  def apply_gptq_marlin_linear(
364
+ input: torch.Tensor,
365
+ weight: torch.Tensor,
366
+ weight_scale: torch.Tensor,
367
+ weight_zp: torch.Tensor,
368
+ g_idx: torch.Tensor,
369
+ g_idx_sort_indices: torch.Tensor,
370
+ workspace: torch.Tensor,
371
+ wtype: ScalarType,
372
+ output_size_per_partition: int,
373
+ input_size_per_partition: int,
374
+ is_k_full: bool,
375
+ bias: Optional[torch.Tensor] = None,
376
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
 
377
  reshaped_x = input.reshape(-1, input.shape[-1])
378
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
379
+
380
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
381
+ n=output_size_per_partition,
382
+ k=reshaped_x.size(1),
383
+ device=input.device,
384
+ dtype=input.dtype)
385
+
386
+ output = ops.gptq_marlin_gemm(reshaped_x,
387
+ None,
388
+ weight,
389
+ weight_scale,
390
+ None,
391
+ weight_zp,
392
+ g_idx,
393
+ g_idx_sort_indices,
394
+ workspace,
395
+ wtype,
396
+ size_m=reshaped_x.shape[0],
397
+ size_n=output_size_per_partition,
398
+ size_k=input_size_per_partition,
399
+ is_k_full=is_k_full,
400
+ use_atomic_add=use_atomic_add,
401
+ use_fp32_reduce=use_fp32_reduce,
402
+ is_zp_float=False)
403
 
404
  if bias is not None:
405
  output.add_(bias) # In-place add
 
408
 
409
 
410
  def apply_awq_marlin_linear(
411
+ input: torch.Tensor,
412
+ weight: torch.Tensor,
413
+ weight_scale: torch.Tensor,
414
+ weight_zp: torch.Tensor,
415
+ g_idx: torch.Tensor,
416
+ g_idx_sort_indices: torch.Tensor,
417
+ workspace: torch.Tensor,
418
+ quant_type: ScalarType,
419
+ output_size_per_partition: int,
420
+ input_size_per_partition: int,
421
+ bias: Optional[torch.Tensor] = None,
422
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
 
423
  reshaped_x = input.reshape(-1, input.shape[-1])
424
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
425
+
426
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
427
+ n=output_size_per_partition,
428
+ k=reshaped_x.size(1),
429
+ device=input.device,
430
+ dtype=input.dtype)
431
+
432
+ output = ops.gptq_marlin_gemm(reshaped_x,
433
+ None,
434
+ weight,
435
+ weight_scale,
436
+ None,
437
+ weight_zp,
438
+ g_idx,
439
+ g_idx_sort_indices,
440
+ workspace,
441
+ quant_type,
442
+ size_m=reshaped_x.shape[0],
443
+ size_n=output_size_per_partition,
444
+ size_k=input_size_per_partition,
445
+ use_atomic_add=use_atomic_add,
446
+ use_fp32_reduce=use_fp32_reduce,
447
+ is_zp_float=False)
448
 
449
  if bias is not None:
450
  output.add_(bias) # In-place add
build/torch27-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp4.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ import quantization as ops
9
+
10
+ from .marlin_utils import (
11
+ USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
12
+ should_use_atomic_add_reduce)
13
+ from quantization.scalar_type import scalar_types
14
+
15
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
16
+
17
+
18
+ def is_fp4_marlin_supported():
19
+ capability = torch.cuda.get_device_capability()
20
+ capability = capability[0] * 10 + capability[1]
21
+ return capability >= 80
22
+
23
+
24
+ def fp4_marlin_process_scales(marlin_scales):
25
+ if not (marlin_scales >= 0).all():
26
+ logger.warning_once(
27
+ "NVFP4 Marlin assumes the scales to be >=0, but has encountered "
28
+ "negative scales. Accuracy will likely be degraded. This is "
29
+ "because it changes the scales from FP8-S1E4M3 to a special "
30
+ "FP8-S0E5M3 format to speedup the dequantization.")
31
+
32
+ # convert to half first, we would convert to fp8 later
33
+ marlin_scales = marlin_scales.to(torch.half)
34
+
35
+ # 8 is the number of scale number using by one thread
36
+ marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
37
+ marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
38
+ marlin_scales.size(0) * 2, -1)
39
+
40
+ # fit the layout of fp8 dequantization
41
+ marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
42
+ marlin_scales.size(0), -1)
43
+
44
+ # We assume that weight_scale (FP8-S1E4M3) is always greater
45
+ # than or equal to 0. So we can convert
46
+ # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
47
+ # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
48
+ # when weight_scale > 0. This allows us to have an exponent bias
49
+ # closer to zero after dequantization.
50
+
51
+ marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
52
+ marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
53
+ marlin_scales = marlin_scales[:, 1::2].contiguous()
54
+
55
+ return marlin_scales
56
+
57
+
58
+ def fp4_marlin_process_global_scale(global_scale):
59
+ assert global_scale.dtype in [torch.half, torch.bfloat16]
60
+ fp4_exponent = 2
61
+ if global_scale.dtype == torch.half:
62
+ target_exponent = 5
63
+ elif global_scale.dtype == torch.bfloat16:
64
+ target_exponent = 8
65
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
66
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
67
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
68
+ return global_scale * (2.0**(exponent_bias - 7))
69
+
70
+
71
+ def apply_fp4_marlin_linear(
72
+ input: torch.Tensor,
73
+ weight: torch.Tensor,
74
+ weight_scale: torch.Tensor,
75
+ weight_scale_2: torch.Tensor,
76
+ workspace: torch.Tensor,
77
+ size_n: int,
78
+ size_k: int,
79
+ bias: Optional[torch.Tensor] = None,
80
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
81
+ # For GPUs that lack FP4 hardware support, we can leverage the
82
+ # Marlin kernel for fast weight-only FP4 quantization
83
+
84
+ reshaped_x = input.reshape(-1, input.shape[-1])
85
+ out_shape = input.shape[:-1] + (size_n, )
86
+
87
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
88
+ n=size_n,
89
+ k=size_k,
90
+ device=input.device,
91
+ dtype=input.dtype)
92
+
93
+ output = ops.gptq_marlin_gemm(a=reshaped_x,
94
+ c=None,
95
+ b_q_weight=weight,
96
+ b_scales=weight_scale,
97
+ global_scale=weight_scale_2,
98
+ b_zeros=None,
99
+ g_idx=None,
100
+ perm=None,
101
+ workspace=workspace,
102
+ b_q_type=scalar_types.float4_e2m1f,
103
+ size_m=reshaped_x.size(0),
104
+ size_n=size_n,
105
+ size_k=size_k,
106
+ use_atomic_add=use_atomic_add,
107
+ use_fp32_reduce=use_fp32_reduce)
108
+
109
+ if bias is not None:
110
+ output.add_(bias) # In-place add
111
+
112
+ return output.reshape(out_shape)
113
+
114
+
115
+ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
116
+ logger.warning_once(
117
+ "Your GPU does not have native support for FP4 computation but "
118
+ "FP4 quantization is being used. Weight-only FP4 compression will "
119
+ "be used leveraging the Marlin kernel. This may degrade "
120
+ "performance for compute-heavy workloads.")
121
+
122
+ part_size_n = layer.output_size_per_partition
123
+ part_size_k = layer.input_size_per_partition
124
+ param_dtype = layer.params_dtype
125
+
126
+ assert layer.weight.shape == (part_size_n, part_size_k // 2)
127
+
128
+ device = layer.weight.device
129
+
130
+ # WORKSPACE
131
+ layer.workspace = marlin_make_workspace_new(device)
132
+
133
+ # WEIGHT
134
+ # Repack weights to marlin format
135
+ perm = torch.empty(0, dtype=torch.int, device=device)
136
+ qweight = layer.weight.view(torch.int32).T.contiguous()
137
+
138
+ marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
139
+ perm=perm,
140
+ size_k=part_size_k,
141
+ size_n=part_size_n,
142
+ num_bits=4)
143
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
144
+
145
+ # WEIGHT SCALES
146
+ # Permute scales
147
+ weight_scale = layer.weight_scale.T.to(param_dtype)
148
+ weight_scale = marlin_permute_scales(s=weight_scale,
149
+ size_k=part_size_k,
150
+ size_n=part_size_n,
151
+ group_size=16)
152
+ weight_scale = fp4_marlin_process_scales(weight_scale)
153
+ layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
154
+
155
+ weight_scale_2 = layer.weight_scale_2.to(param_dtype)
156
+ weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
157
+ layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
158
+ requires_grad=False)
159
+
160
+ return
161
+
162
+
163
+ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
164
+ logger.warning_once(
165
+ "Your GPU does not have native support for FP4 computation but "
166
+ "FP4 quantization is being used. Weight-only FP4 compression will "
167
+ "be used leveraging the Marlin kernel. This may degrade "
168
+ "performance for compute-heavy workloads.")
169
+
170
+ e = layer.num_experts
171
+ k = layer.hidden_size
172
+ n = layer.intermediate_size_per_partition
173
+
174
+ # WORKSPACE
175
+ device = layer.w13_weight.device
176
+ param_dtype = layer.params_dtype
177
+ layer.workspace = marlin_make_workspace_new(device, 4)
178
+ perm = torch.empty(0, dtype=torch.int, device=device)
179
+
180
+ # WEIGHT
181
+ # Repack weights to marlin format
182
+ for name in ["w13_weight", "w2_weight"]:
183
+ weight = getattr(layer, name)
184
+ tensor_list = []
185
+ if "w13" in name:
186
+ size_n, size_k = n * 2, k
187
+ else:
188
+ size_n, size_k = k, n
189
+
190
+ assert weight.shape == (e, size_n, size_k // 2)
191
+
192
+ for i in range(e):
193
+ qweight = weight[i].view(torch.int32).T.contiguous()
194
+
195
+ marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
196
+ perm=perm,
197
+ size_k=size_k,
198
+ size_n=size_n,
199
+ num_bits=4)
200
+ tensor_list.append(marlin_qweight)
201
+
202
+ weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
203
+ weight = torch.nn.Parameter(weight, requires_grad=False)
204
+
205
+ setattr(layer, name, weight)
206
+
207
+ # WEIGHT SCALES
208
+ # Permute scales
209
+ for name in ["w13", "w2"]:
210
+ scales = getattr(layer, name + "_weight_scale").to(param_dtype)
211
+ global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
212
+
213
+ tensor_list = []
214
+ if "w13" in name:
215
+ size_n, size_k = n * 2, k
216
+ else:
217
+ size_n, size_k = k, n
218
+
219
+ for i in range(e):
220
+ marlin_scales = marlin_permute_scales(s=scales[i].T,
221
+ size_k=size_k,
222
+ size_n=size_n,
223
+ group_size=16)
224
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
225
+ tensor_list.append(marlin_scales)
226
+
227
+ scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
228
+ scales = torch.nn.Parameter(scales, requires_grad=False)
229
+ setattr(layer, name + "_weight_scale", scales)
230
+
231
+ global_scale = fp4_marlin_process_global_scale(global_scale)
232
+ global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
233
+ setattr(layer, name + "_weight_scale_2", global_scale)
234
+
235
+
236
+ def rand_marlin_weight_fp4_like(weight, group_size):
237
+ assert group_size > 0
238
+ size_n, size_k = weight.shape
239
+ device = weight.device
240
+
241
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
242
+ global_scale = scales.max() / 448
243
+ scales = (scales / global_scale).to(torch.float8_e4m3fn)
244
+
245
+ fp4_weight = torch.randint(0,
246
+ 256, (size_n, size_k // 2),
247
+ dtype=torch.uint8,
248
+ device=weight.device)
249
+ fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
250
+ ((fp4_weight & 0b01110000) >> 2))
251
+ fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
252
+ fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
253
+
254
+ fp4_weight2 = fp4_weight << 4
255
+ fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
256
+ ((fp4_weight2 & 0b01110000) >> 2))
257
+ fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
258
+ fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
259
+
260
+ weight_ref = torch.cat(
261
+ [fp4_weight_part_2.unsqueeze(2),
262
+ fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
263
+ weight_ref = weight_ref * global_scale.to(weight.dtype) * \
264
+ scales.repeat_interleave(group_size, 1).to(weight.dtype)
265
+
266
+ marlin_qweight = ops.gptq_marlin_repack(
267
+ b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
268
+ perm=torch.empty(0, dtype=torch.int, device=device),
269
+ size_k=size_k,
270
+ size_n=size_n,
271
+ num_bits=4,
272
+ )
273
+
274
+ marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
275
+ size_k=size_k,
276
+ size_n=size_n,
277
+ group_size=group_size)
278
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
279
+
280
+ global_scale = fp4_marlin_process_global_scale(global_scale)
281
+
282
+ return weight_ref.T, marlin_qweight, marlin_scales, global_scale
build/torch27-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp8.py CHANGED
@@ -1,10 +1,13 @@
 
 
 
1
  from typing import Optional
2
 
3
  import torch
4
 
5
  import quantization as ops
6
 
7
- from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
 
9
 
10
  def is_fp8_marlin_supported():
@@ -13,88 +16,107 @@ def is_fp8_marlin_supported():
13
  return capability >= 80
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def apply_fp8_marlin_linear(
17
- input: torch.Tensor,
18
- weight: torch.Tensor,
19
- weight_scale: torch.Tensor,
20
- workspace: torch.Tensor,
21
- size_n: int,
22
- size_k: int,
23
- bias: Optional[torch.Tensor],
24
- ) -> torch.Tensor:
25
  # For GPUs that lack FP8 hardware support, we can leverage the
26
  # Marlin kernel for fast weight-only FP8 quantization
27
 
28
  reshaped_x = input.reshape(-1, input.shape[-1])
29
- out_shape = input.shape[:-1] + (size_n,)
30
-
31
- output = ops.fp8_marlin_gemm(
32
- a=reshaped_x,
33
- b_q_weight=weight,
34
- b_scales=weight_scale,
35
- workspace=workspace,
36
- num_bits=8,
37
- size_m=reshaped_x.shape[0],
38
- size_n=size_n,
39
- size_k=size_k,
40
- )
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  if bias is not None:
43
  output.add_(bias) # In-place add
44
 
45
  return output.reshape(out_shape)
46
 
47
-
48
- def prepare_fp8_layer_for_marlin(
49
- layer: torch.nn.Module, strategy: str = "tensor"
50
- ) -> None:
51
- part_size_n = layer.output_size_per_partition
52
- part_size_k = layer.input_size_per_partition
53
-
54
- device = layer.weight.device
55
-
56
- # WORKSPACE
57
- layer.workspace = marlin_make_workspace(part_size_n, device)
58
-
59
- # WEIGHT
60
- # Repack weights to marlin format
61
- marlin_qweight = ops.gptq_marlin_repack(
62
- b_q_weight=pack_fp8_to_int32(layer.weight),
63
- perm=torch.empty(0, dtype=torch.int, device=device),
64
- size_k=part_size_k,
65
- size_n=part_size_n,
66
- num_bits=8,
67
- )
68
- layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
-
70
- # WEIGHT SCALES
71
- scales = layer.weight_scale.to(layer.orig_dtype)
72
- # Permute scales
73
- marlin_scales = marlin_permute_scales(
74
- s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
- )
76
- layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
-
78
-
79
- def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
  """
81
  Repack FP8 weights to gptq format (packed int32 elements)
82
  """
83
  assert fp8_tensor.dtype == torch.float8_e4m3fn
84
- assert fp8_tensor.shape[0] % 4 == 0
85
-
86
- # Reshape to prepare for packing
87
- reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Convert fp8 to uint8 (byte) representation
90
- byte_tensor = reshaped.view(torch.uint8)
 
 
91
 
92
- # Pack 4 uint8 values into one int32
93
- packed = (
94
- byte_tensor[:, 0].to(torch.int32)
95
- | (byte_tensor[:, 1].to(torch.int32) << 8)
96
- | (byte_tensor[:, 2].to(torch.int32) << 16)
97
- | (byte_tensor[:, 3].to(torch.int32) << 24)
98
- )
99
 
100
- return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  from typing import Optional
5
 
6
  import torch
7
 
8
  import quantization as ops
9
 
10
+ from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
11
 
12
 
13
  def is_fp8_marlin_supported():
 
16
  return capability >= 80
17
 
18
 
19
+ def fp8_fused_exponent_bias_into_scales(scales):
20
+ fp8_exponent = 4
21
+ if scales.dtype == torch.half:
22
+ target_exponent = 5
23
+ elif scales.dtype == torch.bfloat16:
24
+ target_exponent = 8
25
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
26
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
27
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
28
+ s = torch.ones_like(scales) * 2
29
+ s = s**exponent_bias
30
+ return scales * s
31
+
32
+
33
  def apply_fp8_marlin_linear(
34
+ input: torch.Tensor,
35
+ weight: torch.Tensor,
36
+ weight_scale: torch.Tensor,
37
+ workspace: torch.Tensor,
38
+ size_n: int,
39
+ size_k: int,
40
+ bias: Optional[torch.Tensor],
41
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
42
  # For GPUs that lack FP8 hardware support, we can leverage the
43
  # Marlin kernel for fast weight-only FP8 quantization
44
 
45
  reshaped_x = input.reshape(-1, input.shape[-1])
46
+ out_shape = input.shape[:-1] + (size_n, )
47
+
48
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
49
+ n=size_n,
50
+ k=size_k,
51
+ device=input.device,
52
+ dtype=input.dtype)
53
+
54
+ output = ops.gptq_marlin_gemm(a=reshaped_x,
55
+ c=None,
56
+ b_q_weight=weight,
57
+ b_scales=weight_scale,
58
+ global_scale=None,
59
+ b_zeros=None,
60
+ g_idx=None,
61
+ perm=None,
62
+ workspace=workspace,
63
+ b_q_type=scalar_types.float8_e4m3fn,
64
+ size_m=reshaped_x.size(0),
65
+ size_n=size_n,
66
+ size_k=size_k,
67
+ use_atomic_add=use_atomic_add,
68
+ use_fp32_reduce=use_fp32_reduce)
69
 
70
  if bias is not None:
71
  output.add_(bias) # In-place add
72
 
73
  return output.reshape(out_shape)
74
 
75
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
76
+ size_k_first: bool = True) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  Repack FP8 weights to gptq format (packed int32 elements)
79
  """
80
  assert fp8_tensor.dtype == torch.float8_e4m3fn
81
+ assert fp8_tensor.ndim == 2
82
+
83
+ fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
84
+ fp8_tensor = fp8_tensor.contiguous()
85
+ # fp8_tensor is contiguous and have shape (N, K) now
86
+ # with `.view(torch.int32)`, it become (N, K // 4)
87
+ int32_tensor = fp8_tensor.view(torch.int32)
88
+ return int32_tensor.T.contiguous() if size_k_first else int32_tensor
89
+
90
+
91
+ def marlin_quant_fp8_torch(weight, group_size):
92
+ size_n, size_k = weight.shape
93
+ device = weight.device
94
+
95
+ if group_size != -1:
96
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
97
+ repeated_scales = scales.repeat_interleave(group_size, 1)
98
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
99
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
100
+ else:
101
+ scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
102
+ repeated_scales = scales.repeat_interleave(size_k, 1)
103
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
104
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
105
+
106
+ packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
107
+ marlin_qweight = ops.gptq_marlin_repack(
108
+ b_q_weight=packed_weight,
109
+ perm=torch.empty(0, dtype=torch.int, device=device),
110
+ size_k=size_k,
111
+ size_n=size_n,
112
+ num_bits=8,
113
+ )
114
 
115
+ marlin_scales = marlin_permute_scales(s=scales.T,
116
+ size_k=size_k,
117
+ size_n=size_n,
118
+ group_size=group_size)
119
 
120
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
 
 
 
 
 
 
121
 
122
+ return weight_ref.T, marlin_qweight, marlin_scales
build/torch27-cxx11-cu128-aarch64-linux/quantization/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
  from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
  from .cutlass import (
 
3
  cutlass_scaled_mm_supports_fp8,
4
  cutlass_scaled_mm,
5
  cutlass_scaled_mm_azp,
6
  )
7
  from .marlin import (
8
  awq_marlin_repack,
9
- fp8_marlin_gemm,
10
  gptq_marlin_gemm,
11
  gptq_marlin_repack,
12
  gptq_marlin_24_gemm,
@@ -25,8 +25,8 @@ __all__ = [
25
  "awq_marlin_repack",
26
  "cutlass_scaled_mm",
27
  "cutlass_scaled_mm_azp",
 
28
  "cutlass_scaled_mm_supports_fp8",
29
- "fp8_marlin_gemm",
30
  "gptq_marlin_24_gemm",
31
  "gptq_marlin_gemm",
32
  "gptq_marlin_repack",
 
1
  from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
  from .cutlass import (
3
+ cutlass_scaled_mm_supports_block_fp8,
4
  cutlass_scaled_mm_supports_fp8,
5
  cutlass_scaled_mm,
6
  cutlass_scaled_mm_azp,
7
  )
8
  from .marlin import (
9
  awq_marlin_repack,
 
10
  gptq_marlin_gemm,
11
  gptq_marlin_repack,
12
  gptq_marlin_24_gemm,
 
25
  "awq_marlin_repack",
26
  "cutlass_scaled_mm",
27
  "cutlass_scaled_mm_azp",
28
+ "cutlass_scaled_mm_supports_block_fp8",
29
  "cutlass_scaled_mm_supports_fp8",
 
30
  "gptq_marlin_24_gemm",
31
  "gptq_marlin_gemm",
32
  "gptq_marlin_repack",
build/torch27-cxx11-cu128-aarch64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_9035540
3
+ ops = torch.ops._quantization_9035540
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_9035540::{op_name}"
build/torch27-cxx11-cu128-aarch64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_9035540.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6a816838b0e28d1984f1ac89da4f25bff61a0dcfe6dc54351bc08e544363d499
3
- size 121278800
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d670f7d449a8d177ce46784fb4617dcb0edc30f8d8a62305ed1213310256167
3
+ size 296561248
build/torch27-cxx11-cu128-aarch64-linux/quantization/compressed_tensors.py CHANGED
@@ -2,17 +2,7 @@ from typing import Optional, Tuple
2
 
3
  import torch
4
 
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
-
16
 
17
  # fp8
18
  def scaled_fp8_quant(
@@ -21,7 +11,8 @@ def scaled_fp8_quant(
21
  num_token_padding: Optional[int] = None,
22
  scale_ub: Optional[torch.Tensor] = None,
23
  use_per_token_if_dynamic: bool = False,
24
- ) -> Tuple[torch.Tensor, torch.Tensor]:
 
25
  """
26
  Quantize input tensor to FP8 and return quantized tensor and scale.
27
 
@@ -42,30 +33,36 @@ def scaled_fp8_quant(
42
  in the dynamic quantization case.
43
 
44
  Returns:
45
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
  scaling factor.
47
  """
48
  # This code assumes batch_dim and num_tokens are flattened
49
- assert input.ndim == 2
50
- shape: Union[Tuple[int, int], torch.Size] = input.shape
51
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
- # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
- # if current_platform.is_rocm() else torch.float8_e4m3fn
54
- out_dtype = torch.float8_e4m3fn
55
  if num_token_padding:
56
  shape = (max(num_token_padding, input.shape[0]), shape[1])
57
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
 
 
 
 
 
58
 
59
  if scale is None:
60
  if use_per_token_if_dynamic:
61
- scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
- ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
 
63
  else:
64
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
  ops.dynamic_scaled_fp8_quant(output, input, scale)
66
  else:
67
  # num_token_padding not implemented for this case
68
- assert scale.numel() == 1 or num_token_padding is None
69
  ops.static_scaled_fp8_quant(output, input, scale)
70
 
71
  return output, scale
@@ -76,8 +73,8 @@ def scaled_int8_quant(
76
  input: torch.Tensor,
77
  scale: Optional[torch.Tensor] = None,
78
  azp: Optional[torch.Tensor] = None,
79
- symmetric: bool = True,
80
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
  """
82
  Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
 
@@ -90,21 +87,25 @@ def scaled_int8_quant(
90
  symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
 
92
  Returns:
93
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
  """
95
  output = torch.empty_like(input, dtype=torch.int8)
96
  if scale is not None:
97
  # static-per-tensor quantization.
98
  assert symmetric == (
99
- azp is None
100
- ), "azp must only be provided for asymmetric quantization."
101
  ops.static_scaled_int8_quant(output, input, scale, azp)
102
  return output, scale, azp
103
 
104
  # dynamic-per-token quantization.
105
- input_scales = torch.empty(
106
- (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
- )
108
- input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
- ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
 
 
110
  return output, input_scales, input_azp
 
 
 
2
 
3
  import torch
4
 
5
+ from ._ops import ops
 
 
 
 
 
 
 
 
 
 
6
 
7
  # fp8
8
  def scaled_fp8_quant(
 
11
  num_token_padding: Optional[int] = None,
12
  scale_ub: Optional[torch.Tensor] = None,
13
  use_per_token_if_dynamic: bool = False,
14
+ output: Optional[torch.Tensor] = None,
15
+ ) -> tuple[torch.Tensor, torch.Tensor]:
16
  """
17
  Quantize input tensor to FP8 and return quantized tensor and scale.
18
 
 
33
  in the dynamic quantization case.
34
 
35
  Returns:
36
+ tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
37
  scaling factor.
38
  """
39
  # This code assumes batch_dim and num_tokens are flattened
40
+ assert (input.ndim == 2)
41
+ shape: Union[tuple[int, int], torch.Size] = input.shape
42
+ # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
43
+ out_dtype: torch.dtype = current_platform.fp8_dtype()
 
 
44
  if num_token_padding:
45
  shape = (max(num_token_padding, input.shape[0]), shape[1])
46
+ if output is None:
47
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
48
+ else:
49
+ assert num_token_padding is None, \
50
+ "padding not supported if output passed in"
51
+ assert output.dtype == out_dtype
52
 
53
  if scale is None:
54
  if use_per_token_if_dynamic:
55
+ scale = torch.empty((shape[0], 1),
56
+ device=input.device,
57
+ dtype=torch.float32)
58
+ ops.dynamic_per_token_scaled_fp8_quant(
59
+ output, input.contiguous(), scale, scale_ub)
60
  else:
61
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
62
  ops.dynamic_scaled_fp8_quant(output, input, scale)
63
  else:
64
  # num_token_padding not implemented for this case
65
+ assert (scale.numel() == 1 and num_token_padding is None)
66
  ops.static_scaled_fp8_quant(output, input, scale)
67
 
68
  return output, scale
 
73
  input: torch.Tensor,
74
  scale: Optional[torch.Tensor] = None,
75
  azp: Optional[torch.Tensor] = None,
76
+ symmetric: bool = True
77
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
78
  """
79
  Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
80
 
 
87
  symmetric: Whether to use symmetric quantization (scale only, azp ignored).
88
 
89
  Returns:
90
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
91
  """
92
  output = torch.empty_like(input, dtype=torch.int8)
93
  if scale is not None:
94
  # static-per-tensor quantization.
95
  assert symmetric == (
96
+ azp
97
+ is None), "azp must only be provided for asymmetric quantization."
98
  ops.static_scaled_int8_quant(output, input, scale, azp)
99
  return output, scale, azp
100
 
101
  # dynamic-per-token quantization.
102
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
103
+ device=input.device,
104
+ dtype=torch.float32)
105
+ input_azp = None if symmetric else torch.empty_like(input_scales,
106
+ dtype=torch.int32)
107
+ ops.dynamic_scaled_int8_quant(output, input.contiguous(),
108
+ input_scales, input_azp)
109
  return output, input_scales, input_azp
110
+
111
+
build/torch27-cxx11-cu128-aarch64-linux/quantization/cutlass.py CHANGED
@@ -2,22 +2,18 @@ from typing import Optional
2
 
3
  import torch
4
 
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
 
16
 
17
  def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
  return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
 
20
 
 
 
 
 
21
  def cutlass_scaled_mm(
22
  a: torch.Tensor,
23
  b: torch.Tensor,
@@ -33,12 +29,10 @@ def cutlass_scaled_mm(
33
  m = a.shape[0]
34
  n = b.shape[1]
35
 
36
- # if current_platform.is_rocm():
37
- # triton_scaled_mm_module = importlib.import_module(
38
- # "vllm.model_executor.layers.quantization.compressed_tensors."
39
- # "triton_scaled_mm")
40
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
 
43
  out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
 
 
2
 
3
  import torch
4
 
5
+ from ._ops import ops
6
+ from .platforms import current_platform
 
 
 
 
 
 
 
 
7
 
8
 
9
  def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
10
  return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
11
 
12
 
13
+ def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
14
+ return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability)
15
+
16
+
17
  def cutlass_scaled_mm(
18
  a: torch.Tensor,
19
  b: torch.Tensor,
 
29
  m = a.shape[0]
30
  n = b.shape[1]
31
 
32
+ cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
33
+ if not cutlass_compatible_b:
34
+ from .triton_scaled_mm import triton_scaled_mm
35
+ return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
 
 
36
 
37
  out = torch.empty((m, n), dtype=out_dtype, device=a.device)
38
 
build/torch27-cxx11-cu128-aarch64-linux/quantization/marlin.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING
2
 
3
  import torch
4
 
@@ -30,58 +30,30 @@ except ImportError as e:
30
  from .scalar_type import ScalarType
31
 
32
 
33
- # fp8 marlin
34
- def fp8_marlin_gemm(
35
- a: torch.Tensor,
36
- b_q_weight: torch.Tensor,
37
- b_scales: torch.Tensor,
38
- workspace: torch.Tensor,
39
- num_bits: int,
40
- size_m: int,
41
- size_n: int,
42
- size_k: int,
43
- ) -> torch.Tensor:
44
- return ops.fp8_marlin_gemm(
45
- a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
- )
47
-
48
-
49
  # gptq_marlin
50
- def gptq_marlin_gemm(
51
- a: torch.Tensor,
52
- b_q_weight: torch.Tensor,
53
- b_scales: torch.Tensor,
54
- b_zeros: torch.Tensor,
55
- g_idx: torch.Tensor,
56
- perm: torch.Tensor,
57
- workspace: torch.Tensor,
58
- b_q_type: ScalarType,
59
- size_m: int,
60
- size_n: int,
61
- size_k: int,
62
- is_k_full: bool,
63
- has_zp: bool = False,
64
- use_fp32_reduce: bool = False,
65
- is_zp_float: bool = False,
66
- ) -> torch.Tensor:
67
- return ops.gptq_marlin_gemm(
68
- a,
69
- b_q_weight,
70
- b_scales,
71
- b_zeros,
72
- g_idx,
73
- perm,
74
- workspace,
75
- b_q_type.id,
76
- size_m,
77
- size_n,
78
- size_k,
79
- is_k_full,
80
- has_zp,
81
- use_fp32_reduce,
82
- is_zp_float,
83
- )
84
-
85
 
86
  # gptq_marlin
87
  def gptq_marlin_repack(
@@ -153,14 +125,6 @@ def marlin_qqq_gemm(
153
  # Fake ops
154
 
155
  if hasattr(ops, "gptq_marlin_24_gemm"):
156
- @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
- def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
- b_scales: torch.Tensor, workspace: torch.Tensor,
159
- num_bits: int, size_m: torch.SymInt,
160
- size_n: torch.SymInt,
161
- size_k: torch.SymInt) -> torch.Tensor:
162
- return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
-
164
  @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
  def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
  b_meta: torch.Tensor, b_scales: torch.Tensor,
@@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"):
172
 
173
  @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
  def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
- b_q_weight: torch.Tensor,
176
- b_scales: torch.Tensor,
177
- b_zeros: torch.Tensor,
178
- g_idx: torch.Tensor,
179
- perm: torch.Tensor,
180
- workspace: torch.Tensor,
181
- b_q_type: ScalarType,
182
- size_m: torch.SymInt,
183
- size_n: torch.SymInt,
184
- size_k: torch.SymInt,
185
- is_k_full: bool,
186
- has_zp: bool = False,
187
- use_fp32_reduce: bool = False,
188
- is_zp_float: bool = False) -> torch.Tensor:
 
 
189
  return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
 
191
  @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
 
1
+ from typing import TYPE_CHECKING, Optional
2
 
3
  import torch
4
 
 
30
  from .scalar_type import ScalarType
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # gptq_marlin
34
+ def gptq_marlin_gemm(a: torch.Tensor,
35
+ c: Optional[torch.Tensor],
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ global_scale: Optional[torch.Tensor],
39
+ b_zeros: Optional[torch.Tensor],
40
+ g_idx: Optional[torch.Tensor],
41
+ perm: Optional[torch.Tensor],
42
+ workspace: torch.Tensor,
43
+ b_q_type: ScalarType,
44
+ size_m: int,
45
+ size_n: int,
46
+ size_k: int,
47
+ is_k_full: bool = True,
48
+ use_atomic_add: bool = False,
49
+ use_fp32_reduce: bool = False,
50
+ is_zp_float: bool = False) -> torch.Tensor:
51
+ return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
52
+ global_scale, b_zeros, g_idx, perm,
53
+ workspace, b_q_type.id, size_m,
54
+ size_n, size_k, is_k_full,
55
+ use_atomic_add, use_fp32_reduce,
56
+ is_zp_float)
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # gptq_marlin
59
  def gptq_marlin_repack(
 
125
  # Fake ops
126
 
127
  if hasattr(ops, "gptq_marlin_24_gemm"):
 
 
 
 
 
 
 
 
128
  @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
129
  def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
130
  b_meta: torch.Tensor, b_scales: torch.Tensor,
 
136
 
137
  @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
138
  def _gptq_marlin_gemm_fake(a: torch.Tensor,
139
+ c: Optional[torch.Tensor],
140
+ b_q_weight: torch.Tensor,
141
+ b_scales: torch.Tensor,
142
+ global_scale: Optional[torch.Tensor],
143
+ b_zeros: Optional[torch.Tensor],
144
+ g_idx: Optional[torch.Tensor],
145
+ perm: Optional[torch.Tensor],
146
+ workspace: torch.Tensor,
147
+ b_q_type_id: int,
148
+ size_m: torch.SymInt,
149
+ size_n: torch.SymInt,
150
+ size_k: torch.SymInt,
151
+ is_k_full: bool = True,
152
+ use_atomic_add: bool = False,
153
+ use_fp32_reduce: bool = False,
154
+ is_zp_float: bool = False) -> torch.Tensor:
155
  return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
156
 
157
  @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
build/torch27-cxx11-cu128-aarch64-linux/quantization/platforms.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import lru_cache
3
+ from typing import NamedTuple
4
+
5
+ import torch
6
+
7
+ IS_ROCM = torch.version.hip is not None
8
+
9
+
10
+ class DeviceCapability(NamedTuple):
11
+ major: int
12
+ minor: int
13
+
14
+ def as_version_str(self) -> str:
15
+ return f"{self.major}.{self.minor}"
16
+
17
+ def to_int(self) -> int:
18
+ """
19
+ Express device capability as an integer ``<major><minor>``.
20
+
21
+ It is assumed that the minor version is always a single digit.
22
+ """
23
+ assert 0 <= self.minor < 10
24
+ return self.major * 10 + self.minor
25
+
26
+
27
+ class Platform(ABC):
28
+ simple_compile_backend: str = "inductor"
29
+
30
+ @classmethod
31
+ @abstractmethod
32
+ def get_device_name(cls, device_id: int = 0) -> str: ...
33
+
34
+ @abstractmethod
35
+ def is_rocm(self): ...
36
+
37
+
38
+ class CudaPlatform(Platform):
39
+ @classmethod
40
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
41
+ major, minor = torch.cuda.get_device_capability(device_id)
42
+ return DeviceCapability(major=major, minor=minor)
43
+
44
+ @classmethod
45
+ @lru_cache(maxsize=8)
46
+ def get_device_name(cls, device_id: int = 0) -> str:
47
+ return torch.cuda.get_device_name(0)
48
+
49
+ def is_rocm(self):
50
+ return False
51
+
52
+
53
+ class RocmPlatform(Platform):
54
+ @classmethod
55
+ @lru_cache(maxsize=8)
56
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
57
+ major, minor = torch.cuda.get_device_capability(device_id)
58
+ return DeviceCapability(major=major, minor=minor)
59
+
60
+ @classmethod
61
+ @lru_cache(maxsize=8)
62
+ def get_device_name(cls, device_id: int = 0) -> str:
63
+ return torch.cuda.get_device_name(device_id)
64
+
65
+ def is_rocm(self):
66
+ return True
67
+
68
+
69
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch27-cxx11-cu128-aarch64-linux/quantization/scalar_type.py CHANGED
@@ -1,9 +1,14 @@
 
 
 
1
  import functools
2
  import struct
3
  from dataclasses import dataclass
4
  from enum import Enum
5
  from typing import Optional, Union
6
 
 
 
7
 
8
  # Mirrors enum in `core/scalar_type.hpp`
9
  class NanRepr(Enum):
@@ -121,8 +126,8 @@ class ScalarType:
121
  min_raw = max_raw | sign_bit_double
122
  return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
  else:
124
- assert (not self.is_signed() or
125
- self.size_bits <= 64), "Cannot represent min as a int64_t"
126
 
127
  if self.is_signed():
128
  return -(1 << (self.size_bits - 1))
@@ -156,6 +161,8 @@ class ScalarType:
156
  assert offset <= 64, \
157
  f"ScalarType fields too big {offset} to fit into an int64"
158
 
 
 
159
  return val
160
 
161
  @property
@@ -293,6 +300,13 @@ class ScalarType:
293
  ret.id # noqa B018: make sure the id is cached
294
  return ret
295
 
 
 
 
 
 
 
 
296
 
297
  # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
  # for floating point types (leading f) the scheme is:
@@ -319,6 +333,9 @@ class scalar_types:
319
  # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
  float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
 
 
 
 
322
  # "gptq" types
323
  uint2b2 = ScalarType.uint(2, 2)
324
  uint3b4 = ScalarType.uint(3, 4)
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  import functools
5
  import struct
6
  from dataclasses import dataclass
7
  from enum import Enum
8
  from typing import Optional, Union
9
 
10
+ _SCALAR_TYPES_ID_MAP = {}
11
+
12
 
13
  # Mirrors enum in `core/scalar_type.hpp`
14
  class NanRepr(Enum):
 
126
  min_raw = max_raw | sign_bit_double
127
  return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
128
  else:
129
+ assert (not self.is_signed() or self.size_bits
130
+ <= 64), "Cannot represent min as a int64_t"
131
 
132
  if self.is_signed():
133
  return -(1 << (self.size_bits - 1))
 
161
  assert offset <= 64, \
162
  f"ScalarType fields too big {offset} to fit into an int64"
163
 
164
+ _SCALAR_TYPES_ID_MAP[val] = self
165
+
166
  return val
167
 
168
  @property
 
300
  ret.id # noqa B018: make sure the id is cached
301
  return ret
302
 
303
+ @classmethod
304
+ def from_id(cls, scalar_type_id: int):
305
+ if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
306
+ raise ValueError(
307
+ f"scalar_type_id {scalar_type_id} doesn't exists.")
308
+ return _SCALAR_TYPES_ID_MAP[scalar_type_id]
309
+
310
 
311
  # naming generally follows: https://github.com/jax-ml/ml_dtypes
312
  # for floating point types (leading f) the scheme is:
 
333
  # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
334
  float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
335
 
336
+ # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
337
+ float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
338
+
339
  # "gptq" types
340
  uint2b2 = ScalarType.uint(2, 2)
341
  uint3b4 = ScalarType.uint(3, 4)
build/torch27-cxx11-cu128-aarch64-linux/quantization/utils/marlin_utils.py CHANGED
@@ -1,4 +1,7 @@
1
- from typing import List, Optional, Tuple
 
 
 
2
 
3
  import numpy
4
  import torch
@@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True
42
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
  # TODO: we may want to move this into the C++ so its closer to the actual impl
44
  def query_marlin_supported_quant_types(
45
- has_zp: bool, device_capability: Optional[int] = None
 
 
46
  ):
47
  if device_capability is None:
48
  capability_tuple = torch.cuda.get_device_capability()
@@ -51,137 +56,141 @@ def query_marlin_supported_quant_types(
51
  if device_capability < 80:
52
  return []
53
 
 
 
 
 
 
 
 
 
 
 
54
  if has_zp:
55
  # AWQ style, unsigned + runtime zero-point
56
- return [scalar_types.uint4, scalar_types.uint8]
57
  else:
58
  # GPTQ style, unsigned + symmetric bias
59
- # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
- # to add `scalar_types.float8_e4m3fn` here
61
- return [scalar_types.uint4b8, scalar_types.uint8b128]
 
62
 
63
 
64
  def _check_marlin_supported(
65
- quant_type: ScalarType,
66
- group_size: Optional[int],
67
- has_zp: bool,
68
- device_capability: Optional[int] = None,
69
- ) -> Tuple[bool, Optional[str]]:
70
 
71
  if device_capability is None:
72
  capability_tuple = torch.cuda.get_device_capability()
73
  device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
 
75
- supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
 
76
 
77
  if quant_type not in supported_types:
78
- return (
79
- False,
80
- f"Marlin does not support weight_bits = {quant_type}. "
81
- f"Only types = {supported_types} "
82
- f"are supported (for group_size = {group_size}, "
83
- f"device_capability = {device_capability}, zp = {has_zp}).",
84
- )
85
- if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
- return (
87
- False,
88
- f"Marlin does not support group_size = {group_size}. "
89
- f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
- "are supported.",
91
- )
92
 
93
  return True, None
94
 
95
 
96
- def check_marlin_supported(
97
- quant_type: ScalarType,
98
- group_size: int,
99
- has_zp: bool = False,
100
- device_capability: Optional[int] = None,
101
- ) -> bool:
102
- cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
  return cond
104
 
105
 
106
- def verify_marlin_supported(
107
- quant_type: ScalarType, group_size: int, has_zp: bool = False
108
- ) -> None:
109
  cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
  if not cond:
111
  assert err_msg is not None
112
  raise ValueError(err_msg)
113
 
114
 
115
- def verify_marlin_supports_shape(
116
- output_size_per_partition: int,
117
- input_size_per_partition: int,
118
- input_size: int,
119
- group_size: int,
120
- ) -> None:
121
 
122
  # Validate output_size_per_partition
123
  if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
- raise ValueError(
125
- f"Weight output_size_per_partition = "
126
- f"{output_size_per_partition} is not divisible by "
127
- f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
- "Consider reducing tensor_parallel_size or running "
129
- "with --quantization gptq."
130
- )
131
 
132
  # Validate input_size_per_partition
133
  if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
- raise ValueError(
135
- f"Weight input_size_per_partition = "
136
- f"{input_size_per_partition} is not divisible "
137
- f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
- "Consider reducing tensor_parallel_size or running "
139
- "with --quantization gptq."
140
- )
141
-
142
- if group_size < input_size and input_size_per_partition % group_size != 0:
143
  raise ValueError(
144
  f"Weight input_size_per_partition = {input_size_per_partition}"
145
- f" is not divisible by group_size = {group_size}."
146
  "Consider reducing tensor_parallel_size or running "
147
- "with --quantization gptq."
148
- )
149
 
150
 
151
- def check_marlin_supports_shape(
152
- output_size_per_partition: int,
153
- input_size_per_partition: int,
154
- input_size: int,
155
- group_size: int,
156
- ) -> Tuple[bool, Optional[str]]:
157
  try:
158
- verify_marlin_supports_shape(
159
- output_size_per_partition, input_size_per_partition, input_size, group_size
160
- )
161
  except ValueError as e:
162
  return False, e.__str__()
163
  return True, None
164
 
165
 
166
- def marlin_make_workspace(
167
- output_size_per_partition: int, device: torch.device
168
- ) -> torch.Tensor:
169
- max_workspace_size = (
170
- output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
- ) * GPTQ_MARLIN_MAX_PARALLEL
172
 
173
- return torch.zeros(
174
- max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
- )
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
  return (not act_order) or (act_order and not is_row_parallel)
180
 
181
 
182
- def marlin_repeat_scales_on_all_ranks(
183
- act_order: bool, group_size: int, is_row_parallel: bool
184
- ) -> bool:
185
  # Need to repeat scales on every rank if act_ordering or
186
  # channelwise and RowParallelLinear
187
  is_channelwise = group_size == -1
@@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks(
189
 
190
 
191
  def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
- return torch.nn.Parameter(
193
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
- )
195
 
196
 
197
  def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
- return torch.nn.Parameter(
199
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
- )
201
 
202
 
203
- def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
204
  g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
  return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
 
207
 
208
  def get_scale_perms():
209
- scale_perm: List[int] = []
210
  for i in range(8):
211
  scale_perm.extend([i + 8 * j for j in range(8)])
212
- scale_perm_single: List[int] = []
213
  for i in range(4):
214
- scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
 
215
  return scale_perm, scale_perm_single
216
 
217
 
218
- def marlin_permute_scales(
219
- s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
- ) -> torch.Tensor:
221
 
222
  scale_perm, scale_perm_single = get_scale_perms()
223
  if group_size < size_k and group_size != -1:
@@ -247,9 +255,8 @@ def marlin_moe_permute_scales(
247
  return output
248
 
249
 
250
- def marlin_zero_points(
251
- zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
- ) -> torch.Tensor:
253
  # Permute zero-points in a similar way to scales, but do not use the
254
  # "single" permutation, since zero-points are applied on every MMA
255
  scale_perm, _ = get_scale_perms()
@@ -270,9 +277,8 @@ def marlin_zero_points(
270
  return zp
271
 
272
 
273
- def awq_to_marlin_zero_points(
274
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
- ) -> torch.Tensor:
276
  # AWQ zero-points are quantized and packed on the column dim.
277
  # In addition, the values are permuted based on dequantizer.
278
  # Here we undo both of these, and then apply marlin permutation
@@ -294,9 +300,8 @@ def awq_to_marlin_zero_points(
294
  return marlin_zp
295
 
296
 
297
- def moe_awq_to_marlin_zero_points(
298
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
- ):
300
  num_experts = q_zp_packed.shape[0]
301
  output = torch.empty(
302
  (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
@@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points(
304
  dtype=q_zp_packed.dtype,
305
  )
306
  for e in range(num_experts):
307
- output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
 
308
  return output
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def apply_gptq_marlin_linear(
312
- input: torch.Tensor,
313
- weight: torch.Tensor,
314
- weight_scale: torch.Tensor,
315
- weight_zp: torch.Tensor,
316
- g_idx: torch.Tensor,
317
- g_idx_sort_indices: torch.Tensor,
318
- workspace: torch.Tensor,
319
- wtype: ScalarType,
320
- output_size_per_partition: int,
321
- input_size_per_partition: int,
322
- is_k_full: bool,
323
- bias: Optional[torch.Tensor] = None,
324
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
- ) -> torch.Tensor:
326
  reshaped_x = input.reshape(-1, input.shape[-1])
327
- out_shape = input.shape[:-1] + (output_size_per_partition,)
328
-
329
- output = ops.gptq_marlin_gemm(
330
- reshaped_x,
331
- weight,
332
- weight_scale,
333
- weight_zp,
334
- g_idx,
335
- g_idx_sort_indices,
336
- workspace,
337
- wtype,
338
- size_m=reshaped_x.shape[0],
339
- size_n=output_size_per_partition,
340
- size_k=input_size_per_partition,
341
- is_k_full=is_k_full,
342
- has_zp=False,
343
- use_fp32_reduce=use_fp32_reduce,
344
- is_zp_float=False,
345
- )
 
 
 
 
 
 
346
 
347
  if bias is not None:
348
  output.add_(bias) # In-place add
@@ -351,39 +408,43 @@ def apply_gptq_marlin_linear(
351
 
352
 
353
  def apply_awq_marlin_linear(
354
- input: torch.Tensor,
355
- weight: torch.Tensor,
356
- weight_scale: torch.Tensor,
357
- weight_zp: torch.Tensor,
358
- g_idx: torch.Tensor,
359
- g_idx_sort_indices: torch.Tensor,
360
- workspace: torch.Tensor,
361
- quant_type: ScalarType,
362
- output_size_per_partition: int,
363
- input_size_per_partition: int,
364
- bias: Optional[torch.Tensor] = None,
365
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
- ) -> torch.Tensor:
367
  reshaped_x = input.reshape(-1, input.shape[-1])
368
- out_shape = input.shape[:-1] + (output_size_per_partition,)
369
-
370
- output = ops.gptq_marlin_gemm(
371
- reshaped_x,
372
- weight,
373
- weight_scale,
374
- weight_zp,
375
- g_idx,
376
- g_idx_sort_indices,
377
- workspace,
378
- quant_type,
379
- size_m=reshaped_x.shape[0],
380
- size_n=output_size_per_partition,
381
- size_k=input_size_per_partition,
382
- is_k_full=True,
383
- has_zp=True,
384
- use_fp32_reduce=use_fp32_reduce,
385
- is_zp_float=False,
386
- )
 
 
 
 
 
387
 
388
  if bias is not None:
389
  output.add_(bias) # In-place add
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
 
6
  import numpy
7
  import torch
 
45
  # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
46
  # TODO: we may want to move this into the C++ so its closer to the actual impl
47
  def query_marlin_supported_quant_types(
48
+ has_zp: Optional[bool] = None,
49
+ include_fp_type: bool = True,
50
+ device_capability: Optional[int] = None,
51
  ):
52
  if device_capability is None:
53
  capability_tuple = torch.cuda.get_device_capability()
 
56
  if device_capability < 80:
57
  return []
58
 
59
+ # - has_zp is True: return quant_types that has zero points
60
+ # - has_zp is False: return quant_types that has not zero points
61
+ # - has_zp is None: both
62
+ if has_zp is None:
63
+ types0 = query_marlin_supported_quant_types(False, include_fp_type,
64
+ device_capability)
65
+ types1 = query_marlin_supported_quant_types(True, include_fp_type,
66
+ device_capability)
67
+ return types0 + types1
68
+
69
  if has_zp:
70
  # AWQ style, unsigned + runtime zero-point
71
+ return [scalar_types.uint4]
72
  else:
73
  # GPTQ style, unsigned + symmetric bias
74
+ res = [scalar_types.uint4b8, scalar_types.uint8b128]
75
+ if include_fp_type:
76
+ res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
77
+ return res
78
 
79
 
80
  def _check_marlin_supported(
81
+ quant_type: ScalarType,
82
+ group_size: Optional[int],
83
+ has_zp: bool,
84
+ device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
 
85
 
86
  if device_capability is None:
87
  capability_tuple = torch.cuda.get_device_capability()
88
  device_capability = capability_tuple[0] * 10 + capability_tuple[1]
89
 
90
+ supported_types = query_marlin_supported_quant_types(
91
+ has_zp, True, device_capability)
92
 
93
  if quant_type not in supported_types:
94
+ return (False, f"Marlin does not support weight_bits = {quant_type}. "
95
+ f"Only types = {supported_types} "
96
+ f"are supported (for group_size = {group_size}, "
97
+ f"device_capability = {device_capability}, zp = {has_zp}).")
98
+ if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
99
+ return (False, f"Marlin does not support group_size = {group_size}. "
100
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
101
+ "are supported.")
 
 
 
 
 
 
102
 
103
  return True, None
104
 
105
 
106
+ def check_marlin_supported(quant_type: ScalarType,
107
+ group_size: int,
108
+ has_zp: bool = False,
109
+ device_capability: Optional[int] = None) -> bool:
110
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
111
+ device_capability)
 
112
  return cond
113
 
114
 
115
+ def verify_marlin_supported(quant_type: ScalarType,
116
+ group_size: int,
117
+ has_zp: bool = False) -> None:
118
  cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
119
  if not cond:
120
  assert err_msg is not None
121
  raise ValueError(err_msg)
122
 
123
 
124
+ def verify_marlin_supports_shape(output_size_per_partition: int,
125
+ input_size_per_partition: int,
126
+ input_size: int, group_size: int) -> None:
 
 
 
127
 
128
  # Validate output_size_per_partition
129
  if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
130
+ raise ValueError(f"Weight output_size_per_partition = "
131
+ f"{output_size_per_partition} is not divisible by "
132
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
133
+ "Consider reducing tensor_parallel_size or running "
134
+ "with --quantization gptq.")
 
 
135
 
136
  # Validate input_size_per_partition
137
  if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
138
+ raise ValueError(f"Weight input_size_per_partition = "
139
+ f"{input_size_per_partition} is not divisible "
140
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
141
+ "Consider reducing tensor_parallel_size or running "
142
+ "with --quantization gptq.")
143
+
144
+ if (group_size < input_size
145
+ and input_size_per_partition % group_size != 0):
 
146
  raise ValueError(
147
  f"Weight input_size_per_partition = {input_size_per_partition}"
148
+ f" is not divisible by group_size = {group_size}. "
149
  "Consider reducing tensor_parallel_size or running "
150
+ "with --quantization gptq.")
 
151
 
152
 
153
+ def check_marlin_supports_shape(output_size_per_partition: int,
154
+ input_size_per_partition: int,
155
+ input_size: int, group_size: int) \
156
+ -> tuple[bool, Optional[str]]:
 
 
157
  try:
158
+ verify_marlin_supports_shape(output_size_per_partition,
159
+ input_size_per_partition, input_size,
160
+ group_size)
161
  except ValueError as e:
162
  return False, e.__str__()
163
  return True, None
164
 
165
 
166
+ def marlin_make_workspace(output_size_per_partition: int,
167
+ device: torch.device) -> torch.Tensor:
168
+ max_workspace_size = (output_size_per_partition //
169
+ GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
 
 
170
 
171
+ return torch.zeros(max_workspace_size,
172
+ dtype=torch.int,
173
+ device=device,
174
+ requires_grad=False)
175
+
176
+
177
+ def marlin_make_workspace_new(device: torch.device,
178
+ max_blocks_per_sm: int = 1) -> torch.Tensor:
179
+ # In the new marlin kernel, we use the num of threadblocks as workspace
180
+ # size. The num of threadblocks is is sms_count * max_blocks_per_sm.
181
+ sms = torch.cuda.get_device_properties(device).multi_processor_count
182
+ return torch.zeros(sms * max_blocks_per_sm,
183
+ dtype=torch.int,
184
+ device=device,
185
+ requires_grad=False)
186
 
187
 
188
  def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
189
  return (not act_order) or (act_order and not is_row_parallel)
190
 
191
 
192
+ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
193
+ is_row_parallel: bool) -> bool:
 
194
  # Need to repeat scales on every rank if act_ordering or
195
  # channelwise and RowParallelLinear
196
  is_channelwise = group_size == -1
 
198
 
199
 
200
  def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
201
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
202
+ requires_grad=False)
 
203
 
204
 
205
  def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
206
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
207
+ requires_grad=False)
 
208
 
209
 
210
+ def marlin_sort_g_idx(
211
+ g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
212
  g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
213
  return g_idx[g_idx_sort_indices], g_idx_sort_indices
214
 
215
 
216
  def get_scale_perms():
217
+ scale_perm: list[int] = []
218
  for i in range(8):
219
  scale_perm.extend([i + 8 * j for j in range(8)])
220
+ scale_perm_single: list[int] = []
221
  for i in range(4):
222
+ scale_perm_single.extend(
223
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
224
  return scale_perm, scale_perm_single
225
 
226
 
227
+ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
228
+ group_size: int) -> torch.Tensor:
 
229
 
230
  scale_perm, scale_perm_single = get_scale_perms()
231
  if group_size < size_k and group_size != -1:
 
255
  return output
256
 
257
 
258
+ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
259
+ num_bits: int) -> torch.Tensor:
 
260
  # Permute zero-points in a similar way to scales, but do not use the
261
  # "single" permutation, since zero-points are applied on every MMA
262
  scale_perm, _ = get_scale_perms()
 
277
  return zp
278
 
279
 
280
+ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
281
+ size_n: int, num_bits: int) -> torch.Tensor:
 
282
  # AWQ zero-points are quantized and packed on the column dim.
283
  # In addition, the values are permuted based on dequantizer.
284
  # Here we undo both of these, and then apply marlin permutation
 
300
  return marlin_zp
301
 
302
 
303
+ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
304
+ size_n: int, num_bits: int):
 
305
  num_experts = q_zp_packed.shape[0]
306
  output = torch.empty(
307
  (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
 
309
  dtype=q_zp_packed.dtype,
310
  )
311
  for e in range(num_experts):
312
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
313
+ num_bits)
314
  return output
315
 
316
 
317
+ def maybe_warn_marlin_atomic_add(device, dtype):
318
+ if torch.compiler.is_dynamo_compiling():
319
+ return
320
+ device_capability = torch.cuda.get_device_capability(device)
321
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
322
+ logger.info_once(
323
+ "You are running Marlin kernel with bf16 on GPUs before SM90. "
324
+ "You can consider change to fp16 to achieve better performance "
325
+ "if possible.")
326
+
327
+
328
+ def maybe_warn_marlin_atomic_add_env():
329
+ if torch.compiler.is_dynamo_compiling():
330
+ return
331
+ if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
332
+ return
333
+ logger.info_once(
334
+ "Marlin kernel can achieve better performance for small size_n "
335
+ "with experimental use_atomic_add feature. "
336
+ "You can consider set environment variable "
337
+ "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
338
+
339
+
340
+ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
341
+ dtype: torch.dtype) -> bool:
342
+
343
+ # the performance of atomicAdd is better than global reduce
344
+ # only when m*n is small and k is large
345
+ if n >= 2048 or k < 2048 or device.type != "cuda":
346
+ return False
347
+
348
+ # disable atomicAdd reduce by default,
349
+ # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
350
+ if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
351
+ maybe_warn_marlin_atomic_add_env()
352
+ return False
353
+
354
+ # sm8x doesn't support atomicAdd + bfloat16 natively
355
+ device_capability = torch.cuda.get_device_capability(device)
356
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
357
+ maybe_warn_marlin_atomic_add(device, dtype)
358
+ return False
359
+
360
+ return True
361
+
362
+
363
  def apply_gptq_marlin_linear(
364
+ input: torch.Tensor,
365
+ weight: torch.Tensor,
366
+ weight_scale: torch.Tensor,
367
+ weight_zp: torch.Tensor,
368
+ g_idx: torch.Tensor,
369
+ g_idx_sort_indices: torch.Tensor,
370
+ workspace: torch.Tensor,
371
+ wtype: ScalarType,
372
+ output_size_per_partition: int,
373
+ input_size_per_partition: int,
374
+ is_k_full: bool,
375
+ bias: Optional[torch.Tensor] = None,
376
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
 
377
  reshaped_x = input.reshape(-1, input.shape[-1])
378
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
379
+
380
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
381
+ n=output_size_per_partition,
382
+ k=reshaped_x.size(1),
383
+ device=input.device,
384
+ dtype=input.dtype)
385
+
386
+ output = ops.gptq_marlin_gemm(reshaped_x,
387
+ None,
388
+ weight,
389
+ weight_scale,
390
+ None,
391
+ weight_zp,
392
+ g_idx,
393
+ g_idx_sort_indices,
394
+ workspace,
395
+ wtype,
396
+ size_m=reshaped_x.shape[0],
397
+ size_n=output_size_per_partition,
398
+ size_k=input_size_per_partition,
399
+ is_k_full=is_k_full,
400
+ use_atomic_add=use_atomic_add,
401
+ use_fp32_reduce=use_fp32_reduce,
402
+ is_zp_float=False)
403
 
404
  if bias is not None:
405
  output.add_(bias) # In-place add
 
408
 
409
 
410
  def apply_awq_marlin_linear(
411
+ input: torch.Tensor,
412
+ weight: torch.Tensor,
413
+ weight_scale: torch.Tensor,
414
+ weight_zp: torch.Tensor,
415
+ g_idx: torch.Tensor,
416
+ g_idx_sort_indices: torch.Tensor,
417
+ workspace: torch.Tensor,
418
+ quant_type: ScalarType,
419
+ output_size_per_partition: int,
420
+ input_size_per_partition: int,
421
+ bias: Optional[torch.Tensor] = None,
422
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
 
423
  reshaped_x = input.reshape(-1, input.shape[-1])
424
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
425
+
426
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
427
+ n=output_size_per_partition,
428
+ k=reshaped_x.size(1),
429
+ device=input.device,
430
+ dtype=input.dtype)
431
+
432
+ output = ops.gptq_marlin_gemm(reshaped_x,
433
+ None,
434
+ weight,
435
+ weight_scale,
436
+ None,
437
+ weight_zp,
438
+ g_idx,
439
+ g_idx_sort_indices,
440
+ workspace,
441
+ quant_type,
442
+ size_m=reshaped_x.shape[0],
443
+ size_n=output_size_per_partition,
444
+ size_k=input_size_per_partition,
445
+ use_atomic_add=use_atomic_add,
446
+ use_fp32_reduce=use_fp32_reduce,
447
+ is_zp_float=False)
448
 
449
  if bias is not None:
450
  output.add_(bias) # In-place add
build/torch27-cxx11-cu128-aarch64-linux/quantization/utils/marlin_utils_fp4.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ import quantization as ops
9
+
10
+ from .marlin_utils import (
11
+ USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
12
+ should_use_atomic_add_reduce)
13
+ from quantization.scalar_type import scalar_types
14
+
15
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
16
+
17
+
18
+ def is_fp4_marlin_supported():
19
+ capability = torch.cuda.get_device_capability()
20
+ capability = capability[0] * 10 + capability[1]
21
+ return capability >= 80
22
+
23
+
24
+ def fp4_marlin_process_scales(marlin_scales):
25
+ if not (marlin_scales >= 0).all():
26
+ logger.warning_once(
27
+ "NVFP4 Marlin assumes the scales to be >=0, but has encountered "
28
+ "negative scales. Accuracy will likely be degraded. This is "
29
+ "because it changes the scales from FP8-S1E4M3 to a special "
30
+ "FP8-S0E5M3 format to speedup the dequantization.")
31
+
32
+ # convert to half first, we would convert to fp8 later
33
+ marlin_scales = marlin_scales.to(torch.half)
34
+
35
+ # 8 is the number of scale number using by one thread
36
+ marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
37
+ marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
38
+ marlin_scales.size(0) * 2, -1)
39
+
40
+ # fit the layout of fp8 dequantization
41
+ marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
42
+ marlin_scales.size(0), -1)
43
+
44
+ # We assume that weight_scale (FP8-S1E4M3) is always greater
45
+ # than or equal to 0. So we can convert
46
+ # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
47
+ # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
48
+ # when weight_scale > 0. This allows us to have an exponent bias
49
+ # closer to zero after dequantization.
50
+
51
+ marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
52
+ marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
53
+ marlin_scales = marlin_scales[:, 1::2].contiguous()
54
+
55
+ return marlin_scales
56
+
57
+
58
+ def fp4_marlin_process_global_scale(global_scale):
59
+ assert global_scale.dtype in [torch.half, torch.bfloat16]
60
+ fp4_exponent = 2
61
+ if global_scale.dtype == torch.half:
62
+ target_exponent = 5
63
+ elif global_scale.dtype == torch.bfloat16:
64
+ target_exponent = 8
65
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
66
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
67
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
68
+ return global_scale * (2.0**(exponent_bias - 7))
69
+
70
+
71
+ def apply_fp4_marlin_linear(
72
+ input: torch.Tensor,
73
+ weight: torch.Tensor,
74
+ weight_scale: torch.Tensor,
75
+ weight_scale_2: torch.Tensor,
76
+ workspace: torch.Tensor,
77
+ size_n: int,
78
+ size_k: int,
79
+ bias: Optional[torch.Tensor] = None,
80
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
81
+ # For GPUs that lack FP4 hardware support, we can leverage the
82
+ # Marlin kernel for fast weight-only FP4 quantization
83
+
84
+ reshaped_x = input.reshape(-1, input.shape[-1])
85
+ out_shape = input.shape[:-1] + (size_n, )
86
+
87
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
88
+ n=size_n,
89
+ k=size_k,
90
+ device=input.device,
91
+ dtype=input.dtype)
92
+
93
+ output = ops.gptq_marlin_gemm(a=reshaped_x,
94
+ c=None,
95
+ b_q_weight=weight,
96
+ b_scales=weight_scale,
97
+ global_scale=weight_scale_2,
98
+ b_zeros=None,
99
+ g_idx=None,
100
+ perm=None,
101
+ workspace=workspace,
102
+ b_q_type=scalar_types.float4_e2m1f,
103
+ size_m=reshaped_x.size(0),
104
+ size_n=size_n,
105
+ size_k=size_k,
106
+ use_atomic_add=use_atomic_add,
107
+ use_fp32_reduce=use_fp32_reduce)
108
+
109
+ if bias is not None:
110
+ output.add_(bias) # In-place add
111
+
112
+ return output.reshape(out_shape)
113
+
114
+
115
+ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
116
+ logger.warning_once(
117
+ "Your GPU does not have native support for FP4 computation but "
118
+ "FP4 quantization is being used. Weight-only FP4 compression will "
119
+ "be used leveraging the Marlin kernel. This may degrade "
120
+ "performance for compute-heavy workloads.")
121
+
122
+ part_size_n = layer.output_size_per_partition
123
+ part_size_k = layer.input_size_per_partition
124
+ param_dtype = layer.params_dtype
125
+
126
+ assert layer.weight.shape == (part_size_n, part_size_k // 2)
127
+
128
+ device = layer.weight.device
129
+
130
+ # WORKSPACE
131
+ layer.workspace = marlin_make_workspace_new(device)
132
+
133
+ # WEIGHT
134
+ # Repack weights to marlin format
135
+ perm = torch.empty(0, dtype=torch.int, device=device)
136
+ qweight = layer.weight.view(torch.int32).T.contiguous()
137
+
138
+ marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
139
+ perm=perm,
140
+ size_k=part_size_k,
141
+ size_n=part_size_n,
142
+ num_bits=4)
143
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
144
+
145
+ # WEIGHT SCALES
146
+ # Permute scales
147
+ weight_scale = layer.weight_scale.T.to(param_dtype)
148
+ weight_scale = marlin_permute_scales(s=weight_scale,
149
+ size_k=part_size_k,
150
+ size_n=part_size_n,
151
+ group_size=16)
152
+ weight_scale = fp4_marlin_process_scales(weight_scale)
153
+ layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
154
+
155
+ weight_scale_2 = layer.weight_scale_2.to(param_dtype)
156
+ weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
157
+ layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
158
+ requires_grad=False)
159
+
160
+ return
161
+
162
+
163
+ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
164
+ logger.warning_once(
165
+ "Your GPU does not have native support for FP4 computation but "
166
+ "FP4 quantization is being used. Weight-only FP4 compression will "
167
+ "be used leveraging the Marlin kernel. This may degrade "
168
+ "performance for compute-heavy workloads.")
169
+
170
+ e = layer.num_experts
171
+ k = layer.hidden_size
172
+ n = layer.intermediate_size_per_partition
173
+
174
+ # WORKSPACE
175
+ device = layer.w13_weight.device
176
+ param_dtype = layer.params_dtype
177
+ layer.workspace = marlin_make_workspace_new(device, 4)
178
+ perm = torch.empty(0, dtype=torch.int, device=device)
179
+
180
+ # WEIGHT
181
+ # Repack weights to marlin format
182
+ for name in ["w13_weight", "w2_weight"]:
183
+ weight = getattr(layer, name)
184
+ tensor_list = []
185
+ if "w13" in name:
186
+ size_n, size_k = n * 2, k
187
+ else:
188
+ size_n, size_k = k, n
189
+
190
+ assert weight.shape == (e, size_n, size_k // 2)
191
+
192
+ for i in range(e):
193
+ qweight = weight[i].view(torch.int32).T.contiguous()
194
+
195
+ marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
196
+ perm=perm,
197
+ size_k=size_k,
198
+ size_n=size_n,
199
+ num_bits=4)
200
+ tensor_list.append(marlin_qweight)
201
+
202
+ weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
203
+ weight = torch.nn.Parameter(weight, requires_grad=False)
204
+
205
+ setattr(layer, name, weight)
206
+
207
+ # WEIGHT SCALES
208
+ # Permute scales
209
+ for name in ["w13", "w2"]:
210
+ scales = getattr(layer, name + "_weight_scale").to(param_dtype)
211
+ global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
212
+
213
+ tensor_list = []
214
+ if "w13" in name:
215
+ size_n, size_k = n * 2, k
216
+ else:
217
+ size_n, size_k = k, n
218
+
219
+ for i in range(e):
220
+ marlin_scales = marlin_permute_scales(s=scales[i].T,
221
+ size_k=size_k,
222
+ size_n=size_n,
223
+ group_size=16)
224
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
225
+ tensor_list.append(marlin_scales)
226
+
227
+ scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
228
+ scales = torch.nn.Parameter(scales, requires_grad=False)
229
+ setattr(layer, name + "_weight_scale", scales)
230
+
231
+ global_scale = fp4_marlin_process_global_scale(global_scale)
232
+ global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
233
+ setattr(layer, name + "_weight_scale_2", global_scale)
234
+
235
+
236
+ def rand_marlin_weight_fp4_like(weight, group_size):
237
+ assert group_size > 0
238
+ size_n, size_k = weight.shape
239
+ device = weight.device
240
+
241
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
242
+ global_scale = scales.max() / 448
243
+ scales = (scales / global_scale).to(torch.float8_e4m3fn)
244
+
245
+ fp4_weight = torch.randint(0,
246
+ 256, (size_n, size_k // 2),
247
+ dtype=torch.uint8,
248
+ device=weight.device)
249
+ fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
250
+ ((fp4_weight & 0b01110000) >> 2))
251
+ fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
252
+ fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
253
+
254
+ fp4_weight2 = fp4_weight << 4
255
+ fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
256
+ ((fp4_weight2 & 0b01110000) >> 2))
257
+ fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
258
+ fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
259
+
260
+ weight_ref = torch.cat(
261
+ [fp4_weight_part_2.unsqueeze(2),
262
+ fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
263
+ weight_ref = weight_ref * global_scale.to(weight.dtype) * \
264
+ scales.repeat_interleave(group_size, 1).to(weight.dtype)
265
+
266
+ marlin_qweight = ops.gptq_marlin_repack(
267
+ b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
268
+ perm=torch.empty(0, dtype=torch.int, device=device),
269
+ size_k=size_k,
270
+ size_n=size_n,
271
+ num_bits=4,
272
+ )
273
+
274
+ marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
275
+ size_k=size_k,
276
+ size_n=size_n,
277
+ group_size=group_size)
278
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
279
+
280
+ global_scale = fp4_marlin_process_global_scale(global_scale)
281
+
282
+ return weight_ref.T, marlin_qweight, marlin_scales, global_scale
build/torch27-cxx11-cu128-aarch64-linux/quantization/utils/marlin_utils_fp8.py CHANGED
@@ -1,10 +1,13 @@
 
 
 
1
  from typing import Optional
2
 
3
  import torch
4
 
5
  import quantization as ops
6
 
7
- from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
 
9
 
10
  def is_fp8_marlin_supported():
@@ -13,88 +16,107 @@ def is_fp8_marlin_supported():
13
  return capability >= 80
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def apply_fp8_marlin_linear(
17
- input: torch.Tensor,
18
- weight: torch.Tensor,
19
- weight_scale: torch.Tensor,
20
- workspace: torch.Tensor,
21
- size_n: int,
22
- size_k: int,
23
- bias: Optional[torch.Tensor],
24
- ) -> torch.Tensor:
25
  # For GPUs that lack FP8 hardware support, we can leverage the
26
  # Marlin kernel for fast weight-only FP8 quantization
27
 
28
  reshaped_x = input.reshape(-1, input.shape[-1])
29
- out_shape = input.shape[:-1] + (size_n,)
30
-
31
- output = ops.fp8_marlin_gemm(
32
- a=reshaped_x,
33
- b_q_weight=weight,
34
- b_scales=weight_scale,
35
- workspace=workspace,
36
- num_bits=8,
37
- size_m=reshaped_x.shape[0],
38
- size_n=size_n,
39
- size_k=size_k,
40
- )
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  if bias is not None:
43
  output.add_(bias) # In-place add
44
 
45
  return output.reshape(out_shape)
46
 
47
-
48
- def prepare_fp8_layer_for_marlin(
49
- layer: torch.nn.Module, strategy: str = "tensor"
50
- ) -> None:
51
- part_size_n = layer.output_size_per_partition
52
- part_size_k = layer.input_size_per_partition
53
-
54
- device = layer.weight.device
55
-
56
- # WORKSPACE
57
- layer.workspace = marlin_make_workspace(part_size_n, device)
58
-
59
- # WEIGHT
60
- # Repack weights to marlin format
61
- marlin_qweight = ops.gptq_marlin_repack(
62
- b_q_weight=pack_fp8_to_int32(layer.weight),
63
- perm=torch.empty(0, dtype=torch.int, device=device),
64
- size_k=part_size_k,
65
- size_n=part_size_n,
66
- num_bits=8,
67
- )
68
- layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
-
70
- # WEIGHT SCALES
71
- scales = layer.weight_scale.to(layer.orig_dtype)
72
- # Permute scales
73
- marlin_scales = marlin_permute_scales(
74
- s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
- )
76
- layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
-
78
-
79
- def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
  """
81
  Repack FP8 weights to gptq format (packed int32 elements)
82
  """
83
  assert fp8_tensor.dtype == torch.float8_e4m3fn
84
- assert fp8_tensor.shape[0] % 4 == 0
85
-
86
- # Reshape to prepare for packing
87
- reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Convert fp8 to uint8 (byte) representation
90
- byte_tensor = reshaped.view(torch.uint8)
 
 
91
 
92
- # Pack 4 uint8 values into one int32
93
- packed = (
94
- byte_tensor[:, 0].to(torch.int32)
95
- | (byte_tensor[:, 1].to(torch.int32) << 8)
96
- | (byte_tensor[:, 2].to(torch.int32) << 16)
97
- | (byte_tensor[:, 3].to(torch.int32) << 24)
98
- )
99
 
100
- return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  from typing import Optional
5
 
6
  import torch
7
 
8
  import quantization as ops
9
 
10
+ from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
11
 
12
 
13
  def is_fp8_marlin_supported():
 
16
  return capability >= 80
17
 
18
 
19
+ def fp8_fused_exponent_bias_into_scales(scales):
20
+ fp8_exponent = 4
21
+ if scales.dtype == torch.half:
22
+ target_exponent = 5
23
+ elif scales.dtype == torch.bfloat16:
24
+ target_exponent = 8
25
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
26
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
27
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
28
+ s = torch.ones_like(scales) * 2
29
+ s = s**exponent_bias
30
+ return scales * s
31
+
32
+
33
  def apply_fp8_marlin_linear(
34
+ input: torch.Tensor,
35
+ weight: torch.Tensor,
36
+ weight_scale: torch.Tensor,
37
+ workspace: torch.Tensor,
38
+ size_n: int,
39
+ size_k: int,
40
+ bias: Optional[torch.Tensor],
41
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
42
  # For GPUs that lack FP8 hardware support, we can leverage the
43
  # Marlin kernel for fast weight-only FP8 quantization
44
 
45
  reshaped_x = input.reshape(-1, input.shape[-1])
46
+ out_shape = input.shape[:-1] + (size_n, )
47
+
48
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
49
+ n=size_n,
50
+ k=size_k,
51
+ device=input.device,
52
+ dtype=input.dtype)
53
+
54
+ output = ops.gptq_marlin_gemm(a=reshaped_x,
55
+ c=None,
56
+ b_q_weight=weight,
57
+ b_scales=weight_scale,
58
+ global_scale=None,
59
+ b_zeros=None,
60
+ g_idx=None,
61
+ perm=None,
62
+ workspace=workspace,
63
+ b_q_type=scalar_types.float8_e4m3fn,
64
+ size_m=reshaped_x.size(0),
65
+ size_n=size_n,
66
+ size_k=size_k,
67
+ use_atomic_add=use_atomic_add,
68
+ use_fp32_reduce=use_fp32_reduce)
69
 
70
  if bias is not None:
71
  output.add_(bias) # In-place add
72
 
73
  return output.reshape(out_shape)
74
 
75
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
76
+ size_k_first: bool = True) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  Repack FP8 weights to gptq format (packed int32 elements)
79
  """
80
  assert fp8_tensor.dtype == torch.float8_e4m3fn
81
+ assert fp8_tensor.ndim == 2
82
+
83
+ fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
84
+ fp8_tensor = fp8_tensor.contiguous()
85
+ # fp8_tensor is contiguous and have shape (N, K) now
86
+ # with `.view(torch.int32)`, it become (N, K // 4)
87
+ int32_tensor = fp8_tensor.view(torch.int32)
88
+ return int32_tensor.T.contiguous() if size_k_first else int32_tensor
89
+
90
+
91
+ def marlin_quant_fp8_torch(weight, group_size):
92
+ size_n, size_k = weight.shape
93
+ device = weight.device
94
+
95
+ if group_size != -1:
96
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
97
+ repeated_scales = scales.repeat_interleave(group_size, 1)
98
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
99
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
100
+ else:
101
+ scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
102
+ repeated_scales = scales.repeat_interleave(size_k, 1)
103
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
104
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
105
+
106
+ packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
107
+ marlin_qweight = ops.gptq_marlin_repack(
108
+ b_q_weight=packed_weight,
109
+ perm=torch.empty(0, dtype=torch.int, device=device),
110
+ size_k=size_k,
111
+ size_n=size_n,
112
+ num_bits=8,
113
+ )
114
 
115
+ marlin_scales = marlin_permute_scales(s=scales.T,
116
+ size_k=size_k,
117
+ size_n=size_n,
118
+ group_size=group_size)
119
 
120
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
 
 
 
 
 
 
121
 
122
+ return weight_ref.T, marlin_qweight, marlin_scales