danieldk HF staff commited on
Commit
a6c77d7
·
1 Parent(s): 165b25c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py +30 -150
  2. build/torch24-cxx11-cu118-x86_64-linux/quantization/_ops.py +6 -0
  3. build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  4. build/torch24-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py +110 -0
  5. build/torch24-cxx11-cu118-x86_64-linux/quantization/cutlass.py +75 -0
  6. build/torch24-cxx11-cu118-x86_64-linux/quantization/marlin.py +208 -0
  7. build/torch24-cxx11-cu118-x86_64-linux/quantization/scalar_type.py +330 -0
  8. build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py +0 -0
  9. build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py +391 -0
  10. build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py +100 -0
  11. build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py +162 -0
  12. build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py +473 -0
  13. build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +125 -0
  14. build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py +470 -0
  15. build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py +30 -150
  16. build/torch24-cxx11-cu121-x86_64-linux/quantization/_ops.py +6 -0
  17. build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  18. build/torch24-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py +110 -0
  19. build/torch24-cxx11-cu121-x86_64-linux/quantization/cutlass.py +75 -0
  20. build/torch24-cxx11-cu121-x86_64-linux/quantization/marlin.py +208 -0
  21. build/torch24-cxx11-cu121-x86_64-linux/quantization/scalar_type.py +330 -0
  22. build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py +0 -0
  23. build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py +391 -0
  24. build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py +100 -0
  25. build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py +162 -0
  26. build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py +473 -0
  27. build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +125 -0
  28. build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py +470 -0
  29. build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py +30 -150
  30. build/torch24-cxx11-cu124-x86_64-linux/quantization/_ops.py +6 -0
  31. build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  32. build/torch24-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py +110 -0
  33. build/torch24-cxx11-cu124-x86_64-linux/quantization/cutlass.py +75 -0
  34. build/torch24-cxx11-cu124-x86_64-linux/quantization/marlin.py +208 -0
  35. build/torch24-cxx11-cu124-x86_64-linux/quantization/scalar_type.py +330 -0
  36. build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py +0 -0
  37. build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py +391 -0
  38. build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py +100 -0
  39. build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py +162 -0
  40. build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py +473 -0
  41. build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +125 -0
  42. build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py +470 -0
  43. build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py +30 -150
  44. build/torch24-cxx98-cu118-x86_64-linux/quantization/_ops.py +6 -0
  45. build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  46. build/torch24-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py +110 -0
  47. build/torch24-cxx98-cu118-x86_64-linux/quantization/cutlass.py +75 -0
  48. build/torch24-cxx98-cu118-x86_64-linux/quantization/marlin.py +208 -0
  49. build/torch24-cxx98-cu118-x86_64-linux/quantization/scalar_type.py +330 -0
  50. build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py +0 -0
build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,150 +1,30 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch
4
-
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
- ops = torch.ops._quantization
12
- except ImportError:
13
- raise e
14
-
15
- def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
16
- return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
17
-
18
- def cutlass_scaled_mm(a: torch.Tensor,
19
- b: torch.Tensor,
20
- scale_a: torch.Tensor,
21
- scale_b: torch.Tensor,
22
- out_dtype: torch.dtype,
23
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
24
- assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
25
- assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
26
- assert bias is None or bias.shape[0] == b.shape[
27
- 1] and bias.dtype == out_dtype
28
-
29
- m = a.shape[0]
30
- n = b.shape[1]
31
-
32
- #if current_platform.is_rocm():
33
- # triton_scaled_mm_module = importlib.import_module(
34
- # "vllm.model_executor.layers.quantization.compressed_tensors."
35
- # "triton_scaled_mm")
36
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
37
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
38
-
39
- out = torch.empty((m, n), dtype=out_dtype, device=a.device)
40
-
41
- ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
42
-
43
- return out
44
-
45
- # fp8
46
- def scaled_fp8_quant(
47
- input: torch.Tensor,
48
- scale: Optional[torch.Tensor] = None,
49
- num_token_padding: Optional[int] = None,
50
- scale_ub: Optional[torch.Tensor] = None,
51
- use_per_token_if_dynamic: bool = False,
52
- ) -> Tuple[torch.Tensor, torch.Tensor]:
53
- """
54
- Quantize input tensor to FP8 and return quantized tensor and scale.
55
-
56
- This function supports both static and dynamic quantization: If you
57
- provide the scale, it will use static scaling and if you omit it,
58
- the scale will be determined dynamically. The function also allows
59
- optional padding of the output tensors for downstream kernels that
60
- will benefit from padding.
61
-
62
- Args:
63
- input: The input tensor to be quantized to FP8
64
- scale: Optional scaling factor for the FP8 quantization
65
- scale_ub: Optional upper bound for scaling factor in dynamic
66
- per token case
67
- num_token_padding: If specified, pad the first dimension
68
- of the output to at least this value.
69
- use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
- in the dynamic quantization case.
71
-
72
- Returns:
73
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
- scaling factor.
75
- """
76
- # This code assumes batch_dim and num_tokens are flattened
77
- assert (input.ndim == 2)
78
- shape: Union[Tuple[int, int], torch.Size] = input.shape
79
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
- #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
- # if current_platform.is_rocm() else torch.float8_e4m3fn
82
- out_dtype = torch.float8_e4m3fn
83
- if num_token_padding:
84
- shape = (max(num_token_padding, input.shape[0]), shape[1])
85
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
-
87
- if scale is None:
88
- if use_per_token_if_dynamic:
89
- scale = torch.empty((shape[0], 1),
90
- device=input.device,
91
- dtype=torch.float32)
92
- ops.dynamic_per_token_scaled_fp8_quant(
93
- output, input, scale, scale_ub)
94
- else:
95
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
- ops.dynamic_scaled_fp8_quant(output, input, scale)
97
- else:
98
- # num_token_padding not implemented for this case
99
- assert (scale.numel() == 1 or num_token_padding is None)
100
- ops.static_scaled_fp8_quant(output, input, scale)
101
-
102
- return output, scale
103
-
104
- # int8
105
- def scaled_int8_quant(
106
- input: torch.Tensor,
107
- scale: Optional[torch.Tensor] = None,
108
- azp: Optional[torch.Tensor] = None,
109
- symmetric: bool = True
110
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
- """
112
- Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
-
114
- Args:
115
- input: The input tensor to be quantized to int8.
116
- scale: Optional scaling factor for the int8 quantization.
117
- When not provided, we invoke dynamic-per-token quantization.
118
- azp: Optional zero-point for the int8 quantization.
119
- Must be provided for asymmetric quantization if `scale` is provided.
120
- symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
-
122
- Returns:
123
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
- """
125
- output = torch.empty_like(input, dtype=torch.int8)
126
- if scale is not None:
127
- # static-per-tensor quantization.
128
- assert symmetric == (
129
- azp is
130
- None), "azp must only be provided for asymmetric quantization."
131
- ops.static_scaled_int8_quant(output, input, scale, azp)
132
- return output, scale, azp
133
-
134
- # dynamic-per-token quantization.
135
- input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
- device=input.device,
137
- dtype=torch.float32)
138
- input_azp = None if symmetric else torch.empty_like(input_scales,
139
- dtype=torch.int32)
140
- ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
- input_azp)
142
- return output, input_scales, input_azp
143
-
144
- # fp8 marlin
145
- def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
- b_scales: torch.Tensor, workspace: torch.Tensor,
147
- num_bits: int, size_m: int, size_n: int,
148
- size_k: int) -> torch.Tensor:
149
- return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
- num_bits, size_m, size_n, size_k)
 
1
+ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
+ from .cutlass import (
3
+ cutlass_scaled_mm_supports_fp8,
4
+ cutlass_scaled_mm,
5
+ cutlass_scaled_mm_azp,
6
+ )
7
+ from .marlin import (
8
+ awq_marlin_repack,
9
+ fp8_marlin_gemm,
10
+ gptq_marlin_gemm,
11
+ gptq_marlin_repack,
12
+ gptq_marlin_24_gemm,
13
+ marlin_qqq_gemm,
14
+ marlin_gemm,
15
+ )
16
+
17
+ __all__ = [
18
+ "awq_marlin_repack",
19
+ "cutlass_scaled_mm",
20
+ "cutlass_scaled_mm_azp",
21
+ "cutlass_scaled_mm_supports_fp8",
22
+ "fp8_marlin_gemm",
23
+ "gptq_marlin_24_gemm",
24
+ "gptq_marlin_gemm",
25
+ "gptq_marlin_repack",
26
+ "marlin_gemm",
27
+ "marlin_qqq_gemm",
28
+ "scaled_fp8_quant",
29
+ "scaled_int8_quant",
30
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch24-cxx11-cu118-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,3 +1,9 @@
1
  import torch
2
  from . import _quantization_0_0_1
3
  ops = torch.ops._quantization_0_0_1
 
 
 
 
 
 
 
1
  import torch
2
  from . import _quantization_0_0_1
3
  ops = torch.ops._quantization_0_0_1
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_quantization_0_0_1::{op_name}"
build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8b08406547ecaf4b08409b5c8a5144ac0f91faac6c28dcfa6938dd75470db34
3
- size 70296128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35967133ffbd0cac32aafc9e70f441264b1f41710f4f86d68723d2eb9a59cfe8
3
+ size 85717704
build/torch24-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ # fp8
18
+ def scaled_fp8_quant(
19
+ input: torch.Tensor,
20
+ scale: Optional[torch.Tensor] = None,
21
+ num_token_padding: Optional[int] = None,
22
+ scale_ub: Optional[torch.Tensor] = None,
23
+ use_per_token_if_dynamic: bool = False,
24
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Quantize input tensor to FP8 and return quantized tensor and scale.
27
+
28
+ This function supports both static and dynamic quantization: If you
29
+ provide the scale, it will use static scaling and if you omit it,
30
+ the scale will be determined dynamically. The function also allows
31
+ optional padding of the output tensors for downstream kernels that
32
+ will benefit from padding.
33
+
34
+ Args:
35
+ input: The input tensor to be quantized to FP8
36
+ scale: Optional scaling factor for the FP8 quantization
37
+ scale_ub: Optional upper bound for scaling factor in dynamic
38
+ per token case
39
+ num_token_padding: If specified, pad the first dimension
40
+ of the output to at least this value.
41
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
42
+ in the dynamic quantization case.
43
+
44
+ Returns:
45
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
+ scaling factor.
47
+ """
48
+ # This code assumes batch_dim and num_tokens are flattened
49
+ assert input.ndim == 2
50
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
51
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
+ # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
54
+ out_dtype = torch.float8_e4m3fn
55
+ if num_token_padding:
56
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
57
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
58
+
59
+ if scale is None:
60
+ if use_per_token_if_dynamic:
61
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
63
+ else:
64
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
66
+ else:
67
+ # num_token_padding not implemented for this case
68
+ assert scale.numel() == 1 or num_token_padding is None
69
+ ops.static_scaled_fp8_quant(output, input, scale)
70
+
71
+ return output, scale
72
+
73
+
74
+ # int8
75
+ def scaled_int8_quant(
76
+ input: torch.Tensor,
77
+ scale: Optional[torch.Tensor] = None,
78
+ azp: Optional[torch.Tensor] = None,
79
+ symmetric: bool = True,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
+ """
82
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
+
84
+ Args:
85
+ input: The input tensor to be quantized to int8.
86
+ scale: Optional scaling factor for the int8 quantization.
87
+ When not provided, we invoke dynamic-per-token quantization.
88
+ azp: Optional zero-point for the int8 quantization.
89
+ Must be provided for asymmetric quantization if `scale` is provided.
90
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
+
92
+ Returns:
93
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
+ """
95
+ output = torch.empty_like(input, dtype=torch.int8)
96
+ if scale is not None:
97
+ # static-per-tensor quantization.
98
+ assert symmetric == (
99
+ azp is None
100
+ ), "azp must only be provided for asymmetric quantization."
101
+ ops.static_scaled_int8_quant(output, input, scale, azp)
102
+ return output, scale, azp
103
+
104
+ # dynamic-per-token quantization.
105
+ input_scales = torch.empty(
106
+ (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
+ )
108
+ input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
+ ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
110
+ return output, input_scales, input_azp
build/torch24-cxx11-cu118-x86_64-linux/quantization/cutlass.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
+ return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
+
20
+
21
+ def cutlass_scaled_mm(
22
+ a: torch.Tensor,
23
+ b: torch.Tensor,
24
+ scale_a: torch.Tensor,
25
+ scale_b: torch.Tensor,
26
+ out_dtype: torch.dtype,
27
+ bias: Optional[torch.Tensor] = None,
28
+ ) -> torch.Tensor:
29
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
30
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
31
+ assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
32
+
33
+ m = a.shape[0]
34
+ n = b.shape[1]
35
+
36
+ # if current_platform.is_rocm():
37
+ # triton_scaled_mm_module = importlib.import_module(
38
+ # "vllm.model_executor.layers.quantization.compressed_tensors."
39
+ # "triton_scaled_mm")
40
+ # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
+ # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
+
43
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
+
45
+ ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
46
+
47
+ return out
48
+
49
+
50
+ def cutlass_scaled_mm_azp(
51
+ a: torch.Tensor,
52
+ b: torch.Tensor,
53
+ scale_a: torch.Tensor,
54
+ scale_b: torch.Tensor,
55
+ out_dtype: torch.dtype,
56
+ azp_adj: torch.Tensor,
57
+ azp: Optional[torch.Tensor] = None,
58
+ bias: Optional[torch.Tensor] = None,
59
+ ) -> torch.Tensor:
60
+ """
61
+ :param azp_adj: In the per-tensor case, this should include the azp.
62
+ Always per-channel.
63
+ :param azp: Only set in the per-token case. Per-token if set.
64
+ """
65
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
66
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
67
+ assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
68
+ assert azp is None or azp.numel() == a.shape[0]
69
+
70
+ m = a.shape[0]
71
+ n = b.shape[1]
72
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
73
+
74
+ ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
75
+ return out
build/torch24-cxx11-cu118-x86_64-linux/quantization/marlin.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+ def register_fake(fn):
8
+ return lambda name: fn
9
+ else:
10
+ try:
11
+ from torch.library import register_fake
12
+ except ImportError:
13
+ from torch.library import impl_abstract as register_fake
14
+
15
+ try:
16
+ from ._ops import ops, add_op_namespace_prefix
17
+ except ImportError as e:
18
+ # Fallback for local development.
19
+ try:
20
+ import _quantization
21
+
22
+ ops = torch.ops._quantization
23
+
24
+ def add_op_namespace_prefix(op_name: str):
25
+ return f"_quantization::{op_name}"
26
+ except ImportError:
27
+ raise e
28
+
29
+
30
+ from .scalar_type import ScalarType
31
+
32
+
33
+ # fp8 marlin
34
+ def fp8_marlin_gemm(
35
+ a: torch.Tensor,
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ workspace: torch.Tensor,
39
+ num_bits: int,
40
+ size_m: int,
41
+ size_n: int,
42
+ size_k: int,
43
+ ) -> torch.Tensor:
44
+ return ops.fp8_marlin_gemm(
45
+ a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
+ )
47
+
48
+
49
+ # gptq_marlin
50
+ def gptq_marlin_gemm(
51
+ a: torch.Tensor,
52
+ b_q_weight: torch.Tensor,
53
+ b_scales: torch.Tensor,
54
+ b_zeros: torch.Tensor,
55
+ g_idx: torch.Tensor,
56
+ perm: torch.Tensor,
57
+ workspace: torch.Tensor,
58
+ b_q_type: ScalarType,
59
+ size_m: int,
60
+ size_n: int,
61
+ size_k: int,
62
+ is_k_full: bool,
63
+ has_zp: bool = False,
64
+ use_fp32_reduce: bool = False,
65
+ is_zp_float: bool = False,
66
+ ) -> torch.Tensor:
67
+ return ops.gptq_marlin_gemm(
68
+ a,
69
+ b_q_weight,
70
+ b_scales,
71
+ b_zeros,
72
+ g_idx,
73
+ perm,
74
+ workspace,
75
+ b_q_type.id,
76
+ size_m,
77
+ size_n,
78
+ size_k,
79
+ is_k_full,
80
+ has_zp,
81
+ use_fp32_reduce,
82
+ is_zp_float,
83
+ )
84
+
85
+
86
+ # gptq_marlin
87
+ def gptq_marlin_repack(
88
+ b_q_weight: torch.Tensor,
89
+ perm: torch.Tensor,
90
+ size_k: int,
91
+ size_n: int,
92
+ num_bits: int,
93
+ ) -> torch.Tensor:
94
+ return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
95
+
96
+
97
+ # gptq_marlin
98
+ def awq_marlin_repack(
99
+ b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
100
+ ) -> torch.Tensor:
101
+ return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
102
+
103
+
104
+ # marlin
105
+ def marlin_gemm(
106
+ a: torch.Tensor,
107
+ b_q_weight: torch.Tensor,
108
+ b_scales: torch.Tensor,
109
+ workspace: torch.Tensor,
110
+ size_m: int,
111
+ size_n: int,
112
+ size_k: int,
113
+ ) -> torch.Tensor:
114
+ return ops.marlin_gemm(
115
+ a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
116
+ )
117
+
118
+
119
+ # marlin_24
120
+ def gptq_marlin_24_gemm(
121
+ a: torch.Tensor,
122
+ b_q_weight: torch.Tensor,
123
+ b_meta: torch.Tensor,
124
+ b_scales: torch.Tensor,
125
+ workspace: torch.Tensor,
126
+ b_q_type: ScalarType,
127
+ size_m: int,
128
+ size_n: int,
129
+ size_k: int,
130
+ ) -> torch.Tensor:
131
+ return ops.gptq_marlin_24_gemm(
132
+ a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
133
+ )
134
+
135
+
136
+ # qqq ops
137
+ def marlin_qqq_gemm(
138
+ a: torch.Tensor,
139
+ b_q_weight: torch.Tensor,
140
+ s_tok: torch.Tensor,
141
+ s_ch: torch.Tensor,
142
+ s_group: torch.Tensor,
143
+ workspace: torch.Tensor,
144
+ size_m: int,
145
+ size_n: int,
146
+ size_k: int,
147
+ ) -> torch.Tensor:
148
+ return ops.marlin_qqq_gemm(
149
+ a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
150
+ )
151
+
152
+
153
+ # Fake ops
154
+
155
+ if hasattr(ops, "gptq_marlin_24_gemm"):
156
+ @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
+ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
+ b_scales: torch.Tensor, workspace: torch.Tensor,
159
+ num_bits: int, size_m: torch.SymInt,
160
+ size_n: torch.SymInt,
161
+ size_k: torch.SymInt) -> torch.Tensor:
162
+ return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
+
164
+ @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
+ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
+ b_meta: torch.Tensor, b_scales: torch.Tensor,
167
+ workspace: torch.Tensor,
168
+ b_q_type: ScalarType, size_m: torch.SymInt,
169
+ size_n: torch.SymInt,
170
+ size_k: torch.SymInt) -> torch.Tensor:
171
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
172
+
173
+ @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
+ def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
+ b_q_weight: torch.Tensor,
176
+ b_scales: torch.Tensor,
177
+ b_zeros: torch.Tensor,
178
+ g_idx: torch.Tensor,
179
+ perm: torch.Tensor,
180
+ workspace: torch.Tensor,
181
+ b_q_type: ScalarType,
182
+ size_m: torch.SymInt,
183
+ size_n: torch.SymInt,
184
+ size_k: torch.SymInt,
185
+ is_k_full: bool,
186
+ has_zp: bool = False,
187
+ use_fp32_reduce: bool = False,
188
+ is_zp_float: bool = False) -> torch.Tensor:
189
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
+
191
+ @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
192
+ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
193
+ s_tok: torch.Tensor, s_ch: torch.Tensor,
194
+ s_group: torch.Tensor, workspace: torch.Tensor,
195
+ size_m: torch.SymInt, size_n: torch.SymInt,
196
+ size_k: torch.SymInt) -> torch.Tensor:
197
+ return torch.empty((size_m, size_n),
198
+ dtype=torch.float16,
199
+ device=a.device)
200
+
201
+ @register_fake(add_op_namespace_prefix("marlin_gemm"))
202
+ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
203
+ b_scales: torch.Tensor, workspace: torch.Tensor,
204
+ size_m: torch.SymInt, size_n: torch.SymInt,
205
+ size_k: torch.SymInt) -> torch.Tensor:
206
+ return torch.empty((size_m, size_n),
207
+ dtype=torch.float16,
208
+ device=a.device)
build/torch24-cxx11-cu118-x86_64-linux/quantization/scalar_type.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import struct
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Optional, Union
6
+
7
+
8
+ # Mirrors enum in `core/scalar_type.hpp`
9
+ class NanRepr(Enum):
10
+ NONE = 0 # nans are not supported
11
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
12
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
13
+
14
+
15
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
16
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
17
+ # in sync until the inductor fully supports custom C++ classes.
18
+ @dataclass(frozen=True)
19
+ class ScalarType:
20
+ """
21
+ ScalarType can represent a wide range of floating point and integer
22
+ types, in particular it can be used to represent sub-byte data types
23
+ (something that torch.dtype currently does not support). It is also
24
+ capable of representing types with a bias, i.e.:
25
+ `stored_value = value + bias`,
26
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
27
+ of 8). The implementation for this class can be found in
28
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
29
+ with that file.
30
+ """
31
+
32
+ exponent: int
33
+ """
34
+ Number of bits in the exponent if this is a floating point type
35
+ (zero if this an integer type)
36
+ """
37
+
38
+ mantissa: int
39
+ """
40
+ Number of bits in the mantissa if this is a floating point type,
41
+ or the number bits representing an integer excluding the sign bit if
42
+ this an integer type.
43
+ """
44
+
45
+ signed: bool
46
+ "If the type is signed (i.e. has a sign bit)"
47
+
48
+ bias: int
49
+ """
50
+ bias used to encode the values in this scalar type
51
+ (value = stored_value - bias, default 0) for example if we store the
52
+ type as an unsigned integer with a bias of 128 then the value 0 will be
53
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
54
+ """
55
+
56
+ _finite_values_only: bool = False
57
+ """
58
+ Private: if infs are supported, used `has_infs()` instead.
59
+ """
60
+
61
+ nan_repr: NanRepr = NanRepr.IEEE_754
62
+ """
63
+ How NaNs are represent in this scalar type, returns NanRepr value.
64
+ (not applicable for integer types)
65
+ """
66
+
67
+ def _floating_point_max_int(self) -> int:
68
+ assert (
69
+ self.mantissa <= 52 and self.exponent <= 11
70
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
71
+
72
+ max_mantissa = (1 << self.mantissa) - 1
73
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
74
+ max_mantissa = max_mantissa - 1
75
+
76
+ max_exponent = (1 << self.exponent) - 2
77
+ if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
78
+ or self.nan_repr == NanRepr.NONE):
79
+ assert (
80
+ self.exponent < 11
81
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
82
+ max_exponent = max_exponent + 1
83
+
84
+ # adjust the exponent to match that of a double
85
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
86
+ # e is the exponent bits), there is some precedent for non-standard
87
+ # biases, example `float8_e4m3b11fnuz` here:
88
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
89
+ # complication we are just assuming the standard exponent bias until
90
+ # there is a need to support non-standard biases
91
+ exponent_bias = (1 << (self.exponent - 1)) - 1
92
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
93
+
94
+ max_exponent_double = (max_exponent - exponent_bias +
95
+ exponent_bias_double)
96
+
97
+ # shift the mantissa and exponent into the proper positions for an
98
+ # IEEE double and bitwise-or them together.
99
+ return (max_mantissa <<
100
+ (52 - self.mantissa)) | (max_exponent_double << 52)
101
+
102
+ def _floating_point_max(self) -> float:
103
+ double_raw = self._floating_point_max_int()
104
+ return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
105
+
106
+ def _raw_max(self) -> Union[int, float]:
107
+ if self.is_floating_point():
108
+ return self._floating_point_max()
109
+ else:
110
+ assert (self.size_bits < 64 or self.size_bits == 64
111
+ and self.is_signed()), "Cannot represent max as an int"
112
+ return (1 << self.mantissa) - 1
113
+
114
+ def _raw_min(self) -> Union[int, float]:
115
+ if self.is_floating_point():
116
+ assert self.is_signed(
117
+ ), "We currently assume all floating point types are signed"
118
+ sign_bit_double = 1 << 63
119
+
120
+ max_raw = self._floating_point_max_int()
121
+ min_raw = max_raw | sign_bit_double
122
+ return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
+ else:
124
+ assert (not self.is_signed() or
125
+ self.size_bits <= 64), "Cannot represent min as a int64_t"
126
+
127
+ if self.is_signed():
128
+ return -(1 << (self.size_bits - 1))
129
+ else:
130
+ return 0
131
+
132
+ @functools.cached_property
133
+ def id(self) -> int:
134
+ """
135
+ Convert the ScalarType to an int which can be passed to pytorch custom
136
+ ops. This layout of the int must be kept in sync with the C++
137
+ ScalarType's from_id method.
138
+ """
139
+ val = 0
140
+ offset = 0
141
+
142
+ def or_and_advance(member, bit_width):
143
+ nonlocal val
144
+ nonlocal offset
145
+ bit_mask = (1 << bit_width) - 1
146
+ val = val | (int(member) & bit_mask) << offset
147
+ offset = offset + bit_width
148
+
149
+ or_and_advance(self.exponent, 8)
150
+ or_and_advance(self.mantissa, 8)
151
+ or_and_advance(self.signed, 1)
152
+ or_and_advance(self.bias, 32)
153
+ or_and_advance(self._finite_values_only, 1)
154
+ or_and_advance(self.nan_repr.value, 8)
155
+
156
+ assert offset <= 64, \
157
+ f"ScalarType fields too big {offset} to fit into an int64"
158
+
159
+ return val
160
+
161
+ @property
162
+ def size_bits(self) -> int:
163
+ return self.exponent + self.mantissa + int(self.signed)
164
+
165
+ def min(self) -> Union[int, float]:
166
+ """
167
+ Min representable value for this scalar type.
168
+ (accounting for bias if there is one)
169
+ """
170
+ return self._raw_min() - self.bias
171
+
172
+ def max(self) -> Union[int, float]:
173
+ """
174
+ Max representable value for this scalar type.
175
+ (accounting for bias if there is one)
176
+ """
177
+ return self._raw_max() - self.bias
178
+
179
+ def is_signed(self) -> bool:
180
+ """
181
+ If the type is signed (i.e. has a sign bit), same as `signed`
182
+ added for consistency with:
183
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
184
+ """
185
+ return self.signed
186
+
187
+ def is_floating_point(self) -> bool:
188
+ "If the type is a floating point type"
189
+ return self.exponent != 0
190
+
191
+ def is_integer(self) -> bool:
192
+ "If the type is an integer type"
193
+ return self.exponent == 0
194
+
195
+ def has_bias(self) -> bool:
196
+ "If the type has a non-zero bias"
197
+ return self.bias != 0
198
+
199
+ def has_infs(self) -> bool:
200
+ "If the type is floating point and supports infinity"
201
+ return not self._finite_values_only
202
+
203
+ def has_nans(self) -> bool:
204
+ return self.nan_repr != NanRepr.NONE.value
205
+
206
+ def is_ieee_754(self) -> bool:
207
+ """
208
+ If the type is a floating point type that follows IEEE 754
209
+ conventions
210
+ """
211
+ return self.nan_repr == NanRepr.IEEE_754.value and \
212
+ not self._finite_values_only
213
+
214
+ def __str__(self) -> str:
215
+ """
216
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
217
+ for floating point types (leading f) the scheme is:
218
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
219
+ flags:
220
+ - no-flags: means it follows IEEE 754 conventions
221
+ - f: means finite values only (no infinities)
222
+ - n: means nans are supported (non-standard encoding)
223
+ for integer types the scheme is:
224
+ `[u]int<size_bits>[b<bias>]`
225
+ - if bias is not present it means its zero
226
+ """
227
+ if self.is_floating_point():
228
+ ret = "float" + str(self.size_bits) + "_e" + str(
229
+ self.exponent) + "m" + str(self.mantissa)
230
+
231
+ if not self.is_ieee_754():
232
+ if self._finite_values_only:
233
+ ret = ret + "f"
234
+ if self.nan_repr != NanRepr.NONE:
235
+ ret = ret + "n"
236
+
237
+ return ret
238
+ else:
239
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
240
+ if self.has_bias():
241
+ ret = ret + "b" + str(self.bias)
242
+ return ret
243
+
244
+ def __repr__(self) -> str:
245
+ return "ScalarType." + self.__str__()
246
+
247
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
248
+ # opcheck to work.
249
+ def __len__(self) -> int:
250
+ raise TypeError
251
+
252
+ #
253
+ # Convenience Constructors
254
+ #
255
+
256
+ @classmethod
257
+ def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
258
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
259
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
260
+ ret.id # noqa B018: make sure the id is cached
261
+ return ret
262
+
263
+ @classmethod
264
+ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
+ """Create a unsigned integer scalar type."""
266
+ ret = cls(0, size_bits, False, bias if bias else 0)
267
+ ret.id # noqa B018: make sure the id is cached
268
+ return ret
269
+
270
+ @classmethod
271
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
272
+ """
273
+ Create a standard floating point type
274
+ (i.e. follows IEEE 754 conventions).
275
+ """
276
+ assert (mantissa > 0 and exponent > 0)
277
+ ret = cls(exponent, mantissa, True, 0)
278
+ ret.id # noqa B018: make sure the id is cached
279
+ return ret
280
+
281
+ @classmethod
282
+ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
283
+ nan_repr: NanRepr) -> 'ScalarType':
284
+ """
285
+ Create a non-standard floating point type
286
+ (i.e. does not follow IEEE 754 conventions).
287
+ """
288
+ assert (mantissa > 0 and exponent > 0)
289
+ assert (nan_repr != NanRepr.IEEE_754), (
290
+ "use `float_IEEE754` constructor for floating point types that "
291
+ "follow IEEE 754 conventions")
292
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
293
+ ret.id # noqa B018: make sure the id is cached
294
+ return ret
295
+
296
+
297
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
+ # for floating point types (leading f) the scheme is:
299
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
300
+ # flags:
301
+ # - no-flags: means it follows IEEE 754 conventions
302
+ # - f: means finite values only (no infinities)
303
+ # - n: means nans are supported (non-standard encoding)
304
+ # for integer types the scheme is:
305
+ # `[u]int<size_bits>[b<bias>]`
306
+ # - if bias is not present it means its zero
307
+
308
+
309
+ class scalar_types:
310
+ int4 = ScalarType.int_(4, None)
311
+ uint4 = ScalarType.uint(4, None)
312
+ int8 = ScalarType.int_(8, None)
313
+ uint8 = ScalarType.uint(8, None)
314
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
315
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
316
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
317
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
318
+
319
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
+
322
+ # "gptq" types
323
+ uint2b2 = ScalarType.uint(2, 2)
324
+ uint3b4 = ScalarType.uint(3, 4)
325
+ uint4b8 = ScalarType.uint(4, 8)
326
+ uint8b128 = ScalarType.uint(8, 128)
327
+
328
+ # colloquial names
329
+ bfloat16 = float16_e8m7
330
+ float16 = float16_e5m10
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py ADDED
File without changes
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ import quantization as ops
7
+ from quantization.scalar_type import ScalarType, scalar_types
8
+
9
+ from .quant_utils import pack_cols, unpack_cols
10
+
11
+ GPTQ_MARLIN_TILE = 16
12
+ GPTQ_MARLIN_MIN_THREAD_N = 64
13
+ GPTQ_MARLIN_MIN_THREAD_K = 128
14
+ GPTQ_MARLIN_MAX_PARALLEL = 16
15
+
16
+ GPTQ_MARLIN_24_TILE = 16
17
+ GPTQ_MARLIN_24_MIN_THREAD_N = 128
18
+ GPTQ_MARLIN_24_MIN_THREAD_K = 128
19
+ GPTQ_MARLIN_24_MAX_PARALLEL = 64
20
+
21
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
22
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
23
+
24
+ MARLIN_QQQ_TILE = 16
25
+ MARLIN_QQQ_MIN_THREAD_N = 64
26
+ MARLIN_QQQ_MIN_THREAD_K = 128
27
+ MARLIN_QQQ_MAX_PARALLEL = 16
28
+
29
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
30
+ MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
31
+ MARLIN_QQQ_SUPPORTED_SYM = [True]
32
+
33
+ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
34
+
35
+ # In case there is a performance issue with Marlin, the variable below can be
36
+ # changed to False, which allows Marlin to perform global reductions in fp16
37
+ # precision (instead of fp32), and therefore, save on some memory movements.
38
+ USE_FP32_REDUCE_DEFAULT = True
39
+
40
+
41
+ # For binary size and compile time, we don't support the same types for with and
42
+ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
+ # TODO: we may want to move this into the C++ so its closer to the actual impl
44
+ def query_marlin_supported_quant_types(
45
+ has_zp: bool, device_capability: Optional[int] = None
46
+ ):
47
+ if device_capability is None:
48
+ capability_tuple = torch.cuda.get_device_capability()
49
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
50
+
51
+ if device_capability < 80:
52
+ return []
53
+
54
+ if has_zp:
55
+ # AWQ style, unsigned + runtime zero-point
56
+ return [scalar_types.uint4, scalar_types.uint8]
57
+ else:
58
+ # GPTQ style, unsigned + symmetric bias
59
+ # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
+ # to add `scalar_types.float8_e4m3fn` here
61
+ return [scalar_types.uint4b8, scalar_types.uint8b128]
62
+
63
+
64
+ def _check_marlin_supported(
65
+ quant_type: ScalarType,
66
+ group_size: Optional[int],
67
+ has_zp: bool,
68
+ device_capability: Optional[int] = None,
69
+ ) -> Tuple[bool, Optional[str]]:
70
+
71
+ if device_capability is None:
72
+ capability_tuple = torch.cuda.get_device_capability()
73
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
+
75
+ supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
76
+
77
+ if quant_type not in supported_types:
78
+ return (
79
+ False,
80
+ f"Marlin does not support weight_bits = {quant_type}. "
81
+ f"Only types = {supported_types} "
82
+ f"are supported (for group_size = {group_size}, "
83
+ f"device_capability = {device_capability}, zp = {has_zp}).",
84
+ )
85
+ if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
+ return (
87
+ False,
88
+ f"Marlin does not support group_size = {group_size}. "
89
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
+ "are supported.",
91
+ )
92
+
93
+ return True, None
94
+
95
+
96
+ def check_marlin_supported(
97
+ quant_type: ScalarType,
98
+ group_size: int,
99
+ has_zp: bool = False,
100
+ device_capability: Optional[int] = None,
101
+ ) -> bool:
102
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
+ return cond
104
+
105
+
106
+ def verify_marlin_supported(
107
+ quant_type: ScalarType, group_size: int, has_zp: bool = False
108
+ ) -> None:
109
+ cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
+ if not cond:
111
+ assert err_msg is not None
112
+ raise ValueError(err_msg)
113
+
114
+
115
+ def verify_marlin_supports_shape(
116
+ output_size_per_partition: int,
117
+ input_size_per_partition: int,
118
+ input_size: int,
119
+ group_size: int,
120
+ ) -> None:
121
+
122
+ # Validate output_size_per_partition
123
+ if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
+ raise ValueError(
125
+ f"Weight output_size_per_partition = "
126
+ f"{output_size_per_partition} is not divisible by "
127
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
+ "Consider reducing tensor_parallel_size or running "
129
+ "with --quantization gptq."
130
+ )
131
+
132
+ # Validate input_size_per_partition
133
+ if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
+ raise ValueError(
135
+ f"Weight input_size_per_partition = "
136
+ f"{input_size_per_partition} is not divisible "
137
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
+ "Consider reducing tensor_parallel_size or running "
139
+ "with --quantization gptq."
140
+ )
141
+
142
+ if group_size < input_size and input_size_per_partition % group_size != 0:
143
+ raise ValueError(
144
+ f"Weight input_size_per_partition = {input_size_per_partition}"
145
+ f" is not divisible by group_size = {group_size}."
146
+ "Consider reducing tensor_parallel_size or running "
147
+ "with --quantization gptq."
148
+ )
149
+
150
+
151
+ def check_marlin_supports_shape(
152
+ output_size_per_partition: int,
153
+ input_size_per_partition: int,
154
+ input_size: int,
155
+ group_size: int,
156
+ ) -> Tuple[bool, Optional[str]]:
157
+ try:
158
+ verify_marlin_supports_shape(
159
+ output_size_per_partition, input_size_per_partition, input_size, group_size
160
+ )
161
+ except ValueError as e:
162
+ return False, e.__str__()
163
+ return True, None
164
+
165
+
166
+ def marlin_make_workspace(
167
+ output_size_per_partition: int, device: torch.device
168
+ ) -> torch.Tensor:
169
+ max_workspace_size = (
170
+ output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
+ ) * GPTQ_MARLIN_MAX_PARALLEL
172
+
173
+ return torch.zeros(
174
+ max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
+ )
176
+
177
+
178
+ def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
+ return (not act_order) or (act_order and not is_row_parallel)
180
+
181
+
182
+ def marlin_repeat_scales_on_all_ranks(
183
+ act_order: bool, group_size: int, is_row_parallel: bool
184
+ ) -> bool:
185
+ # Need to repeat scales on every rank if act_ordering or
186
+ # channelwise and RowParallelLinear
187
+ is_channelwise = group_size == -1
188
+ return act_order or (is_channelwise and is_row_parallel)
189
+
190
+
191
+ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
+ return torch.nn.Parameter(
193
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
+ )
195
+
196
+
197
+ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
+ return torch.nn.Parameter(
199
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
+ )
201
+
202
+
203
+ def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
+ return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
+
207
+
208
+ def get_scale_perms():
209
+ scale_perm: List[int] = []
210
+ for i in range(8):
211
+ scale_perm.extend([i + 8 * j for j in range(8)])
212
+ scale_perm_single: List[int] = []
213
+ for i in range(4):
214
+ scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
215
+ return scale_perm, scale_perm_single
216
+
217
+
218
+ def marlin_permute_scales(
219
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
+ ) -> torch.Tensor:
221
+
222
+ scale_perm, scale_perm_single = get_scale_perms()
223
+ if group_size < size_k and group_size != -1:
224
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
225
+ else:
226
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
227
+ s = s.reshape((-1, size_n)).contiguous()
228
+
229
+ return s
230
+
231
+
232
+ def marlin_moe_permute_scales(
233
+ s: torch.Tensor,
234
+ size_k: int,
235
+ size_n: int,
236
+ group_size: int,
237
+ ):
238
+ num_experts = s.shape[0]
239
+ output = torch.empty(
240
+ (num_experts, s.shape[1], s.shape[2]),
241
+ device=s.device,
242
+ dtype=s.dtype,
243
+ )
244
+
245
+ for e in range(num_experts):
246
+ output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
247
+ return output
248
+
249
+
250
+ def marlin_zero_points(
251
+ zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
+ ) -> torch.Tensor:
253
+ # Permute zero-points in a similar way to scales, but do not use the
254
+ # "single" permutation, since zero-points are applied on every MMA
255
+ scale_perm, _ = get_scale_perms()
256
+ zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
257
+
258
+ # Interleave column dim (for the dequantize code) and pack it to int32
259
+ if num_bits == 4:
260
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
261
+ elif num_bits == 8:
262
+ interleave = numpy.array([0, 2, 1, 3])
263
+ else:
264
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
265
+
266
+ zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
267
+ zp = zp.reshape((-1, size_n)).contiguous()
268
+ zp = pack_cols(zp, num_bits, size_k, size_n)
269
+
270
+ return zp
271
+
272
+
273
+ def awq_to_marlin_zero_points(
274
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
+ ) -> torch.Tensor:
276
+ # AWQ zero-points are quantized and packed on the column dim.
277
+ # In addition, the values are permuted based on dequantizer.
278
+ # Here we undo both of these, and then apply marlin permutation
279
+ # and pack it back.
280
+ q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
281
+
282
+ # Undo interleaving (use argsort(..) to get inverse perm)
283
+ if num_bits == 4:
284
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
285
+ elif num_bits == 8:
286
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
287
+ else:
288
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
289
+
290
+ q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
291
+ q_zp = q_zp.reshape((-1, size_n)).contiguous()
292
+
293
+ marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
294
+ return marlin_zp
295
+
296
+
297
+ def moe_awq_to_marlin_zero_points(
298
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
+ ):
300
+ num_experts = q_zp_packed.shape[0]
301
+ output = torch.empty(
302
+ (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
303
+ device=q_zp_packed.device,
304
+ dtype=q_zp_packed.dtype,
305
+ )
306
+ for e in range(num_experts):
307
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
308
+ return output
309
+
310
+
311
+ def apply_gptq_marlin_linear(
312
+ input: torch.Tensor,
313
+ weight: torch.Tensor,
314
+ weight_scale: torch.Tensor,
315
+ weight_zp: torch.Tensor,
316
+ g_idx: torch.Tensor,
317
+ g_idx_sort_indices: torch.Tensor,
318
+ workspace: torch.Tensor,
319
+ wtype: ScalarType,
320
+ output_size_per_partition: int,
321
+ input_size_per_partition: int,
322
+ is_k_full: bool,
323
+ bias: Optional[torch.Tensor] = None,
324
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
+ ) -> torch.Tensor:
326
+ reshaped_x = input.reshape(-1, input.shape[-1])
327
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
328
+
329
+ output = ops.gptq_marlin_gemm(
330
+ reshaped_x,
331
+ weight,
332
+ weight_scale,
333
+ weight_zp,
334
+ g_idx,
335
+ g_idx_sort_indices,
336
+ workspace,
337
+ wtype,
338
+ size_m=reshaped_x.shape[0],
339
+ size_n=output_size_per_partition,
340
+ size_k=input_size_per_partition,
341
+ is_k_full=is_k_full,
342
+ has_zp=False,
343
+ use_fp32_reduce=use_fp32_reduce,
344
+ is_zp_float=False,
345
+ )
346
+
347
+ if bias is not None:
348
+ output.add_(bias) # In-place add
349
+
350
+ return output.reshape(out_shape)
351
+
352
+
353
+ def apply_awq_marlin_linear(
354
+ input: torch.Tensor,
355
+ weight: torch.Tensor,
356
+ weight_scale: torch.Tensor,
357
+ weight_zp: torch.Tensor,
358
+ g_idx: torch.Tensor,
359
+ g_idx_sort_indices: torch.Tensor,
360
+ workspace: torch.Tensor,
361
+ quant_type: ScalarType,
362
+ output_size_per_partition: int,
363
+ input_size_per_partition: int,
364
+ bias: Optional[torch.Tensor] = None,
365
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
+ ) -> torch.Tensor:
367
+ reshaped_x = input.reshape(-1, input.shape[-1])
368
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
369
+
370
+ output = ops.gptq_marlin_gemm(
371
+ reshaped_x,
372
+ weight,
373
+ weight_scale,
374
+ weight_zp,
375
+ g_idx,
376
+ g_idx_sort_indices,
377
+ workspace,
378
+ quant_type,
379
+ size_m=reshaped_x.shape[0],
380
+ size_n=output_size_per_partition,
381
+ size_k=input_size_per_partition,
382
+ is_k_full=True,
383
+ has_zp=True,
384
+ use_fp32_reduce=use_fp32_reduce,
385
+ is_zp_float=False,
386
+ )
387
+
388
+ if bias is not None:
389
+ output.add_(bias) # In-place add
390
+
391
+ return output.reshape(out_shape)
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ import quantization as ops
6
+
7
+ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
+
9
+
10
+ def is_fp8_marlin_supported():
11
+ capability = torch.cuda.get_device_capability()
12
+ capability = capability[0] * 10 + capability[1]
13
+ return capability >= 80
14
+
15
+
16
+ def apply_fp8_marlin_linear(
17
+ input: torch.Tensor,
18
+ weight: torch.Tensor,
19
+ weight_scale: torch.Tensor,
20
+ workspace: torch.Tensor,
21
+ size_n: int,
22
+ size_k: int,
23
+ bias: Optional[torch.Tensor],
24
+ ) -> torch.Tensor:
25
+ # For GPUs that lack FP8 hardware support, we can leverage the
26
+ # Marlin kernel for fast weight-only FP8 quantization
27
+
28
+ reshaped_x = input.reshape(-1, input.shape[-1])
29
+ out_shape = input.shape[:-1] + (size_n,)
30
+
31
+ output = ops.fp8_marlin_gemm(
32
+ a=reshaped_x,
33
+ b_q_weight=weight,
34
+ b_scales=weight_scale,
35
+ workspace=workspace,
36
+ num_bits=8,
37
+ size_m=reshaped_x.shape[0],
38
+ size_n=size_n,
39
+ size_k=size_k,
40
+ )
41
+
42
+ if bias is not None:
43
+ output.add_(bias) # In-place add
44
+
45
+ return output.reshape(out_shape)
46
+
47
+
48
+ def prepare_fp8_layer_for_marlin(
49
+ layer: torch.nn.Module, strategy: str = "tensor"
50
+ ) -> None:
51
+ part_size_n = layer.output_size_per_partition
52
+ part_size_k = layer.input_size_per_partition
53
+
54
+ device = layer.weight.device
55
+
56
+ # WORKSPACE
57
+ layer.workspace = marlin_make_workspace(part_size_n, device)
58
+
59
+ # WEIGHT
60
+ # Repack weights to marlin format
61
+ marlin_qweight = ops.gptq_marlin_repack(
62
+ b_q_weight=pack_fp8_to_int32(layer.weight),
63
+ perm=torch.empty(0, dtype=torch.int, device=device),
64
+ size_k=part_size_k,
65
+ size_n=part_size_n,
66
+ num_bits=8,
67
+ )
68
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
+
70
+ # WEIGHT SCALES
71
+ scales = layer.weight_scale.to(layer.orig_dtype)
72
+ # Permute scales
73
+ marlin_scales = marlin_permute_scales(
74
+ s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
+ )
76
+ layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
+
78
+
79
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
+ """
81
+ Repack FP8 weights to gptq format (packed int32 elements)
82
+ """
83
+ assert fp8_tensor.dtype == torch.float8_e4m3fn
84
+ assert fp8_tensor.shape[0] % 4 == 0
85
+
86
+ # Reshape to prepare for packing
87
+ reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
88
+
89
+ # Convert fp8 to uint8 (byte) representation
90
+ byte_tensor = reshaped.view(torch.uint8)
91
+
92
+ # Pack 4 uint8 values into one int32
93
+ packed = (
94
+ byte_tensor[:, 0].to(torch.int32)
95
+ | (byte_tensor[:, 1].to(torch.int32) << 8)
96
+ | (byte_tensor[:, 2].to(torch.int32) << 16)
97
+ | (byte_tensor[:, 3].to(torch.int32) << 24)
98
+ )
99
+
100
+ return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from quantization.scalar_type import ScalarType
9
+
10
+ from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
11
+ from .quant_utils import (
12
+ get_pack_factor,
13
+ gptq_quantize_weights,
14
+ quantize_weights,
15
+ sort_weights,
16
+ )
17
+
18
+
19
+ class MarlinWorkspace:
20
+
21
+ def __init__(self, out_features, min_thread_n, max_parallel):
22
+ assert (
23
+ out_features % min_thread_n == 0
24
+ ), "out_features = {} is undivisible by min_thread_n = {}".format(
25
+ out_features, min_thread_n
26
+ )
27
+
28
+ max_workspace_size = (out_features // min_thread_n) * max_parallel
29
+
30
+ self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
31
+
32
+
33
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
34
+ assert q_w.shape == (size_k, size_n)
35
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
36
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
37
+
38
+ # Permute weights to 16x64 marlin tiles
39
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
40
+ q_w = q_w.permute((0, 2, 1, 3))
41
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
42
+
43
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
44
+
45
+ return q_w
46
+
47
+
48
+ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
49
+ # Permute
50
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
51
+
52
+ # Pack
53
+ pack_factor = get_pack_factor(num_bits)
54
+ orig_device = q_w.device
55
+
56
+ q_w = q_w.cpu().numpy().astype(np.uint32)
57
+
58
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
59
+ for i in range(pack_factor):
60
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
61
+
62
+ q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
63
+
64
+ return q_packed
65
+
66
+
67
+ def get_weight_perm(num_bits: int):
68
+ perm_list: List[int] = []
69
+ for i in range(32):
70
+ perm1: List[int] = []
71
+ col = i // 4
72
+ for block in [0, 1]:
73
+ for row in [
74
+ 2 * (i % 4),
75
+ 2 * (i % 4) + 1,
76
+ 2 * (i % 4 + 4),
77
+ 2 * (i % 4 + 4) + 1,
78
+ ]:
79
+ perm1.append(16 * row + col + 8 * block)
80
+ for j in range(4):
81
+ perm_list.extend([p + 256 * j for p in perm1])
82
+
83
+ perm = np.array(perm_list)
84
+
85
+ if num_bits == 4:
86
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
87
+ elif num_bits == 8:
88
+ interleave = np.array([0, 2, 1, 3])
89
+ else:
90
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
91
+
92
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
93
+ perm = torch.from_numpy(perm)
94
+ return perm
95
+
96
+
97
+ def marlin_quantize(
98
+ w: torch.Tensor,
99
+ quant_type: ScalarType,
100
+ group_size: int,
101
+ act_order: bool,
102
+ test_perm: Optional[torch.Tensor] = None,
103
+ ):
104
+ size_k, size_n = w.shape
105
+ num_bits = quant_type.size_bits
106
+
107
+ # Normalize group_size
108
+ if group_size == -1:
109
+ group_size = size_k
110
+ assert group_size <= size_k
111
+
112
+ # Quantize (and apply act_order if provided)
113
+ w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
114
+ w, quant_type, group_size, act_order, test_perm
115
+ )
116
+
117
+ # For act_order, sort the "weights" and "g_idx" so that group ids are
118
+ # increasing
119
+ sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
120
+ if act_order:
121
+ q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
122
+
123
+ # Reformat to marlin
124
+ weight_perm = get_weight_perm(num_bits)
125
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
126
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
127
+
128
+ # Create result
129
+ res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
130
+ for i in range(len(res_list)):
131
+ res_list[i] = res_list[i].to(w.device)
132
+
133
+ return res_list
134
+
135
+
136
+ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
137
+ size_k, size_n = w.shape
138
+
139
+ # Normalize group_size
140
+ if group_size == -1:
141
+ group_size = size_k
142
+ assert group_size <= size_k
143
+
144
+ # Detect num groups
145
+ assert size_k % group_size == 0
146
+ num_groups = size_k // group_size
147
+
148
+ # Quantize with zp
149
+ w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
150
+
151
+ # Reformat to marlin
152
+ weight_perm = get_weight_perm(quant_type.size_bits)
153
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
154
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
155
+ marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
156
+
157
+ # Create result
158
+ res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
159
+ for i in range(len(res_list)):
160
+ res_list[i] = res_list[i].to(w.device)
161
+
162
+ return res_list
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ import random
4
+ from typing import List
5
+
6
+ import numpy
7
+ import torch
8
+
9
+ from quantization.scalar_type import ScalarType
10
+
11
+ from .marlin_utils_test import marlin_weights
12
+ from .quant_utils import gptq_quantize_weights
13
+
14
+
15
+ # This is PyTorch implementation of main part of reorder_meta()
16
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
17
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
18
+ # GEMM decides upon layout of this matrix, and at the moment for the
19
+ # sparse GEMM executed on tensor cores, this is layout described by
20
+ # ColumnMajorInterleaved<2> data structure, in
21
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
22
+ # reordering of meta matrix into meta_reordered matrix calculated
23
+ # according to these segments of CUTLASS code is re-implemented here.
24
+ # Note that this calculation produces offsets for scattering metadata
25
+ # matrix elements into reordered metadata matrix elements (or,
26
+ # equivalently, for gathering reordered metadata matrix element back
27
+ # into metadata matrix elements).
28
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
29
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
30
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
31
+
32
+ # Reorder the rows, then swizzle the 2x2 blocks.
33
+ group_x = 64
34
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
35
+
36
+ dst_rows = (
37
+ dst_rows // group_x * group_x
38
+ + (dst_rows % 2) * 2
39
+ + (dst_rows % 8) // 4
40
+ + ((dst_rows % group_y) % 4) // 2 * 32
41
+ + ((dst_rows % group_x) // 8) * 4
42
+ )
43
+
44
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
45
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
46
+ dst_rows += topright - bottomleft
47
+ dst_cols -= topright - bottomleft
48
+
49
+ # Assumed that meta tensor is to be stored in CUTLASS
50
+ # InterleavedColumnMajor layout, and reverse engineered
51
+ # corresponding code to store values into this tensor.
52
+ interleave = 2
53
+ cols_maj = dst_cols // interleave
54
+ cols_min = dst_cols % interleave
55
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
56
+
57
+
58
+ # This function converts dense matrix into sparse semi-structured
59
+ # representation, producing "compressed" matrix, in the layout used by
60
+ # CUTLASS backend, and corresponding metadata matrix.
61
+ def sparse_semi_structured_from_dense_cutlass(dense):
62
+ if dense.dim() != 2:
63
+ raise RuntimeError(
64
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
65
+ )
66
+
67
+ m, k = dense.shape
68
+ device = dense.device
69
+
70
+ meta_dtype = torch.int8
71
+ if dense.dtype == torch.int8:
72
+ meta_dtype = torch.int32
73
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
74
+ meta_dtype = torch.int16
75
+ else:
76
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
77
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
78
+ if quadbits_per_meta_elem not in (4, 8):
79
+ raise RuntimeError("Invalid number of elements per meta element calculated")
80
+
81
+ if meta_dtype == torch.int32:
82
+ if m % 16 != 0:
83
+ raise RuntimeError(
84
+ f"Number of rows of dense matrix {m} must be divisible by 16"
85
+ )
86
+ else:
87
+ if m % 32 != 0:
88
+ raise RuntimeError(
89
+ f"Number of rows of dense matrix {m} must be divisible by 32"
90
+ )
91
+ if k % (4 * quadbits_per_meta_elem) != 0:
92
+ raise RuntimeError(
93
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
94
+ )
95
+
96
+ if dense.dtype != torch.float:
97
+ ksparse = 4
98
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
99
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
100
+ else:
101
+ ksparse = 2
102
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
103
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
104
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
105
+
106
+ # Encoding quadruples of True/False values as follows:
107
+ # [True, True, False, False] -> 0b0100
108
+ # [True, False, True, False] -> 0b1000
109
+ # [False, True, True, False] -> 0b1001
110
+ # [True, False, False, True ] -> 0b1100
111
+ # [False, True, False, True ] -> 0b1101
112
+ # [False, False, True, True ] -> 0b1110
113
+ # Thus, lower two bits in the encoding are index of the True value
114
+ # at the lowest index in the quadruple, and the higher two bits in
115
+ # the encoding are index of the other True value in the quadruple.
116
+ # In case there are less than two True values, than False value or
117
+ # values at some index or indices are considered True for the
118
+ # encoding. In case there are more than two True values, then the
119
+ # excess True value(s) at some indices are considered False for
120
+ # the encoding. The exact encodings used for these cases are as
121
+ # follows:
122
+ # [False, False, False, False] -> 0b1110
123
+ # [False, False, False, True ] -> 0b1110
124
+ # [False, False, True, False] -> 0b1110
125
+ # [False, True, False, False] -> 0b1001
126
+ # [False, True, True, True ] -> 0b1101
127
+ # [True, False, False, False] -> 0b1000
128
+ # [True, False, True, True ] -> 0b1100
129
+ # [True, True, False, True ] -> 0b0100
130
+ # [True, True, True, False] -> 0b0100
131
+ # [True, True, True, True ] -> 0b0100
132
+ # These particular encodings are chosen, with the help of Espresso
133
+ # logic minimizer software, for the purpose of minimization of
134
+ # corresponding Boolean functions, that translate non-zero flags
135
+ # into encoding bits. Note also possible choices for the first
136
+ # and last of these encodings were limited only to (0b0100,
137
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
138
+ # case.
139
+
140
+ expr0 = m0 & m1
141
+ expr1 = ~m0 & m1
142
+ expr2 = ~m0 & ~m1
143
+ bit0 = expr1
144
+ bit1 = expr2
145
+ bit2 = expr0 | expr2 | m3
146
+ bit3 = expr1 | ~m1
147
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
148
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
149
+
150
+ if dense.dtype != torch.float:
151
+ sparse0 = dense_4.gather(
152
+ -1, idxs0.unsqueeze(-1)
153
+ ) # type: ignore[possibly-undefined]
154
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
155
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
156
+ else:
157
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
158
+ m, k // 2
159
+ ) # type: ignore[possibly-undefined]
160
+
161
+ meta_4 = idxs0 | (idxs1 << 2)
162
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
163
+
164
+ if quadbits_per_meta_elem == 4:
165
+ meta = (
166
+ meta_n[:, :, 0]
167
+ | (meta_n[:, :, 1] << 4)
168
+ | (meta_n[:, :, 2] << 8)
169
+ | (meta_n[:, :, 3] << 12)
170
+ )
171
+ elif quadbits_per_meta_elem == 8:
172
+ meta = (
173
+ meta_n[:, :, 0]
174
+ | (meta_n[:, :, 1] << 4)
175
+ | (meta_n[:, :, 2] << 8)
176
+ | (meta_n[:, :, 3] << 12)
177
+ | (meta_n[:, :, 4] << 16)
178
+ | (meta_n[:, :, 5] << 20)
179
+ | (meta_n[:, :, 6] << 24)
180
+ | (meta_n[:, :, 7] << 28)
181
+ )
182
+
183
+ # Reorder meta tensor elements.
184
+ meta_reordered = meta.new_empty(
185
+ (m * meta_ncols,)
186
+ ) # type: ignore[possibly-undefined]
187
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
188
+ m, meta_ncols, meta_dtype, device
189
+ )
190
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
191
+
192
+ return (sparse, meta_reordered.view(m, meta_ncols))
193
+
194
+
195
+ # This function performs reverse of the function above - it
196
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
197
+ # in the layout used by CUTLASS backend, and accompanying metadata
198
+ # matrix.
199
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
200
+ if sparse.dim() != 2:
201
+ raise RuntimeError(
202
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
203
+ )
204
+
205
+ m, k = sparse.shape
206
+ device = sparse.device
207
+
208
+ if meta_reordered.dim() != 2:
209
+ raise RuntimeError(
210
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
211
+ )
212
+ if meta_reordered.device != device:
213
+ raise RuntimeError(
214
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
215
+ )
216
+
217
+ meta_dtype = meta_reordered.dtype
218
+ if meta_dtype not in (torch.int16, torch.int32):
219
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
220
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
221
+
222
+ ksparse = 4 if sparse.dtype != torch.float else 2
223
+
224
+ meta_nrows, meta_ncols = meta_reordered.shape
225
+ if meta_nrows != m:
226
+ raise RuntimeError(
227
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
228
+ )
229
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
230
+ raise RuntimeError(
231
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
232
+ "expected according to the number of columns of meta matrix"
233
+ )
234
+
235
+ # Undo meta tensor elements reordering.
236
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
237
+ m, meta_ncols, meta_dtype, device
238
+ )
239
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
240
+
241
+ # Unpack sparse tensor back to original dense tensor, using
242
+ # information provided by meta tensor. Note that torch.float
243
+ # datatype is handled pretty much the same as
244
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
245
+ # value is encoded as if underlying 8 bytes contain four
246
+ # torch.half/torch.bfloat16 values, where either first two or last
247
+ # two are zeros.
248
+ meta_2 = torch.empty(
249
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
250
+ dtype=meta_dtype,
251
+ device=device,
252
+ )
253
+ if quadbits_per_meta_elem == 4:
254
+ meta_2[:, :, 0] = meta & 0b11
255
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
256
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
257
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
258
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
259
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
260
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
261
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
262
+ elif quadbits_per_meta_elem == 8:
263
+ meta_2[:, :, 0] = meta & 0b11
264
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
265
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
266
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
267
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
268
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
269
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
270
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
271
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
272
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
273
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
274
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
275
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
276
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
277
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
278
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
279
+
280
+ dense_offsets = meta_2.view(-1) + (
281
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
282
+ ).view(-1, 1).repeat(1, 2).view(-1)
283
+
284
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
285
+ if sparse.dtype != torch.float:
286
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
287
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
288
+ else:
289
+ dense.view(torch.half).scatter_(
290
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
291
+ )
292
+
293
+ return dense.view(m, 2 * k)
294
+
295
+
296
+ def mask_creator(tensor):
297
+ """
298
+ Class for creating N:M sparsity masks.
299
+ Masks will be created using the N:M ratio, where for every block of
300
+ M weights, N will be pruned based on ranked weight value. Each mask
301
+ will correspond to the given tensor.
302
+
303
+ :param N: The number of weights in a group to keep
304
+ :param M: The size of a weight group
305
+ """
306
+ N = 2
307
+ M = 4
308
+
309
+ mask = None
310
+ # for i, tensor in enumerate(tensors):
311
+ if tensor.numel() % M != 0:
312
+ raise ValueError(
313
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
314
+ )
315
+
316
+ num_groups = tensor.numel() // M
317
+
318
+ # N:M sparsity for linear layers
319
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
320
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
321
+
322
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
323
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
324
+
325
+ return mask
326
+
327
+
328
+ def inject_24(w, size_k, size_n):
329
+ assert w.shape == (size_k, size_n)
330
+
331
+ mask = mask_creator(w.t()).t().cuda().bool()
332
+
333
+ return (mask * w).contiguous(), mask.contiguous()
334
+
335
+
336
+ def check_24(w, num_rows_to_sample=50, _verbose=False):
337
+ BLOCK_SIZE = 4
338
+ MAX_NON_ZEROS = 2
339
+
340
+ w = w.t().contiguous()
341
+
342
+ print("check_24: w.shape = {}".format(w.shape))
343
+
344
+ num_rows, num_cols = w.shape
345
+ sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
346
+ if _verbose:
347
+ print(f"Sampled row idxs = {sampled_row_idxs}")
348
+
349
+ total_segments = 0
350
+ non_24_segments = 0
351
+ for i in sampled_row_idxs:
352
+ for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
353
+ total_segments += 1
354
+ block = w[i, j : j + BLOCK_SIZE]
355
+ num_nonzero = torch.count_nonzero(block)
356
+ if num_nonzero > MAX_NON_ZEROS:
357
+ print("i = {} j = {} block = {}".format(i, j, block))
358
+ non_24_segments += 1
359
+
360
+ print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
361
+
362
+
363
+ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
364
+ assert q_24.shape == (size_k, size_n)
365
+
366
+ # Remove bias to normalize over 0
367
+ q_24_no_zp = q_24 - wtype.bias
368
+
369
+ # Compress
370
+ q_24_no_zp = q_24_no_zp.t().contiguous()
371
+ q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
372
+ q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
373
+
374
+ # Restore bias
375
+ q_24_comp = q_24_no_zp_comp + wtype.bias
376
+
377
+ # Resize meta to its actual shape (without moving any data)
378
+ meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
379
+
380
+ return q_24_comp, meta
381
+
382
+
383
+ def get_scale_perms_24():
384
+ scale_perm: List[int] = []
385
+ for i in range(8):
386
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
387
+ scale_perm_single: List[int] = []
388
+ for i in range(8):
389
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
390
+ return scale_perm, scale_perm_single
391
+
392
+
393
+ def get_weight_perm_24(num_bits: int):
394
+ perm_list: List[int] = []
395
+ for i in range(32):
396
+ perm1: List[int] = []
397
+ col = i // 4
398
+ col_o = col // 2
399
+ for block in [0, 1]:
400
+ for row in [
401
+ 2 * (i % 4),
402
+ 2 * (i % 4) + 1,
403
+ 2 * (i % 4 + 4),
404
+ 2 * (i % 4 + 4) + 1,
405
+ ]:
406
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
407
+ for j in range(4):
408
+ perm_list.extend([p + 1 * j for p in perm1])
409
+ perm = numpy.array(perm_list)
410
+
411
+ if num_bits == 4:
412
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
413
+ elif num_bits == 8:
414
+ interleave = numpy.array([0, 2, 1, 3])
415
+ else:
416
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
417
+
418
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
419
+ perm = torch.from_numpy(perm)
420
+ return perm
421
+
422
+
423
+ def marlin_permute_scales_24(
424
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
425
+ ) -> torch.Tensor:
426
+
427
+ scale_perm, scale_perm_single = get_scale_perms_24()
428
+ if group_size < size_k and group_size != -1:
429
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
430
+ else:
431
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
432
+ s = s.reshape((-1, size_n)).contiguous()
433
+
434
+ return s
435
+
436
+
437
+ def marlin_24_quantize(
438
+ w: torch.Tensor,
439
+ quant_type: ScalarType,
440
+ group_size: int,
441
+ ):
442
+ size_k, size_n = w.shape
443
+
444
+ # Normalize group_size
445
+ if group_size == -1:
446
+ group_size = size_k
447
+ assert group_size <= size_k
448
+
449
+ # Inject 2:4 sparsity
450
+ w_24, mask_24 = inject_24(w, size_k, size_n)
451
+
452
+ # Quantize
453
+ w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
454
+ w_24, quant_type, group_size, act_order=False
455
+ )
456
+
457
+ # Compress quantized weight
458
+ q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
459
+ size_k_comp = size_k // 2
460
+
461
+ # Reformat to marlin
462
+ weight_perm = get_weight_perm_24(quant_type.size_bits)
463
+ marlin_24_q_w_comp = marlin_weights(
464
+ q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
465
+ )
466
+ marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
467
+
468
+ # Create result
469
+ res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
470
+ for i in range(len(res_list)):
471
+ res_list[i] = res_list[i].to(w.device)
472
+
473
+ return res_list
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ from .marlin_utils_test import marlin_permute_weights
7
+ from .quant_utils import get_pack_factor, qqq_quantize_weights
8
+
9
+
10
+ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
11
+ # Permute
12
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
13
+
14
+ # Pack
15
+ pack_factor = get_pack_factor(num_bits)
16
+ orig_device = q_w.device
17
+
18
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
19
+
20
+ q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
21
+ dtype=numpy.uint32)
22
+ if group_size == size_k:
23
+ for i in range(pack_factor):
24
+ q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
25
+ else:
26
+ for i in range(pack_factor):
27
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
28
+
29
+ q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
30
+
31
+ return q_packed
32
+
33
+
34
+ def get_qqq_scale_perms():
35
+ scale_perm: List[int] = []
36
+ for i in range(8):
37
+ scale_perm.extend([i + 8 * j for j in range(8)])
38
+ scale_perm_single: List[int] = []
39
+ for i in range(4):
40
+ scale_perm_single.extend(
41
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
42
+ return scale_perm, scale_perm_single
43
+
44
+
45
+ # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
46
+ def get_qqq_weight_perm(num_bits: int, quant_type: str):
47
+ perm_list: List[int] = []
48
+ for i in range(32):
49
+ perm1: List[int] = []
50
+ col = i // 4
51
+ for block in [0, 1]:
52
+ for row in [
53
+ 4 * (i % 4),
54
+ 4 * (i % 4) + 1,
55
+ 4 * (i % 4) + 2,
56
+ 4 * (i % 4) + 3,
57
+ ]:
58
+ perm1.append(16 * row + col + 8 * block)
59
+ for j in range(4):
60
+ perm_list.extend([p + 256 * j for p in perm1])
61
+
62
+ perm = numpy.array(perm_list)
63
+
64
+ assert quant_type in ["per-channel",
65
+ "per-group"], "not supported quantization type"
66
+ if num_bits == 4:
67
+ if quant_type == "per-channel":
68
+ interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
69
+ else:
70
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
71
+ else:
72
+ raise Exception("num_bits must be 4, got {}".format(num_bits))
73
+
74
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
75
+ perm = torch.from_numpy(perm)
76
+ return perm
77
+
78
+
79
+ def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
80
+ scale_perm, scale_perm_single = get_qqq_scale_perms()
81
+ if group_size < size_k and group_size != -1:
82
+ s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
83
+ s_channel = s_channel.reshape(
84
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
85
+ s_group = s_group.reshape((-1, size_n)).contiguous()
86
+ else:
87
+ s_channel = s_channel.reshape(
88
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
89
+ s_channel = s_channel.reshape((-1, size_n)).contiguous()
90
+
91
+ return s_group, s_channel
92
+
93
+
94
+ def marlin_qqq_quantize(
95
+ w: torch.Tensor,
96
+ num_bits: int,
97
+ group_size: int,
98
+ ):
99
+ size_k, size_n = w.shape
100
+
101
+ # Normalize group_size
102
+ if group_size == -1:
103
+ group_size = size_k
104
+ assert group_size <= size_k
105
+ quant_type = "per-channel" if group_size == size_k else "per-group"
106
+
107
+ # Quantize
108
+ w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
109
+ w, num_bits, group_size)
110
+
111
+ # Reformat to marlin_qqq
112
+ weight_perm = get_qqq_weight_perm(num_bits, quant_type)
113
+ marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
114
+ weight_perm, group_size)
115
+ marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
116
+ s_group, s_channel, size_k, size_n, group_size)
117
+
118
+ # Create result
119
+ res_list = [
120
+ w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
121
+ ]
122
+ for i in range(len(res_list)):
123
+ res_list[i] = res_list[i].to(w.device)
124
+
125
+ return res_list
build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is used for /tests and /benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy
6
+ import torch
7
+
8
+ from quantization.scalar_type import ScalarType, scalar_types
9
+
10
+ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
+ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
12
+
13
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
14
+
15
+ # Note: this is a hack. We should update each model to register the
16
+ # stacked params and get it from there instead in a future PR.
17
+ # fused_name: List[shard_name]
18
+ FUSED_LAYER_NAME_MAPPING = {
19
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
20
+ "gate_up_proj": ["gate_proj", "up_proj"],
21
+ }
22
+
23
+
24
+ def pack_quantized_values_into_int32(
25
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
26
+ ):
27
+ # move dim to pack to the end
28
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
29
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
30
+ w_q_perm = w_q.permute(perm)
31
+
32
+ pack_factor = 32 // wtype.size_bits
33
+ mask = (1 << wtype.size_bits) - 1
34
+
35
+ new_shape_perm = list(w_q_perm.shape)
36
+ assert w_q_perm.shape[-1] % pack_factor == 0
37
+ new_shape_perm[-1] //= pack_factor
38
+
39
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
40
+ for i in range(pack_factor):
41
+ res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
42
+
43
+ return res.permute(inv_perm)
44
+
45
+
46
+ def unpack_quantized_values_into_int32(
47
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
48
+ ):
49
+ # move dim to pack to the end
50
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
51
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
52
+ w_q_perm = w_q.permute(perm)
53
+
54
+ pack_factor = 32 // wtype.size_bits
55
+ mask = (1 << wtype.size_bits) - 1
56
+
57
+ new_shape_perm = list(w_q_perm.shape)
58
+ new_shape_perm[-1] *= pack_factor
59
+
60
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
61
+ for i in range(pack_factor):
62
+ res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
63
+
64
+ return res.permute(inv_perm)
65
+
66
+
67
+ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
68
+ # prefix: model.layers.0.self_attn.q_proj
69
+ # proj_name: q_proj
70
+ proj_name = prefix.split(".")[-1]
71
+ if proj_name in FUSED_LAYER_NAME_MAPPING:
72
+ shard_prefixes = [
73
+ prefix.replace(proj_name, shard_proj_name)
74
+ for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
75
+ ]
76
+
77
+ is_skipped = None
78
+ for shard_prefix in shard_prefixes:
79
+ is_shard_skipped = shard_prefix in ignored_layers
80
+
81
+ if is_skipped is None:
82
+ is_skipped = is_shard_skipped
83
+ elif is_shard_skipped != is_skipped:
84
+ raise ValueError(
85
+ f"Detected some but not all shards of {prefix} "
86
+ "are quantized. All shards of fused layers "
87
+ "to have the same precision."
88
+ )
89
+ else:
90
+ is_skipped = prefix in ignored_layers
91
+
92
+ assert is_skipped is not None
93
+ return is_skipped
94
+
95
+
96
+ def get_pack_factor(num_bits):
97
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
98
+ return 32 // num_bits
99
+
100
+
101
+ def permute_rows(
102
+ q_w: torch.Tensor,
103
+ w_ref: torch.Tensor,
104
+ group_size: int,
105
+ test_perm: Optional[torch.Tensor] = None,
106
+ ):
107
+ assert q_w.shape == w_ref.shape
108
+
109
+ orig_device = q_w.device
110
+ k_size, _ = q_w.shape
111
+
112
+ g_idx = torch.zeros((k_size,), dtype=torch.int32)
113
+ for i in range(k_size):
114
+ g_idx[i] = i // group_size
115
+
116
+ # Simulate act_order by doing a random permutation on K
117
+ rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
118
+
119
+ g_idx = g_idx[rand_perm].contiguous()
120
+ q_w = q_w[rand_perm, :].contiguous()
121
+ w_ref = w_ref[rand_perm, :].contiguous()
122
+
123
+ return (
124
+ w_ref.to(device=orig_device),
125
+ q_w.to(device=orig_device),
126
+ g_idx.to(device=orig_device),
127
+ rand_perm.to(device=orig_device),
128
+ )
129
+
130
+
131
+ def quantize_weights(
132
+ w: torch.Tensor,
133
+ quant_type: ScalarType,
134
+ group_size: Optional[int],
135
+ zero_points: bool = False,
136
+ ref_zero_points_after_scales: bool = False,
137
+ ):
138
+ assert (
139
+ quant_type.is_integer()
140
+ ), "Floating point quantization may work but has not been tested"
141
+ assert not zero_points or group_size is not None, (
142
+ "to have group zero points, group_size must be provided "
143
+ "(-1 group_size is channelwise)"
144
+ )
145
+
146
+ orig_device = w.device
147
+ orig_type = w.dtype
148
+ size_k, size_n = w.shape
149
+
150
+ assert w.is_floating_point(), "w must be float"
151
+
152
+ if group_size == -1:
153
+ group_size = size_k
154
+
155
+ # Reshape to [groupsize, -1]
156
+ if group_size is not None and group_size < size_k:
157
+ w = w.reshape((-1, group_size, size_n))
158
+ w = w.permute(1, 0, 2)
159
+ w = w.reshape((group_size, -1))
160
+
161
+ # Compute scale for each group
162
+ max_val = torch.max(w, 0, keepdim=True).values
163
+ min_val = torch.min(w, 0, keepdim=True).values
164
+
165
+ max_q_val = quant_type.max()
166
+ min_q_val = quant_type.min()
167
+
168
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
169
+ maybe_w_zp = None
170
+ if group_size is not None:
171
+ if zero_points:
172
+ assert not quant_type.is_signed() and quant_type.max() > 0
173
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
174
+ maybe_w_zp = (
175
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
176
+ )
177
+ else:
178
+ # If the bias is such that there are no possible negative/positive
179
+ # values, set the max value to inf to avoid divide by 0
180
+ w_s = torch.max(
181
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
182
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
183
+ )
184
+
185
+ # Quantize
186
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
187
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
188
+
189
+ # Compute ref (dequantized)
190
+ # For some kernels (namely Machete) the zero-points are applied after the
191
+ # scales are applied, for this case computing the reference in similar way
192
+ # allows us to use tighter error tolerances in our unit tests.
193
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
194
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
195
+ else:
196
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
197
+
198
+ if quant_type.has_bias():
199
+ w_q += quant_type.bias
200
+
201
+ # Restore original shapes
202
+ if group_size is not None and group_size < size_k:
203
+
204
+ def reshape_w(w):
205
+ w = w.reshape((group_size, -1, size_n))
206
+ w = w.permute(1, 0, 2)
207
+ w = w.reshape((size_k, size_n)).contiguous()
208
+ return w
209
+
210
+ w_q = reshape_w(w_q)
211
+ w_ref = reshape_w(w_ref)
212
+ w_s = w_s.reshape((-1, size_n)).contiguous()
213
+
214
+ if maybe_w_zp is not None:
215
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
216
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
217
+
218
+ return (
219
+ w_ref.to(device=orig_device),
220
+ w_q.to(device=orig_device),
221
+ w_s if group_size is not None else None,
222
+ maybe_w_zp,
223
+ )
224
+
225
+
226
+ def gptq_quantize_weights(
227
+ w: torch.Tensor,
228
+ quant_type: ScalarType,
229
+ group_size: int,
230
+ act_order: bool,
231
+ test_perm: Optional[torch.Tensor] = None,
232
+ ):
233
+ size_k, _ = w.shape
234
+
235
+ assert w.is_floating_point(), "w must be float"
236
+ assert (
237
+ quant_type in SUPPORTED_GPTQ_QUANT_TYPES
238
+ ), f"Unsupported gptq type = {quant_type}"
239
+ assert group_size in SUPPORTED_GROUP_SIZES + [
240
+ size_k
241
+ ], f"Unsupported groupsize = {group_size}"
242
+
243
+ w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
244
+
245
+ # Apply act_order
246
+ g_idx = torch.empty(0, dtype=torch.int, device=w.device)
247
+ rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
248
+ if act_order:
249
+ assert (
250
+ group_size < size_k
251
+ ), "For act_order, groupsize = {} must be less than size_k = {}".format(
252
+ group_size, size_k
253
+ )
254
+
255
+ w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
256
+
257
+ return w_ref, w_q, w_s, g_idx, rand_perm
258
+
259
+
260
+ # QQQ employs different quant schemes for per-group and
261
+ # per-channel quantization.
262
+ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
263
+ orig_device = w.device
264
+ size_k, size_n = w.shape
265
+
266
+ assert w.is_floating_point(), "w must be float"
267
+ assert (
268
+ num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
269
+ ), f"Unsupported num_bits = {num_bits}"
270
+ assert group_size in SUPPORTED_GROUP_SIZES + [
271
+ size_k
272
+ ], f"Unsupported groupsize = {group_size}"
273
+
274
+ if group_size == -1:
275
+ group_size = size_k
276
+ assert group_size <= size_k
277
+
278
+ if group_size < size_k:
279
+ # Reshape to [groupsize, -1]
280
+ w = w.reshape((-1, group_size, size_n))
281
+ w = w.permute(1, 0, 2)
282
+ w = w.reshape((group_size, -1))
283
+
284
+ max_q_val = 2**num_bits - 1
285
+ half_q_val = (max_q_val + 1) // 2
286
+
287
+ # Compute scale for each group
288
+ s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
289
+ s_group *= 2 / max_q_val # 2 => symmetric
290
+
291
+ # Quantize
292
+ q_w = torch.round(w / s_group).int()
293
+ q_w += half_q_val
294
+ q_w = torch.clamp(q_w, 0, max_q_val)
295
+ # Compute ref (dequantized)
296
+ w_ref = (q_w - half_q_val).half() * s_group
297
+
298
+ # Restore original shapes
299
+ def reshape_w(w):
300
+ w = w.reshape((group_size, -1, size_n))
301
+ w = w.permute(1, 0, 2)
302
+ w = w.reshape((size_k, size_n)).contiguous()
303
+ return w
304
+
305
+ q_w = reshape_w(q_w)
306
+ w_ref = reshape_w(w_ref)
307
+
308
+ # Compute int8 quantization scale for each channel
309
+ s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
310
+ s_channel /= 127.0
311
+ t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
312
+ w_ref = t_int8.half() * s_channel
313
+ s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
314
+
315
+ # Fuse scales
316
+ s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
317
+ dtype=torch.half
318
+ )
319
+ else:
320
+ max_q_val = 2 ** (num_bits - 1) - 1
321
+
322
+ # Compute scale for each channel
323
+ s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
324
+ s_channel /= max_q_val
325
+
326
+ # Quantize
327
+ q_w = torch.round(w / s_channel).int()
328
+ q_w = torch.clamp(q_w, -max_q_val, max_q_val)
329
+ # Compute ref (dequantized)
330
+ w_ref = q_w.half() * s_channel
331
+
332
+ s_group = torch.tensor([], dtype=torch.half)
333
+ # div 2 ** (8 - self.bits)) to offset right shift in unpacking
334
+ s_channel /= 2 ** (8 - num_bits)
335
+ s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
336
+
337
+ return (
338
+ w_ref.to(device=orig_device),
339
+ q_w.to(device=orig_device),
340
+ s_group.to(device=orig_device),
341
+ s_channel.to(device=orig_device),
342
+ )
343
+
344
+
345
+ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
346
+ orig_device = q_w.device
347
+
348
+ sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
349
+
350
+ g_idx = g_idx[sort_indices].contiguous()
351
+ q_w = q_w[sort_indices, :].contiguous()
352
+
353
+ return (
354
+ q_w.to(device=orig_device),
355
+ g_idx.to(device=orig_device),
356
+ sort_indices.to(device=orig_device),
357
+ )
358
+
359
+
360
+ def pack_rows(
361
+ q_w: torch.Tensor,
362
+ num_bits: int,
363
+ size_k: int,
364
+ size_n: int,
365
+ ):
366
+ assert q_w.shape == (size_k, size_n)
367
+
368
+ pack_factor = get_pack_factor(num_bits)
369
+ assert size_k % pack_factor == 0
370
+
371
+ orig_device = q_w.device
372
+
373
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
374
+
375
+ q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
376
+
377
+ for i in range(pack_factor):
378
+ q_res |= q_w[i::pack_factor, :] << num_bits * i
379
+
380
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
381
+ return q_res
382
+
383
+
384
+ def pack_cols(
385
+ q_w: torch.Tensor,
386
+ num_bits: int,
387
+ size_k: int,
388
+ size_n: int,
389
+ ):
390
+ assert q_w.shape == (size_k, size_n)
391
+
392
+ pack_factor = get_pack_factor(num_bits)
393
+ assert size_n % pack_factor == 0
394
+
395
+ orig_device = q_w.device
396
+
397
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
398
+
399
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
400
+
401
+ for i in range(pack_factor):
402
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
403
+
404
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
405
+ q_res = q_res.contiguous()
406
+
407
+ return q_res
408
+
409
+
410
+ def unpack_cols(
411
+ packed_q_w: torch.Tensor,
412
+ num_bits: int,
413
+ size_k: int,
414
+ size_n: int,
415
+ ):
416
+ pack_factor = get_pack_factor(num_bits)
417
+ assert size_n % pack_factor == 0
418
+ assert packed_q_w.shape == (
419
+ size_k,
420
+ size_n // pack_factor,
421
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
422
+ packed_q_w.shape, size_k, size_n, pack_factor
423
+ )
424
+
425
+ orig_device = packed_q_w.device
426
+
427
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
428
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
429
+
430
+ mask = (1 << num_bits) - 1
431
+ for i in range(pack_factor):
432
+ vals = packed_q_w_cpu & mask
433
+ packed_q_w_cpu >>= num_bits
434
+ q_res[:, i::pack_factor] = vals
435
+
436
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
437
+ q_res = q_res.contiguous()
438
+
439
+ return q_res
440
+
441
+
442
+ def gptq_pack(
443
+ q_w: torch.Tensor,
444
+ num_bits: int,
445
+ size_k: int,
446
+ size_n: int,
447
+ ):
448
+ return pack_rows(q_w, num_bits, size_k, size_n)
449
+
450
+
451
+ def awq_pack(
452
+ q_w: torch.Tensor,
453
+ num_bits: int,
454
+ size_k: int,
455
+ size_n: int,
456
+ ):
457
+ assert q_w.shape == (size_k, size_n)
458
+
459
+ # Interleave column dim (for the dequantize code) and pack it to int32
460
+ if num_bits == 4:
461
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
462
+ elif num_bits == 8:
463
+ interleave = numpy.array([0, 2, 1, 3])
464
+ else:
465
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
466
+
467
+ q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
468
+ q_w = q_w.reshape((-1, size_n)).contiguous()
469
+
470
+ return pack_cols(q_w, num_bits, size_k, size_n)
build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,150 +1,30 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch
4
-
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
- ops = torch.ops._quantization
12
- except ImportError:
13
- raise e
14
-
15
- def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
16
- return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
17
-
18
- def cutlass_scaled_mm(a: torch.Tensor,
19
- b: torch.Tensor,
20
- scale_a: torch.Tensor,
21
- scale_b: torch.Tensor,
22
- out_dtype: torch.dtype,
23
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
24
- assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
25
- assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
26
- assert bias is None or bias.shape[0] == b.shape[
27
- 1] and bias.dtype == out_dtype
28
-
29
- m = a.shape[0]
30
- n = b.shape[1]
31
-
32
- #if current_platform.is_rocm():
33
- # triton_scaled_mm_module = importlib.import_module(
34
- # "vllm.model_executor.layers.quantization.compressed_tensors."
35
- # "triton_scaled_mm")
36
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
37
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
38
-
39
- out = torch.empty((m, n), dtype=out_dtype, device=a.device)
40
-
41
- ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
42
-
43
- return out
44
-
45
- # fp8
46
- def scaled_fp8_quant(
47
- input: torch.Tensor,
48
- scale: Optional[torch.Tensor] = None,
49
- num_token_padding: Optional[int] = None,
50
- scale_ub: Optional[torch.Tensor] = None,
51
- use_per_token_if_dynamic: bool = False,
52
- ) -> Tuple[torch.Tensor, torch.Tensor]:
53
- """
54
- Quantize input tensor to FP8 and return quantized tensor and scale.
55
-
56
- This function supports both static and dynamic quantization: If you
57
- provide the scale, it will use static scaling and if you omit it,
58
- the scale will be determined dynamically. The function also allows
59
- optional padding of the output tensors for downstream kernels that
60
- will benefit from padding.
61
-
62
- Args:
63
- input: The input tensor to be quantized to FP8
64
- scale: Optional scaling factor for the FP8 quantization
65
- scale_ub: Optional upper bound for scaling factor in dynamic
66
- per token case
67
- num_token_padding: If specified, pad the first dimension
68
- of the output to at least this value.
69
- use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
- in the dynamic quantization case.
71
-
72
- Returns:
73
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
- scaling factor.
75
- """
76
- # This code assumes batch_dim and num_tokens are flattened
77
- assert (input.ndim == 2)
78
- shape: Union[Tuple[int, int], torch.Size] = input.shape
79
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
- #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
- # if current_platform.is_rocm() else torch.float8_e4m3fn
82
- out_dtype = torch.float8_e4m3fn
83
- if num_token_padding:
84
- shape = (max(num_token_padding, input.shape[0]), shape[1])
85
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
-
87
- if scale is None:
88
- if use_per_token_if_dynamic:
89
- scale = torch.empty((shape[0], 1),
90
- device=input.device,
91
- dtype=torch.float32)
92
- ops.dynamic_per_token_scaled_fp8_quant(
93
- output, input, scale, scale_ub)
94
- else:
95
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
- ops.dynamic_scaled_fp8_quant(output, input, scale)
97
- else:
98
- # num_token_padding not implemented for this case
99
- assert (scale.numel() == 1 or num_token_padding is None)
100
- ops.static_scaled_fp8_quant(output, input, scale)
101
-
102
- return output, scale
103
-
104
- # int8
105
- def scaled_int8_quant(
106
- input: torch.Tensor,
107
- scale: Optional[torch.Tensor] = None,
108
- azp: Optional[torch.Tensor] = None,
109
- symmetric: bool = True
110
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
- """
112
- Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
-
114
- Args:
115
- input: The input tensor to be quantized to int8.
116
- scale: Optional scaling factor for the int8 quantization.
117
- When not provided, we invoke dynamic-per-token quantization.
118
- azp: Optional zero-point for the int8 quantization.
119
- Must be provided for asymmetric quantization if `scale` is provided.
120
- symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
-
122
- Returns:
123
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
- """
125
- output = torch.empty_like(input, dtype=torch.int8)
126
- if scale is not None:
127
- # static-per-tensor quantization.
128
- assert symmetric == (
129
- azp is
130
- None), "azp must only be provided for asymmetric quantization."
131
- ops.static_scaled_int8_quant(output, input, scale, azp)
132
- return output, scale, azp
133
-
134
- # dynamic-per-token quantization.
135
- input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
- device=input.device,
137
- dtype=torch.float32)
138
- input_azp = None if symmetric else torch.empty_like(input_scales,
139
- dtype=torch.int32)
140
- ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
- input_azp)
142
- return output, input_scales, input_azp
143
-
144
- # fp8 marlin
145
- def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
- b_scales: torch.Tensor, workspace: torch.Tensor,
147
- num_bits: int, size_m: int, size_n: int,
148
- size_k: int) -> torch.Tensor:
149
- return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
- num_bits, size_m, size_n, size_k)
 
1
+ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
+ from .cutlass import (
3
+ cutlass_scaled_mm_supports_fp8,
4
+ cutlass_scaled_mm,
5
+ cutlass_scaled_mm_azp,
6
+ )
7
+ from .marlin import (
8
+ awq_marlin_repack,
9
+ fp8_marlin_gemm,
10
+ gptq_marlin_gemm,
11
+ gptq_marlin_repack,
12
+ gptq_marlin_24_gemm,
13
+ marlin_qqq_gemm,
14
+ marlin_gemm,
15
+ )
16
+
17
+ __all__ = [
18
+ "awq_marlin_repack",
19
+ "cutlass_scaled_mm",
20
+ "cutlass_scaled_mm_azp",
21
+ "cutlass_scaled_mm_supports_fp8",
22
+ "fp8_marlin_gemm",
23
+ "gptq_marlin_24_gemm",
24
+ "gptq_marlin_gemm",
25
+ "gptq_marlin_repack",
26
+ "marlin_gemm",
27
+ "marlin_qqq_gemm",
28
+ "scaled_fp8_quant",
29
+ "scaled_int8_quant",
30
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch24-cxx11-cu121-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,3 +1,9 @@
1
  import torch
2
  from . import _quantization_0_0_1
3
  ops = torch.ops._quantization_0_0_1
 
 
 
 
 
 
 
1
  import torch
2
  from . import _quantization_0_0_1
3
  ops = torch.ops._quantization_0_0_1
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_quantization_0_0_1::{op_name}"
build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1fec120a5de02ea58eb455e5f6483ce15da4672c356982bd4ac070864755e28
3
- size 86065792
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0abd1636906a69e8cf7d85fdfc7b99b6b4f4cc3d753431ad3d49ba674238c27
3
+ size 105267936
build/torch24-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ # fp8
18
+ def scaled_fp8_quant(
19
+ input: torch.Tensor,
20
+ scale: Optional[torch.Tensor] = None,
21
+ num_token_padding: Optional[int] = None,
22
+ scale_ub: Optional[torch.Tensor] = None,
23
+ use_per_token_if_dynamic: bool = False,
24
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Quantize input tensor to FP8 and return quantized tensor and scale.
27
+
28
+ This function supports both static and dynamic quantization: If you
29
+ provide the scale, it will use static scaling and if you omit it,
30
+ the scale will be determined dynamically. The function also allows
31
+ optional padding of the output tensors for downstream kernels that
32
+ will benefit from padding.
33
+
34
+ Args:
35
+ input: The input tensor to be quantized to FP8
36
+ scale: Optional scaling factor for the FP8 quantization
37
+ scale_ub: Optional upper bound for scaling factor in dynamic
38
+ per token case
39
+ num_token_padding: If specified, pad the first dimension
40
+ of the output to at least this value.
41
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
42
+ in the dynamic quantization case.
43
+
44
+ Returns:
45
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
+ scaling factor.
47
+ """
48
+ # This code assumes batch_dim and num_tokens are flattened
49
+ assert input.ndim == 2
50
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
51
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
+ # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
54
+ out_dtype = torch.float8_e4m3fn
55
+ if num_token_padding:
56
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
57
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
58
+
59
+ if scale is None:
60
+ if use_per_token_if_dynamic:
61
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
63
+ else:
64
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
66
+ else:
67
+ # num_token_padding not implemented for this case
68
+ assert scale.numel() == 1 or num_token_padding is None
69
+ ops.static_scaled_fp8_quant(output, input, scale)
70
+
71
+ return output, scale
72
+
73
+
74
+ # int8
75
+ def scaled_int8_quant(
76
+ input: torch.Tensor,
77
+ scale: Optional[torch.Tensor] = None,
78
+ azp: Optional[torch.Tensor] = None,
79
+ symmetric: bool = True,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
+ """
82
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
+
84
+ Args:
85
+ input: The input tensor to be quantized to int8.
86
+ scale: Optional scaling factor for the int8 quantization.
87
+ When not provided, we invoke dynamic-per-token quantization.
88
+ azp: Optional zero-point for the int8 quantization.
89
+ Must be provided for asymmetric quantization if `scale` is provided.
90
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
+
92
+ Returns:
93
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
+ """
95
+ output = torch.empty_like(input, dtype=torch.int8)
96
+ if scale is not None:
97
+ # static-per-tensor quantization.
98
+ assert symmetric == (
99
+ azp is None
100
+ ), "azp must only be provided for asymmetric quantization."
101
+ ops.static_scaled_int8_quant(output, input, scale, azp)
102
+ return output, scale, azp
103
+
104
+ # dynamic-per-token quantization.
105
+ input_scales = torch.empty(
106
+ (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
+ )
108
+ input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
+ ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
110
+ return output, input_scales, input_azp
build/torch24-cxx11-cu121-x86_64-linux/quantization/cutlass.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
+ return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
+
20
+
21
+ def cutlass_scaled_mm(
22
+ a: torch.Tensor,
23
+ b: torch.Tensor,
24
+ scale_a: torch.Tensor,
25
+ scale_b: torch.Tensor,
26
+ out_dtype: torch.dtype,
27
+ bias: Optional[torch.Tensor] = None,
28
+ ) -> torch.Tensor:
29
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
30
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
31
+ assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
32
+
33
+ m = a.shape[0]
34
+ n = b.shape[1]
35
+
36
+ # if current_platform.is_rocm():
37
+ # triton_scaled_mm_module = importlib.import_module(
38
+ # "vllm.model_executor.layers.quantization.compressed_tensors."
39
+ # "triton_scaled_mm")
40
+ # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
+ # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
+
43
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
+
45
+ ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
46
+
47
+ return out
48
+
49
+
50
+ def cutlass_scaled_mm_azp(
51
+ a: torch.Tensor,
52
+ b: torch.Tensor,
53
+ scale_a: torch.Tensor,
54
+ scale_b: torch.Tensor,
55
+ out_dtype: torch.dtype,
56
+ azp_adj: torch.Tensor,
57
+ azp: Optional[torch.Tensor] = None,
58
+ bias: Optional[torch.Tensor] = None,
59
+ ) -> torch.Tensor:
60
+ """
61
+ :param azp_adj: In the per-tensor case, this should include the azp.
62
+ Always per-channel.
63
+ :param azp: Only set in the per-token case. Per-token if set.
64
+ """
65
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
66
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
67
+ assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
68
+ assert azp is None or azp.numel() == a.shape[0]
69
+
70
+ m = a.shape[0]
71
+ n = b.shape[1]
72
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
73
+
74
+ ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
75
+ return out
build/torch24-cxx11-cu121-x86_64-linux/quantization/marlin.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+ def register_fake(fn):
8
+ return lambda name: fn
9
+ else:
10
+ try:
11
+ from torch.library import register_fake
12
+ except ImportError:
13
+ from torch.library import impl_abstract as register_fake
14
+
15
+ try:
16
+ from ._ops import ops, add_op_namespace_prefix
17
+ except ImportError as e:
18
+ # Fallback for local development.
19
+ try:
20
+ import _quantization
21
+
22
+ ops = torch.ops._quantization
23
+
24
+ def add_op_namespace_prefix(op_name: str):
25
+ return f"_quantization::{op_name}"
26
+ except ImportError:
27
+ raise e
28
+
29
+
30
+ from .scalar_type import ScalarType
31
+
32
+
33
+ # fp8 marlin
34
+ def fp8_marlin_gemm(
35
+ a: torch.Tensor,
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ workspace: torch.Tensor,
39
+ num_bits: int,
40
+ size_m: int,
41
+ size_n: int,
42
+ size_k: int,
43
+ ) -> torch.Tensor:
44
+ return ops.fp8_marlin_gemm(
45
+ a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
+ )
47
+
48
+
49
+ # gptq_marlin
50
+ def gptq_marlin_gemm(
51
+ a: torch.Tensor,
52
+ b_q_weight: torch.Tensor,
53
+ b_scales: torch.Tensor,
54
+ b_zeros: torch.Tensor,
55
+ g_idx: torch.Tensor,
56
+ perm: torch.Tensor,
57
+ workspace: torch.Tensor,
58
+ b_q_type: ScalarType,
59
+ size_m: int,
60
+ size_n: int,
61
+ size_k: int,
62
+ is_k_full: bool,
63
+ has_zp: bool = False,
64
+ use_fp32_reduce: bool = False,
65
+ is_zp_float: bool = False,
66
+ ) -> torch.Tensor:
67
+ return ops.gptq_marlin_gemm(
68
+ a,
69
+ b_q_weight,
70
+ b_scales,
71
+ b_zeros,
72
+ g_idx,
73
+ perm,
74
+ workspace,
75
+ b_q_type.id,
76
+ size_m,
77
+ size_n,
78
+ size_k,
79
+ is_k_full,
80
+ has_zp,
81
+ use_fp32_reduce,
82
+ is_zp_float,
83
+ )
84
+
85
+
86
+ # gptq_marlin
87
+ def gptq_marlin_repack(
88
+ b_q_weight: torch.Tensor,
89
+ perm: torch.Tensor,
90
+ size_k: int,
91
+ size_n: int,
92
+ num_bits: int,
93
+ ) -> torch.Tensor:
94
+ return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
95
+
96
+
97
+ # gptq_marlin
98
+ def awq_marlin_repack(
99
+ b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
100
+ ) -> torch.Tensor:
101
+ return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
102
+
103
+
104
+ # marlin
105
+ def marlin_gemm(
106
+ a: torch.Tensor,
107
+ b_q_weight: torch.Tensor,
108
+ b_scales: torch.Tensor,
109
+ workspace: torch.Tensor,
110
+ size_m: int,
111
+ size_n: int,
112
+ size_k: int,
113
+ ) -> torch.Tensor:
114
+ return ops.marlin_gemm(
115
+ a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
116
+ )
117
+
118
+
119
+ # marlin_24
120
+ def gptq_marlin_24_gemm(
121
+ a: torch.Tensor,
122
+ b_q_weight: torch.Tensor,
123
+ b_meta: torch.Tensor,
124
+ b_scales: torch.Tensor,
125
+ workspace: torch.Tensor,
126
+ b_q_type: ScalarType,
127
+ size_m: int,
128
+ size_n: int,
129
+ size_k: int,
130
+ ) -> torch.Tensor:
131
+ return ops.gptq_marlin_24_gemm(
132
+ a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
133
+ )
134
+
135
+
136
+ # qqq ops
137
+ def marlin_qqq_gemm(
138
+ a: torch.Tensor,
139
+ b_q_weight: torch.Tensor,
140
+ s_tok: torch.Tensor,
141
+ s_ch: torch.Tensor,
142
+ s_group: torch.Tensor,
143
+ workspace: torch.Tensor,
144
+ size_m: int,
145
+ size_n: int,
146
+ size_k: int,
147
+ ) -> torch.Tensor:
148
+ return ops.marlin_qqq_gemm(
149
+ a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
150
+ )
151
+
152
+
153
+ # Fake ops
154
+
155
+ if hasattr(ops, "gptq_marlin_24_gemm"):
156
+ @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
+ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
+ b_scales: torch.Tensor, workspace: torch.Tensor,
159
+ num_bits: int, size_m: torch.SymInt,
160
+ size_n: torch.SymInt,
161
+ size_k: torch.SymInt) -> torch.Tensor:
162
+ return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
+
164
+ @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
+ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
+ b_meta: torch.Tensor, b_scales: torch.Tensor,
167
+ workspace: torch.Tensor,
168
+ b_q_type: ScalarType, size_m: torch.SymInt,
169
+ size_n: torch.SymInt,
170
+ size_k: torch.SymInt) -> torch.Tensor:
171
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
172
+
173
+ @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
+ def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
+ b_q_weight: torch.Tensor,
176
+ b_scales: torch.Tensor,
177
+ b_zeros: torch.Tensor,
178
+ g_idx: torch.Tensor,
179
+ perm: torch.Tensor,
180
+ workspace: torch.Tensor,
181
+ b_q_type: ScalarType,
182
+ size_m: torch.SymInt,
183
+ size_n: torch.SymInt,
184
+ size_k: torch.SymInt,
185
+ is_k_full: bool,
186
+ has_zp: bool = False,
187
+ use_fp32_reduce: bool = False,
188
+ is_zp_float: bool = False) -> torch.Tensor:
189
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
+
191
+ @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
192
+ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
193
+ s_tok: torch.Tensor, s_ch: torch.Tensor,
194
+ s_group: torch.Tensor, workspace: torch.Tensor,
195
+ size_m: torch.SymInt, size_n: torch.SymInt,
196
+ size_k: torch.SymInt) -> torch.Tensor:
197
+ return torch.empty((size_m, size_n),
198
+ dtype=torch.float16,
199
+ device=a.device)
200
+
201
+ @register_fake(add_op_namespace_prefix("marlin_gemm"))
202
+ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
203
+ b_scales: torch.Tensor, workspace: torch.Tensor,
204
+ size_m: torch.SymInt, size_n: torch.SymInt,
205
+ size_k: torch.SymInt) -> torch.Tensor:
206
+ return torch.empty((size_m, size_n),
207
+ dtype=torch.float16,
208
+ device=a.device)
build/torch24-cxx11-cu121-x86_64-linux/quantization/scalar_type.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import struct
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Optional, Union
6
+
7
+
8
+ # Mirrors enum in `core/scalar_type.hpp`
9
+ class NanRepr(Enum):
10
+ NONE = 0 # nans are not supported
11
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
12
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
13
+
14
+
15
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
16
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
17
+ # in sync until the inductor fully supports custom C++ classes.
18
+ @dataclass(frozen=True)
19
+ class ScalarType:
20
+ """
21
+ ScalarType can represent a wide range of floating point and integer
22
+ types, in particular it can be used to represent sub-byte data types
23
+ (something that torch.dtype currently does not support). It is also
24
+ capable of representing types with a bias, i.e.:
25
+ `stored_value = value + bias`,
26
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
27
+ of 8). The implementation for this class can be found in
28
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
29
+ with that file.
30
+ """
31
+
32
+ exponent: int
33
+ """
34
+ Number of bits in the exponent if this is a floating point type
35
+ (zero if this an integer type)
36
+ """
37
+
38
+ mantissa: int
39
+ """
40
+ Number of bits in the mantissa if this is a floating point type,
41
+ or the number bits representing an integer excluding the sign bit if
42
+ this an integer type.
43
+ """
44
+
45
+ signed: bool
46
+ "If the type is signed (i.e. has a sign bit)"
47
+
48
+ bias: int
49
+ """
50
+ bias used to encode the values in this scalar type
51
+ (value = stored_value - bias, default 0) for example if we store the
52
+ type as an unsigned integer with a bias of 128 then the value 0 will be
53
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
54
+ """
55
+
56
+ _finite_values_only: bool = False
57
+ """
58
+ Private: if infs are supported, used `has_infs()` instead.
59
+ """
60
+
61
+ nan_repr: NanRepr = NanRepr.IEEE_754
62
+ """
63
+ How NaNs are represent in this scalar type, returns NanRepr value.
64
+ (not applicable for integer types)
65
+ """
66
+
67
+ def _floating_point_max_int(self) -> int:
68
+ assert (
69
+ self.mantissa <= 52 and self.exponent <= 11
70
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
71
+
72
+ max_mantissa = (1 << self.mantissa) - 1
73
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
74
+ max_mantissa = max_mantissa - 1
75
+
76
+ max_exponent = (1 << self.exponent) - 2
77
+ if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
78
+ or self.nan_repr == NanRepr.NONE):
79
+ assert (
80
+ self.exponent < 11
81
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
82
+ max_exponent = max_exponent + 1
83
+
84
+ # adjust the exponent to match that of a double
85
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
86
+ # e is the exponent bits), there is some precedent for non-standard
87
+ # biases, example `float8_e4m3b11fnuz` here:
88
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
89
+ # complication we are just assuming the standard exponent bias until
90
+ # there is a need to support non-standard biases
91
+ exponent_bias = (1 << (self.exponent - 1)) - 1
92
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
93
+
94
+ max_exponent_double = (max_exponent - exponent_bias +
95
+ exponent_bias_double)
96
+
97
+ # shift the mantissa and exponent into the proper positions for an
98
+ # IEEE double and bitwise-or them together.
99
+ return (max_mantissa <<
100
+ (52 - self.mantissa)) | (max_exponent_double << 52)
101
+
102
+ def _floating_point_max(self) -> float:
103
+ double_raw = self._floating_point_max_int()
104
+ return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
105
+
106
+ def _raw_max(self) -> Union[int, float]:
107
+ if self.is_floating_point():
108
+ return self._floating_point_max()
109
+ else:
110
+ assert (self.size_bits < 64 or self.size_bits == 64
111
+ and self.is_signed()), "Cannot represent max as an int"
112
+ return (1 << self.mantissa) - 1
113
+
114
+ def _raw_min(self) -> Union[int, float]:
115
+ if self.is_floating_point():
116
+ assert self.is_signed(
117
+ ), "We currently assume all floating point types are signed"
118
+ sign_bit_double = 1 << 63
119
+
120
+ max_raw = self._floating_point_max_int()
121
+ min_raw = max_raw | sign_bit_double
122
+ return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
+ else:
124
+ assert (not self.is_signed() or
125
+ self.size_bits <= 64), "Cannot represent min as a int64_t"
126
+
127
+ if self.is_signed():
128
+ return -(1 << (self.size_bits - 1))
129
+ else:
130
+ return 0
131
+
132
+ @functools.cached_property
133
+ def id(self) -> int:
134
+ """
135
+ Convert the ScalarType to an int which can be passed to pytorch custom
136
+ ops. This layout of the int must be kept in sync with the C++
137
+ ScalarType's from_id method.
138
+ """
139
+ val = 0
140
+ offset = 0
141
+
142
+ def or_and_advance(member, bit_width):
143
+ nonlocal val
144
+ nonlocal offset
145
+ bit_mask = (1 << bit_width) - 1
146
+ val = val | (int(member) & bit_mask) << offset
147
+ offset = offset + bit_width
148
+
149
+ or_and_advance(self.exponent, 8)
150
+ or_and_advance(self.mantissa, 8)
151
+ or_and_advance(self.signed, 1)
152
+ or_and_advance(self.bias, 32)
153
+ or_and_advance(self._finite_values_only, 1)
154
+ or_and_advance(self.nan_repr.value, 8)
155
+
156
+ assert offset <= 64, \
157
+ f"ScalarType fields too big {offset} to fit into an int64"
158
+
159
+ return val
160
+
161
+ @property
162
+ def size_bits(self) -> int:
163
+ return self.exponent + self.mantissa + int(self.signed)
164
+
165
+ def min(self) -> Union[int, float]:
166
+ """
167
+ Min representable value for this scalar type.
168
+ (accounting for bias if there is one)
169
+ """
170
+ return self._raw_min() - self.bias
171
+
172
+ def max(self) -> Union[int, float]:
173
+ """
174
+ Max representable value for this scalar type.
175
+ (accounting for bias if there is one)
176
+ """
177
+ return self._raw_max() - self.bias
178
+
179
+ def is_signed(self) -> bool:
180
+ """
181
+ If the type is signed (i.e. has a sign bit), same as `signed`
182
+ added for consistency with:
183
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
184
+ """
185
+ return self.signed
186
+
187
+ def is_floating_point(self) -> bool:
188
+ "If the type is a floating point type"
189
+ return self.exponent != 0
190
+
191
+ def is_integer(self) -> bool:
192
+ "If the type is an integer type"
193
+ return self.exponent == 0
194
+
195
+ def has_bias(self) -> bool:
196
+ "If the type has a non-zero bias"
197
+ return self.bias != 0
198
+
199
+ def has_infs(self) -> bool:
200
+ "If the type is floating point and supports infinity"
201
+ return not self._finite_values_only
202
+
203
+ def has_nans(self) -> bool:
204
+ return self.nan_repr != NanRepr.NONE.value
205
+
206
+ def is_ieee_754(self) -> bool:
207
+ """
208
+ If the type is a floating point type that follows IEEE 754
209
+ conventions
210
+ """
211
+ return self.nan_repr == NanRepr.IEEE_754.value and \
212
+ not self._finite_values_only
213
+
214
+ def __str__(self) -> str:
215
+ """
216
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
217
+ for floating point types (leading f) the scheme is:
218
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
219
+ flags:
220
+ - no-flags: means it follows IEEE 754 conventions
221
+ - f: means finite values only (no infinities)
222
+ - n: means nans are supported (non-standard encoding)
223
+ for integer types the scheme is:
224
+ `[u]int<size_bits>[b<bias>]`
225
+ - if bias is not present it means its zero
226
+ """
227
+ if self.is_floating_point():
228
+ ret = "float" + str(self.size_bits) + "_e" + str(
229
+ self.exponent) + "m" + str(self.mantissa)
230
+
231
+ if not self.is_ieee_754():
232
+ if self._finite_values_only:
233
+ ret = ret + "f"
234
+ if self.nan_repr != NanRepr.NONE:
235
+ ret = ret + "n"
236
+
237
+ return ret
238
+ else:
239
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
240
+ if self.has_bias():
241
+ ret = ret + "b" + str(self.bias)
242
+ return ret
243
+
244
+ def __repr__(self) -> str:
245
+ return "ScalarType." + self.__str__()
246
+
247
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
248
+ # opcheck to work.
249
+ def __len__(self) -> int:
250
+ raise TypeError
251
+
252
+ #
253
+ # Convenience Constructors
254
+ #
255
+
256
+ @classmethod
257
+ def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
258
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
259
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
260
+ ret.id # noqa B018: make sure the id is cached
261
+ return ret
262
+
263
+ @classmethod
264
+ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
+ """Create a unsigned integer scalar type."""
266
+ ret = cls(0, size_bits, False, bias if bias else 0)
267
+ ret.id # noqa B018: make sure the id is cached
268
+ return ret
269
+
270
+ @classmethod
271
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
272
+ """
273
+ Create a standard floating point type
274
+ (i.e. follows IEEE 754 conventions).
275
+ """
276
+ assert (mantissa > 0 and exponent > 0)
277
+ ret = cls(exponent, mantissa, True, 0)
278
+ ret.id # noqa B018: make sure the id is cached
279
+ return ret
280
+
281
+ @classmethod
282
+ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
283
+ nan_repr: NanRepr) -> 'ScalarType':
284
+ """
285
+ Create a non-standard floating point type
286
+ (i.e. does not follow IEEE 754 conventions).
287
+ """
288
+ assert (mantissa > 0 and exponent > 0)
289
+ assert (nan_repr != NanRepr.IEEE_754), (
290
+ "use `float_IEEE754` constructor for floating point types that "
291
+ "follow IEEE 754 conventions")
292
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
293
+ ret.id # noqa B018: make sure the id is cached
294
+ return ret
295
+
296
+
297
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
+ # for floating point types (leading f) the scheme is:
299
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
300
+ # flags:
301
+ # - no-flags: means it follows IEEE 754 conventions
302
+ # - f: means finite values only (no infinities)
303
+ # - n: means nans are supported (non-standard encoding)
304
+ # for integer types the scheme is:
305
+ # `[u]int<size_bits>[b<bias>]`
306
+ # - if bias is not present it means its zero
307
+
308
+
309
+ class scalar_types:
310
+ int4 = ScalarType.int_(4, None)
311
+ uint4 = ScalarType.uint(4, None)
312
+ int8 = ScalarType.int_(8, None)
313
+ uint8 = ScalarType.uint(8, None)
314
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
315
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
316
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
317
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
318
+
319
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
+
322
+ # "gptq" types
323
+ uint2b2 = ScalarType.uint(2, 2)
324
+ uint3b4 = ScalarType.uint(3, 4)
325
+ uint4b8 = ScalarType.uint(4, 8)
326
+ uint8b128 = ScalarType.uint(8, 128)
327
+
328
+ # colloquial names
329
+ bfloat16 = float16_e8m7
330
+ float16 = float16_e5m10
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py ADDED
File without changes
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ import quantization as ops
7
+ from quantization.scalar_type import ScalarType, scalar_types
8
+
9
+ from .quant_utils import pack_cols, unpack_cols
10
+
11
+ GPTQ_MARLIN_TILE = 16
12
+ GPTQ_MARLIN_MIN_THREAD_N = 64
13
+ GPTQ_MARLIN_MIN_THREAD_K = 128
14
+ GPTQ_MARLIN_MAX_PARALLEL = 16
15
+
16
+ GPTQ_MARLIN_24_TILE = 16
17
+ GPTQ_MARLIN_24_MIN_THREAD_N = 128
18
+ GPTQ_MARLIN_24_MIN_THREAD_K = 128
19
+ GPTQ_MARLIN_24_MAX_PARALLEL = 64
20
+
21
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
22
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
23
+
24
+ MARLIN_QQQ_TILE = 16
25
+ MARLIN_QQQ_MIN_THREAD_N = 64
26
+ MARLIN_QQQ_MIN_THREAD_K = 128
27
+ MARLIN_QQQ_MAX_PARALLEL = 16
28
+
29
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
30
+ MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
31
+ MARLIN_QQQ_SUPPORTED_SYM = [True]
32
+
33
+ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
34
+
35
+ # In case there is a performance issue with Marlin, the variable below can be
36
+ # changed to False, which allows Marlin to perform global reductions in fp16
37
+ # precision (instead of fp32), and therefore, save on some memory movements.
38
+ USE_FP32_REDUCE_DEFAULT = True
39
+
40
+
41
+ # For binary size and compile time, we don't support the same types for with and
42
+ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
+ # TODO: we may want to move this into the C++ so its closer to the actual impl
44
+ def query_marlin_supported_quant_types(
45
+ has_zp: bool, device_capability: Optional[int] = None
46
+ ):
47
+ if device_capability is None:
48
+ capability_tuple = torch.cuda.get_device_capability()
49
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
50
+
51
+ if device_capability < 80:
52
+ return []
53
+
54
+ if has_zp:
55
+ # AWQ style, unsigned + runtime zero-point
56
+ return [scalar_types.uint4, scalar_types.uint8]
57
+ else:
58
+ # GPTQ style, unsigned + symmetric bias
59
+ # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
+ # to add `scalar_types.float8_e4m3fn` here
61
+ return [scalar_types.uint4b8, scalar_types.uint8b128]
62
+
63
+
64
+ def _check_marlin_supported(
65
+ quant_type: ScalarType,
66
+ group_size: Optional[int],
67
+ has_zp: bool,
68
+ device_capability: Optional[int] = None,
69
+ ) -> Tuple[bool, Optional[str]]:
70
+
71
+ if device_capability is None:
72
+ capability_tuple = torch.cuda.get_device_capability()
73
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
+
75
+ supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
76
+
77
+ if quant_type not in supported_types:
78
+ return (
79
+ False,
80
+ f"Marlin does not support weight_bits = {quant_type}. "
81
+ f"Only types = {supported_types} "
82
+ f"are supported (for group_size = {group_size}, "
83
+ f"device_capability = {device_capability}, zp = {has_zp}).",
84
+ )
85
+ if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
+ return (
87
+ False,
88
+ f"Marlin does not support group_size = {group_size}. "
89
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
+ "are supported.",
91
+ )
92
+
93
+ return True, None
94
+
95
+
96
+ def check_marlin_supported(
97
+ quant_type: ScalarType,
98
+ group_size: int,
99
+ has_zp: bool = False,
100
+ device_capability: Optional[int] = None,
101
+ ) -> bool:
102
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
+ return cond
104
+
105
+
106
+ def verify_marlin_supported(
107
+ quant_type: ScalarType, group_size: int, has_zp: bool = False
108
+ ) -> None:
109
+ cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
+ if not cond:
111
+ assert err_msg is not None
112
+ raise ValueError(err_msg)
113
+
114
+
115
+ def verify_marlin_supports_shape(
116
+ output_size_per_partition: int,
117
+ input_size_per_partition: int,
118
+ input_size: int,
119
+ group_size: int,
120
+ ) -> None:
121
+
122
+ # Validate output_size_per_partition
123
+ if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
+ raise ValueError(
125
+ f"Weight output_size_per_partition = "
126
+ f"{output_size_per_partition} is not divisible by "
127
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
+ "Consider reducing tensor_parallel_size or running "
129
+ "with --quantization gptq."
130
+ )
131
+
132
+ # Validate input_size_per_partition
133
+ if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
+ raise ValueError(
135
+ f"Weight input_size_per_partition = "
136
+ f"{input_size_per_partition} is not divisible "
137
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
+ "Consider reducing tensor_parallel_size or running "
139
+ "with --quantization gptq."
140
+ )
141
+
142
+ if group_size < input_size and input_size_per_partition % group_size != 0:
143
+ raise ValueError(
144
+ f"Weight input_size_per_partition = {input_size_per_partition}"
145
+ f" is not divisible by group_size = {group_size}."
146
+ "Consider reducing tensor_parallel_size or running "
147
+ "with --quantization gptq."
148
+ )
149
+
150
+
151
+ def check_marlin_supports_shape(
152
+ output_size_per_partition: int,
153
+ input_size_per_partition: int,
154
+ input_size: int,
155
+ group_size: int,
156
+ ) -> Tuple[bool, Optional[str]]:
157
+ try:
158
+ verify_marlin_supports_shape(
159
+ output_size_per_partition, input_size_per_partition, input_size, group_size
160
+ )
161
+ except ValueError as e:
162
+ return False, e.__str__()
163
+ return True, None
164
+
165
+
166
+ def marlin_make_workspace(
167
+ output_size_per_partition: int, device: torch.device
168
+ ) -> torch.Tensor:
169
+ max_workspace_size = (
170
+ output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
+ ) * GPTQ_MARLIN_MAX_PARALLEL
172
+
173
+ return torch.zeros(
174
+ max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
+ )
176
+
177
+
178
+ def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
+ return (not act_order) or (act_order and not is_row_parallel)
180
+
181
+
182
+ def marlin_repeat_scales_on_all_ranks(
183
+ act_order: bool, group_size: int, is_row_parallel: bool
184
+ ) -> bool:
185
+ # Need to repeat scales on every rank if act_ordering or
186
+ # channelwise and RowParallelLinear
187
+ is_channelwise = group_size == -1
188
+ return act_order or (is_channelwise and is_row_parallel)
189
+
190
+
191
+ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
+ return torch.nn.Parameter(
193
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
+ )
195
+
196
+
197
+ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
+ return torch.nn.Parameter(
199
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
+ )
201
+
202
+
203
+ def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
+ return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
+
207
+
208
+ def get_scale_perms():
209
+ scale_perm: List[int] = []
210
+ for i in range(8):
211
+ scale_perm.extend([i + 8 * j for j in range(8)])
212
+ scale_perm_single: List[int] = []
213
+ for i in range(4):
214
+ scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
215
+ return scale_perm, scale_perm_single
216
+
217
+
218
+ def marlin_permute_scales(
219
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
+ ) -> torch.Tensor:
221
+
222
+ scale_perm, scale_perm_single = get_scale_perms()
223
+ if group_size < size_k and group_size != -1:
224
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
225
+ else:
226
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
227
+ s = s.reshape((-1, size_n)).contiguous()
228
+
229
+ return s
230
+
231
+
232
+ def marlin_moe_permute_scales(
233
+ s: torch.Tensor,
234
+ size_k: int,
235
+ size_n: int,
236
+ group_size: int,
237
+ ):
238
+ num_experts = s.shape[0]
239
+ output = torch.empty(
240
+ (num_experts, s.shape[1], s.shape[2]),
241
+ device=s.device,
242
+ dtype=s.dtype,
243
+ )
244
+
245
+ for e in range(num_experts):
246
+ output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
247
+ return output
248
+
249
+
250
+ def marlin_zero_points(
251
+ zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
+ ) -> torch.Tensor:
253
+ # Permute zero-points in a similar way to scales, but do not use the
254
+ # "single" permutation, since zero-points are applied on every MMA
255
+ scale_perm, _ = get_scale_perms()
256
+ zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
257
+
258
+ # Interleave column dim (for the dequantize code) and pack it to int32
259
+ if num_bits == 4:
260
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
261
+ elif num_bits == 8:
262
+ interleave = numpy.array([0, 2, 1, 3])
263
+ else:
264
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
265
+
266
+ zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
267
+ zp = zp.reshape((-1, size_n)).contiguous()
268
+ zp = pack_cols(zp, num_bits, size_k, size_n)
269
+
270
+ return zp
271
+
272
+
273
+ def awq_to_marlin_zero_points(
274
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
+ ) -> torch.Tensor:
276
+ # AWQ zero-points are quantized and packed on the column dim.
277
+ # In addition, the values are permuted based on dequantizer.
278
+ # Here we undo both of these, and then apply marlin permutation
279
+ # and pack it back.
280
+ q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
281
+
282
+ # Undo interleaving (use argsort(..) to get inverse perm)
283
+ if num_bits == 4:
284
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
285
+ elif num_bits == 8:
286
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
287
+ else:
288
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
289
+
290
+ q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
291
+ q_zp = q_zp.reshape((-1, size_n)).contiguous()
292
+
293
+ marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
294
+ return marlin_zp
295
+
296
+
297
+ def moe_awq_to_marlin_zero_points(
298
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
+ ):
300
+ num_experts = q_zp_packed.shape[0]
301
+ output = torch.empty(
302
+ (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
303
+ device=q_zp_packed.device,
304
+ dtype=q_zp_packed.dtype,
305
+ )
306
+ for e in range(num_experts):
307
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
308
+ return output
309
+
310
+
311
+ def apply_gptq_marlin_linear(
312
+ input: torch.Tensor,
313
+ weight: torch.Tensor,
314
+ weight_scale: torch.Tensor,
315
+ weight_zp: torch.Tensor,
316
+ g_idx: torch.Tensor,
317
+ g_idx_sort_indices: torch.Tensor,
318
+ workspace: torch.Tensor,
319
+ wtype: ScalarType,
320
+ output_size_per_partition: int,
321
+ input_size_per_partition: int,
322
+ is_k_full: bool,
323
+ bias: Optional[torch.Tensor] = None,
324
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
+ ) -> torch.Tensor:
326
+ reshaped_x = input.reshape(-1, input.shape[-1])
327
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
328
+
329
+ output = ops.gptq_marlin_gemm(
330
+ reshaped_x,
331
+ weight,
332
+ weight_scale,
333
+ weight_zp,
334
+ g_idx,
335
+ g_idx_sort_indices,
336
+ workspace,
337
+ wtype,
338
+ size_m=reshaped_x.shape[0],
339
+ size_n=output_size_per_partition,
340
+ size_k=input_size_per_partition,
341
+ is_k_full=is_k_full,
342
+ has_zp=False,
343
+ use_fp32_reduce=use_fp32_reduce,
344
+ is_zp_float=False,
345
+ )
346
+
347
+ if bias is not None:
348
+ output.add_(bias) # In-place add
349
+
350
+ return output.reshape(out_shape)
351
+
352
+
353
+ def apply_awq_marlin_linear(
354
+ input: torch.Tensor,
355
+ weight: torch.Tensor,
356
+ weight_scale: torch.Tensor,
357
+ weight_zp: torch.Tensor,
358
+ g_idx: torch.Tensor,
359
+ g_idx_sort_indices: torch.Tensor,
360
+ workspace: torch.Tensor,
361
+ quant_type: ScalarType,
362
+ output_size_per_partition: int,
363
+ input_size_per_partition: int,
364
+ bias: Optional[torch.Tensor] = None,
365
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
+ ) -> torch.Tensor:
367
+ reshaped_x = input.reshape(-1, input.shape[-1])
368
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
369
+
370
+ output = ops.gptq_marlin_gemm(
371
+ reshaped_x,
372
+ weight,
373
+ weight_scale,
374
+ weight_zp,
375
+ g_idx,
376
+ g_idx_sort_indices,
377
+ workspace,
378
+ quant_type,
379
+ size_m=reshaped_x.shape[0],
380
+ size_n=output_size_per_partition,
381
+ size_k=input_size_per_partition,
382
+ is_k_full=True,
383
+ has_zp=True,
384
+ use_fp32_reduce=use_fp32_reduce,
385
+ is_zp_float=False,
386
+ )
387
+
388
+ if bias is not None:
389
+ output.add_(bias) # In-place add
390
+
391
+ return output.reshape(out_shape)
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ import quantization as ops
6
+
7
+ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
+
9
+
10
+ def is_fp8_marlin_supported():
11
+ capability = torch.cuda.get_device_capability()
12
+ capability = capability[0] * 10 + capability[1]
13
+ return capability >= 80
14
+
15
+
16
+ def apply_fp8_marlin_linear(
17
+ input: torch.Tensor,
18
+ weight: torch.Tensor,
19
+ weight_scale: torch.Tensor,
20
+ workspace: torch.Tensor,
21
+ size_n: int,
22
+ size_k: int,
23
+ bias: Optional[torch.Tensor],
24
+ ) -> torch.Tensor:
25
+ # For GPUs that lack FP8 hardware support, we can leverage the
26
+ # Marlin kernel for fast weight-only FP8 quantization
27
+
28
+ reshaped_x = input.reshape(-1, input.shape[-1])
29
+ out_shape = input.shape[:-1] + (size_n,)
30
+
31
+ output = ops.fp8_marlin_gemm(
32
+ a=reshaped_x,
33
+ b_q_weight=weight,
34
+ b_scales=weight_scale,
35
+ workspace=workspace,
36
+ num_bits=8,
37
+ size_m=reshaped_x.shape[0],
38
+ size_n=size_n,
39
+ size_k=size_k,
40
+ )
41
+
42
+ if bias is not None:
43
+ output.add_(bias) # In-place add
44
+
45
+ return output.reshape(out_shape)
46
+
47
+
48
+ def prepare_fp8_layer_for_marlin(
49
+ layer: torch.nn.Module, strategy: str = "tensor"
50
+ ) -> None:
51
+ part_size_n = layer.output_size_per_partition
52
+ part_size_k = layer.input_size_per_partition
53
+
54
+ device = layer.weight.device
55
+
56
+ # WORKSPACE
57
+ layer.workspace = marlin_make_workspace(part_size_n, device)
58
+
59
+ # WEIGHT
60
+ # Repack weights to marlin format
61
+ marlin_qweight = ops.gptq_marlin_repack(
62
+ b_q_weight=pack_fp8_to_int32(layer.weight),
63
+ perm=torch.empty(0, dtype=torch.int, device=device),
64
+ size_k=part_size_k,
65
+ size_n=part_size_n,
66
+ num_bits=8,
67
+ )
68
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
+
70
+ # WEIGHT SCALES
71
+ scales = layer.weight_scale.to(layer.orig_dtype)
72
+ # Permute scales
73
+ marlin_scales = marlin_permute_scales(
74
+ s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
+ )
76
+ layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
+
78
+
79
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
+ """
81
+ Repack FP8 weights to gptq format (packed int32 elements)
82
+ """
83
+ assert fp8_tensor.dtype == torch.float8_e4m3fn
84
+ assert fp8_tensor.shape[0] % 4 == 0
85
+
86
+ # Reshape to prepare for packing
87
+ reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
88
+
89
+ # Convert fp8 to uint8 (byte) representation
90
+ byte_tensor = reshaped.view(torch.uint8)
91
+
92
+ # Pack 4 uint8 values into one int32
93
+ packed = (
94
+ byte_tensor[:, 0].to(torch.int32)
95
+ | (byte_tensor[:, 1].to(torch.int32) << 8)
96
+ | (byte_tensor[:, 2].to(torch.int32) << 16)
97
+ | (byte_tensor[:, 3].to(torch.int32) << 24)
98
+ )
99
+
100
+ return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from quantization.scalar_type import ScalarType
9
+
10
+ from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
11
+ from .quant_utils import (
12
+ get_pack_factor,
13
+ gptq_quantize_weights,
14
+ quantize_weights,
15
+ sort_weights,
16
+ )
17
+
18
+
19
+ class MarlinWorkspace:
20
+
21
+ def __init__(self, out_features, min_thread_n, max_parallel):
22
+ assert (
23
+ out_features % min_thread_n == 0
24
+ ), "out_features = {} is undivisible by min_thread_n = {}".format(
25
+ out_features, min_thread_n
26
+ )
27
+
28
+ max_workspace_size = (out_features // min_thread_n) * max_parallel
29
+
30
+ self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
31
+
32
+
33
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
34
+ assert q_w.shape == (size_k, size_n)
35
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
36
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
37
+
38
+ # Permute weights to 16x64 marlin tiles
39
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
40
+ q_w = q_w.permute((0, 2, 1, 3))
41
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
42
+
43
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
44
+
45
+ return q_w
46
+
47
+
48
+ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
49
+ # Permute
50
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
51
+
52
+ # Pack
53
+ pack_factor = get_pack_factor(num_bits)
54
+ orig_device = q_w.device
55
+
56
+ q_w = q_w.cpu().numpy().astype(np.uint32)
57
+
58
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
59
+ for i in range(pack_factor):
60
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
61
+
62
+ q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
63
+
64
+ return q_packed
65
+
66
+
67
+ def get_weight_perm(num_bits: int):
68
+ perm_list: List[int] = []
69
+ for i in range(32):
70
+ perm1: List[int] = []
71
+ col = i // 4
72
+ for block in [0, 1]:
73
+ for row in [
74
+ 2 * (i % 4),
75
+ 2 * (i % 4) + 1,
76
+ 2 * (i % 4 + 4),
77
+ 2 * (i % 4 + 4) + 1,
78
+ ]:
79
+ perm1.append(16 * row + col + 8 * block)
80
+ for j in range(4):
81
+ perm_list.extend([p + 256 * j for p in perm1])
82
+
83
+ perm = np.array(perm_list)
84
+
85
+ if num_bits == 4:
86
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
87
+ elif num_bits == 8:
88
+ interleave = np.array([0, 2, 1, 3])
89
+ else:
90
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
91
+
92
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
93
+ perm = torch.from_numpy(perm)
94
+ return perm
95
+
96
+
97
+ def marlin_quantize(
98
+ w: torch.Tensor,
99
+ quant_type: ScalarType,
100
+ group_size: int,
101
+ act_order: bool,
102
+ test_perm: Optional[torch.Tensor] = None,
103
+ ):
104
+ size_k, size_n = w.shape
105
+ num_bits = quant_type.size_bits
106
+
107
+ # Normalize group_size
108
+ if group_size == -1:
109
+ group_size = size_k
110
+ assert group_size <= size_k
111
+
112
+ # Quantize (and apply act_order if provided)
113
+ w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
114
+ w, quant_type, group_size, act_order, test_perm
115
+ )
116
+
117
+ # For act_order, sort the "weights" and "g_idx" so that group ids are
118
+ # increasing
119
+ sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
120
+ if act_order:
121
+ q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
122
+
123
+ # Reformat to marlin
124
+ weight_perm = get_weight_perm(num_bits)
125
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
126
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
127
+
128
+ # Create result
129
+ res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
130
+ for i in range(len(res_list)):
131
+ res_list[i] = res_list[i].to(w.device)
132
+
133
+ return res_list
134
+
135
+
136
+ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
137
+ size_k, size_n = w.shape
138
+
139
+ # Normalize group_size
140
+ if group_size == -1:
141
+ group_size = size_k
142
+ assert group_size <= size_k
143
+
144
+ # Detect num groups
145
+ assert size_k % group_size == 0
146
+ num_groups = size_k // group_size
147
+
148
+ # Quantize with zp
149
+ w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
150
+
151
+ # Reformat to marlin
152
+ weight_perm = get_weight_perm(quant_type.size_bits)
153
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
154
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
155
+ marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
156
+
157
+ # Create result
158
+ res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
159
+ for i in range(len(res_list)):
160
+ res_list[i] = res_list[i].to(w.device)
161
+
162
+ return res_list
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ import random
4
+ from typing import List
5
+
6
+ import numpy
7
+ import torch
8
+
9
+ from quantization.scalar_type import ScalarType
10
+
11
+ from .marlin_utils_test import marlin_weights
12
+ from .quant_utils import gptq_quantize_weights
13
+
14
+
15
+ # This is PyTorch implementation of main part of reorder_meta()
16
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
17
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
18
+ # GEMM decides upon layout of this matrix, and at the moment for the
19
+ # sparse GEMM executed on tensor cores, this is layout described by
20
+ # ColumnMajorInterleaved<2> data structure, in
21
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
22
+ # reordering of meta matrix into meta_reordered matrix calculated
23
+ # according to these segments of CUTLASS code is re-implemented here.
24
+ # Note that this calculation produces offsets for scattering metadata
25
+ # matrix elements into reordered metadata matrix elements (or,
26
+ # equivalently, for gathering reordered metadata matrix element back
27
+ # into metadata matrix elements).
28
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
29
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
30
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
31
+
32
+ # Reorder the rows, then swizzle the 2x2 blocks.
33
+ group_x = 64
34
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
35
+
36
+ dst_rows = (
37
+ dst_rows // group_x * group_x
38
+ + (dst_rows % 2) * 2
39
+ + (dst_rows % 8) // 4
40
+ + ((dst_rows % group_y) % 4) // 2 * 32
41
+ + ((dst_rows % group_x) // 8) * 4
42
+ )
43
+
44
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
45
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
46
+ dst_rows += topright - bottomleft
47
+ dst_cols -= topright - bottomleft
48
+
49
+ # Assumed that meta tensor is to be stored in CUTLASS
50
+ # InterleavedColumnMajor layout, and reverse engineered
51
+ # corresponding code to store values into this tensor.
52
+ interleave = 2
53
+ cols_maj = dst_cols // interleave
54
+ cols_min = dst_cols % interleave
55
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
56
+
57
+
58
+ # This function converts dense matrix into sparse semi-structured
59
+ # representation, producing "compressed" matrix, in the layout used by
60
+ # CUTLASS backend, and corresponding metadata matrix.
61
+ def sparse_semi_structured_from_dense_cutlass(dense):
62
+ if dense.dim() != 2:
63
+ raise RuntimeError(
64
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
65
+ )
66
+
67
+ m, k = dense.shape
68
+ device = dense.device
69
+
70
+ meta_dtype = torch.int8
71
+ if dense.dtype == torch.int8:
72
+ meta_dtype = torch.int32
73
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
74
+ meta_dtype = torch.int16
75
+ else:
76
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
77
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
78
+ if quadbits_per_meta_elem not in (4, 8):
79
+ raise RuntimeError("Invalid number of elements per meta element calculated")
80
+
81
+ if meta_dtype == torch.int32:
82
+ if m % 16 != 0:
83
+ raise RuntimeError(
84
+ f"Number of rows of dense matrix {m} must be divisible by 16"
85
+ )
86
+ else:
87
+ if m % 32 != 0:
88
+ raise RuntimeError(
89
+ f"Number of rows of dense matrix {m} must be divisible by 32"
90
+ )
91
+ if k % (4 * quadbits_per_meta_elem) != 0:
92
+ raise RuntimeError(
93
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
94
+ )
95
+
96
+ if dense.dtype != torch.float:
97
+ ksparse = 4
98
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
99
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
100
+ else:
101
+ ksparse = 2
102
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
103
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
104
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
105
+
106
+ # Encoding quadruples of True/False values as follows:
107
+ # [True, True, False, False] -> 0b0100
108
+ # [True, False, True, False] -> 0b1000
109
+ # [False, True, True, False] -> 0b1001
110
+ # [True, False, False, True ] -> 0b1100
111
+ # [False, True, False, True ] -> 0b1101
112
+ # [False, False, True, True ] -> 0b1110
113
+ # Thus, lower two bits in the encoding are index of the True value
114
+ # at the lowest index in the quadruple, and the higher two bits in
115
+ # the encoding are index of the other True value in the quadruple.
116
+ # In case there are less than two True values, than False value or
117
+ # values at some index or indices are considered True for the
118
+ # encoding. In case there are more than two True values, then the
119
+ # excess True value(s) at some indices are considered False for
120
+ # the encoding. The exact encodings used for these cases are as
121
+ # follows:
122
+ # [False, False, False, False] -> 0b1110
123
+ # [False, False, False, True ] -> 0b1110
124
+ # [False, False, True, False] -> 0b1110
125
+ # [False, True, False, False] -> 0b1001
126
+ # [False, True, True, True ] -> 0b1101
127
+ # [True, False, False, False] -> 0b1000
128
+ # [True, False, True, True ] -> 0b1100
129
+ # [True, True, False, True ] -> 0b0100
130
+ # [True, True, True, False] -> 0b0100
131
+ # [True, True, True, True ] -> 0b0100
132
+ # These particular encodings are chosen, with the help of Espresso
133
+ # logic minimizer software, for the purpose of minimization of
134
+ # corresponding Boolean functions, that translate non-zero flags
135
+ # into encoding bits. Note also possible choices for the first
136
+ # and last of these encodings were limited only to (0b0100,
137
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
138
+ # case.
139
+
140
+ expr0 = m0 & m1
141
+ expr1 = ~m0 & m1
142
+ expr2 = ~m0 & ~m1
143
+ bit0 = expr1
144
+ bit1 = expr2
145
+ bit2 = expr0 | expr2 | m3
146
+ bit3 = expr1 | ~m1
147
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
148
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
149
+
150
+ if dense.dtype != torch.float:
151
+ sparse0 = dense_4.gather(
152
+ -1, idxs0.unsqueeze(-1)
153
+ ) # type: ignore[possibly-undefined]
154
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
155
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
156
+ else:
157
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
158
+ m, k // 2
159
+ ) # type: ignore[possibly-undefined]
160
+
161
+ meta_4 = idxs0 | (idxs1 << 2)
162
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
163
+
164
+ if quadbits_per_meta_elem == 4:
165
+ meta = (
166
+ meta_n[:, :, 0]
167
+ | (meta_n[:, :, 1] << 4)
168
+ | (meta_n[:, :, 2] << 8)
169
+ | (meta_n[:, :, 3] << 12)
170
+ )
171
+ elif quadbits_per_meta_elem == 8:
172
+ meta = (
173
+ meta_n[:, :, 0]
174
+ | (meta_n[:, :, 1] << 4)
175
+ | (meta_n[:, :, 2] << 8)
176
+ | (meta_n[:, :, 3] << 12)
177
+ | (meta_n[:, :, 4] << 16)
178
+ | (meta_n[:, :, 5] << 20)
179
+ | (meta_n[:, :, 6] << 24)
180
+ | (meta_n[:, :, 7] << 28)
181
+ )
182
+
183
+ # Reorder meta tensor elements.
184
+ meta_reordered = meta.new_empty(
185
+ (m * meta_ncols,)
186
+ ) # type: ignore[possibly-undefined]
187
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
188
+ m, meta_ncols, meta_dtype, device
189
+ )
190
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
191
+
192
+ return (sparse, meta_reordered.view(m, meta_ncols))
193
+
194
+
195
+ # This function performs reverse of the function above - it
196
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
197
+ # in the layout used by CUTLASS backend, and accompanying metadata
198
+ # matrix.
199
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
200
+ if sparse.dim() != 2:
201
+ raise RuntimeError(
202
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
203
+ )
204
+
205
+ m, k = sparse.shape
206
+ device = sparse.device
207
+
208
+ if meta_reordered.dim() != 2:
209
+ raise RuntimeError(
210
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
211
+ )
212
+ if meta_reordered.device != device:
213
+ raise RuntimeError(
214
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
215
+ )
216
+
217
+ meta_dtype = meta_reordered.dtype
218
+ if meta_dtype not in (torch.int16, torch.int32):
219
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
220
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
221
+
222
+ ksparse = 4 if sparse.dtype != torch.float else 2
223
+
224
+ meta_nrows, meta_ncols = meta_reordered.shape
225
+ if meta_nrows != m:
226
+ raise RuntimeError(
227
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
228
+ )
229
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
230
+ raise RuntimeError(
231
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
232
+ "expected according to the number of columns of meta matrix"
233
+ )
234
+
235
+ # Undo meta tensor elements reordering.
236
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
237
+ m, meta_ncols, meta_dtype, device
238
+ )
239
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
240
+
241
+ # Unpack sparse tensor back to original dense tensor, using
242
+ # information provided by meta tensor. Note that torch.float
243
+ # datatype is handled pretty much the same as
244
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
245
+ # value is encoded as if underlying 8 bytes contain four
246
+ # torch.half/torch.bfloat16 values, where either first two or last
247
+ # two are zeros.
248
+ meta_2 = torch.empty(
249
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
250
+ dtype=meta_dtype,
251
+ device=device,
252
+ )
253
+ if quadbits_per_meta_elem == 4:
254
+ meta_2[:, :, 0] = meta & 0b11
255
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
256
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
257
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
258
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
259
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
260
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
261
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
262
+ elif quadbits_per_meta_elem == 8:
263
+ meta_2[:, :, 0] = meta & 0b11
264
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
265
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
266
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
267
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
268
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
269
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
270
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
271
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
272
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
273
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
274
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
275
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
276
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
277
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
278
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
279
+
280
+ dense_offsets = meta_2.view(-1) + (
281
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
282
+ ).view(-1, 1).repeat(1, 2).view(-1)
283
+
284
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
285
+ if sparse.dtype != torch.float:
286
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
287
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
288
+ else:
289
+ dense.view(torch.half).scatter_(
290
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
291
+ )
292
+
293
+ return dense.view(m, 2 * k)
294
+
295
+
296
+ def mask_creator(tensor):
297
+ """
298
+ Class for creating N:M sparsity masks.
299
+ Masks will be created using the N:M ratio, where for every block of
300
+ M weights, N will be pruned based on ranked weight value. Each mask
301
+ will correspond to the given tensor.
302
+
303
+ :param N: The number of weights in a group to keep
304
+ :param M: The size of a weight group
305
+ """
306
+ N = 2
307
+ M = 4
308
+
309
+ mask = None
310
+ # for i, tensor in enumerate(tensors):
311
+ if tensor.numel() % M != 0:
312
+ raise ValueError(
313
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
314
+ )
315
+
316
+ num_groups = tensor.numel() // M
317
+
318
+ # N:M sparsity for linear layers
319
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
320
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
321
+
322
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
323
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
324
+
325
+ return mask
326
+
327
+
328
+ def inject_24(w, size_k, size_n):
329
+ assert w.shape == (size_k, size_n)
330
+
331
+ mask = mask_creator(w.t()).t().cuda().bool()
332
+
333
+ return (mask * w).contiguous(), mask.contiguous()
334
+
335
+
336
+ def check_24(w, num_rows_to_sample=50, _verbose=False):
337
+ BLOCK_SIZE = 4
338
+ MAX_NON_ZEROS = 2
339
+
340
+ w = w.t().contiguous()
341
+
342
+ print("check_24: w.shape = {}".format(w.shape))
343
+
344
+ num_rows, num_cols = w.shape
345
+ sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
346
+ if _verbose:
347
+ print(f"Sampled row idxs = {sampled_row_idxs}")
348
+
349
+ total_segments = 0
350
+ non_24_segments = 0
351
+ for i in sampled_row_idxs:
352
+ for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
353
+ total_segments += 1
354
+ block = w[i, j : j + BLOCK_SIZE]
355
+ num_nonzero = torch.count_nonzero(block)
356
+ if num_nonzero > MAX_NON_ZEROS:
357
+ print("i = {} j = {} block = {}".format(i, j, block))
358
+ non_24_segments += 1
359
+
360
+ print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
361
+
362
+
363
+ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
364
+ assert q_24.shape == (size_k, size_n)
365
+
366
+ # Remove bias to normalize over 0
367
+ q_24_no_zp = q_24 - wtype.bias
368
+
369
+ # Compress
370
+ q_24_no_zp = q_24_no_zp.t().contiguous()
371
+ q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
372
+ q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
373
+
374
+ # Restore bias
375
+ q_24_comp = q_24_no_zp_comp + wtype.bias
376
+
377
+ # Resize meta to its actual shape (without moving any data)
378
+ meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
379
+
380
+ return q_24_comp, meta
381
+
382
+
383
+ def get_scale_perms_24():
384
+ scale_perm: List[int] = []
385
+ for i in range(8):
386
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
387
+ scale_perm_single: List[int] = []
388
+ for i in range(8):
389
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
390
+ return scale_perm, scale_perm_single
391
+
392
+
393
+ def get_weight_perm_24(num_bits: int):
394
+ perm_list: List[int] = []
395
+ for i in range(32):
396
+ perm1: List[int] = []
397
+ col = i // 4
398
+ col_o = col // 2
399
+ for block in [0, 1]:
400
+ for row in [
401
+ 2 * (i % 4),
402
+ 2 * (i % 4) + 1,
403
+ 2 * (i % 4 + 4),
404
+ 2 * (i % 4 + 4) + 1,
405
+ ]:
406
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
407
+ for j in range(4):
408
+ perm_list.extend([p + 1 * j for p in perm1])
409
+ perm = numpy.array(perm_list)
410
+
411
+ if num_bits == 4:
412
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
413
+ elif num_bits == 8:
414
+ interleave = numpy.array([0, 2, 1, 3])
415
+ else:
416
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
417
+
418
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
419
+ perm = torch.from_numpy(perm)
420
+ return perm
421
+
422
+
423
+ def marlin_permute_scales_24(
424
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
425
+ ) -> torch.Tensor:
426
+
427
+ scale_perm, scale_perm_single = get_scale_perms_24()
428
+ if group_size < size_k and group_size != -1:
429
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
430
+ else:
431
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
432
+ s = s.reshape((-1, size_n)).contiguous()
433
+
434
+ return s
435
+
436
+
437
+ def marlin_24_quantize(
438
+ w: torch.Tensor,
439
+ quant_type: ScalarType,
440
+ group_size: int,
441
+ ):
442
+ size_k, size_n = w.shape
443
+
444
+ # Normalize group_size
445
+ if group_size == -1:
446
+ group_size = size_k
447
+ assert group_size <= size_k
448
+
449
+ # Inject 2:4 sparsity
450
+ w_24, mask_24 = inject_24(w, size_k, size_n)
451
+
452
+ # Quantize
453
+ w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
454
+ w_24, quant_type, group_size, act_order=False
455
+ )
456
+
457
+ # Compress quantized weight
458
+ q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
459
+ size_k_comp = size_k // 2
460
+
461
+ # Reformat to marlin
462
+ weight_perm = get_weight_perm_24(quant_type.size_bits)
463
+ marlin_24_q_w_comp = marlin_weights(
464
+ q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
465
+ )
466
+ marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
467
+
468
+ # Create result
469
+ res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
470
+ for i in range(len(res_list)):
471
+ res_list[i] = res_list[i].to(w.device)
472
+
473
+ return res_list
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ from .marlin_utils_test import marlin_permute_weights
7
+ from .quant_utils import get_pack_factor, qqq_quantize_weights
8
+
9
+
10
+ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
11
+ # Permute
12
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
13
+
14
+ # Pack
15
+ pack_factor = get_pack_factor(num_bits)
16
+ orig_device = q_w.device
17
+
18
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
19
+
20
+ q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
21
+ dtype=numpy.uint32)
22
+ if group_size == size_k:
23
+ for i in range(pack_factor):
24
+ q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
25
+ else:
26
+ for i in range(pack_factor):
27
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
28
+
29
+ q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
30
+
31
+ return q_packed
32
+
33
+
34
+ def get_qqq_scale_perms():
35
+ scale_perm: List[int] = []
36
+ for i in range(8):
37
+ scale_perm.extend([i + 8 * j for j in range(8)])
38
+ scale_perm_single: List[int] = []
39
+ for i in range(4):
40
+ scale_perm_single.extend(
41
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
42
+ return scale_perm, scale_perm_single
43
+
44
+
45
+ # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
46
+ def get_qqq_weight_perm(num_bits: int, quant_type: str):
47
+ perm_list: List[int] = []
48
+ for i in range(32):
49
+ perm1: List[int] = []
50
+ col = i // 4
51
+ for block in [0, 1]:
52
+ for row in [
53
+ 4 * (i % 4),
54
+ 4 * (i % 4) + 1,
55
+ 4 * (i % 4) + 2,
56
+ 4 * (i % 4) + 3,
57
+ ]:
58
+ perm1.append(16 * row + col + 8 * block)
59
+ for j in range(4):
60
+ perm_list.extend([p + 256 * j for p in perm1])
61
+
62
+ perm = numpy.array(perm_list)
63
+
64
+ assert quant_type in ["per-channel",
65
+ "per-group"], "not supported quantization type"
66
+ if num_bits == 4:
67
+ if quant_type == "per-channel":
68
+ interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
69
+ else:
70
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
71
+ else:
72
+ raise Exception("num_bits must be 4, got {}".format(num_bits))
73
+
74
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
75
+ perm = torch.from_numpy(perm)
76
+ return perm
77
+
78
+
79
+ def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
80
+ scale_perm, scale_perm_single = get_qqq_scale_perms()
81
+ if group_size < size_k and group_size != -1:
82
+ s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
83
+ s_channel = s_channel.reshape(
84
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
85
+ s_group = s_group.reshape((-1, size_n)).contiguous()
86
+ else:
87
+ s_channel = s_channel.reshape(
88
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
89
+ s_channel = s_channel.reshape((-1, size_n)).contiguous()
90
+
91
+ return s_group, s_channel
92
+
93
+
94
+ def marlin_qqq_quantize(
95
+ w: torch.Tensor,
96
+ num_bits: int,
97
+ group_size: int,
98
+ ):
99
+ size_k, size_n = w.shape
100
+
101
+ # Normalize group_size
102
+ if group_size == -1:
103
+ group_size = size_k
104
+ assert group_size <= size_k
105
+ quant_type = "per-channel" if group_size == size_k else "per-group"
106
+
107
+ # Quantize
108
+ w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
109
+ w, num_bits, group_size)
110
+
111
+ # Reformat to marlin_qqq
112
+ weight_perm = get_qqq_weight_perm(num_bits, quant_type)
113
+ marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
114
+ weight_perm, group_size)
115
+ marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
116
+ s_group, s_channel, size_k, size_n, group_size)
117
+
118
+ # Create result
119
+ res_list = [
120
+ w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
121
+ ]
122
+ for i in range(len(res_list)):
123
+ res_list[i] = res_list[i].to(w.device)
124
+
125
+ return res_list
build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is used for /tests and /benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy
6
+ import torch
7
+
8
+ from quantization.scalar_type import ScalarType, scalar_types
9
+
10
+ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
+ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
12
+
13
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
14
+
15
+ # Note: this is a hack. We should update each model to register the
16
+ # stacked params and get it from there instead in a future PR.
17
+ # fused_name: List[shard_name]
18
+ FUSED_LAYER_NAME_MAPPING = {
19
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
20
+ "gate_up_proj": ["gate_proj", "up_proj"],
21
+ }
22
+
23
+
24
+ def pack_quantized_values_into_int32(
25
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
26
+ ):
27
+ # move dim to pack to the end
28
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
29
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
30
+ w_q_perm = w_q.permute(perm)
31
+
32
+ pack_factor = 32 // wtype.size_bits
33
+ mask = (1 << wtype.size_bits) - 1
34
+
35
+ new_shape_perm = list(w_q_perm.shape)
36
+ assert w_q_perm.shape[-1] % pack_factor == 0
37
+ new_shape_perm[-1] //= pack_factor
38
+
39
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
40
+ for i in range(pack_factor):
41
+ res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
42
+
43
+ return res.permute(inv_perm)
44
+
45
+
46
+ def unpack_quantized_values_into_int32(
47
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
48
+ ):
49
+ # move dim to pack to the end
50
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
51
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
52
+ w_q_perm = w_q.permute(perm)
53
+
54
+ pack_factor = 32 // wtype.size_bits
55
+ mask = (1 << wtype.size_bits) - 1
56
+
57
+ new_shape_perm = list(w_q_perm.shape)
58
+ new_shape_perm[-1] *= pack_factor
59
+
60
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
61
+ for i in range(pack_factor):
62
+ res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
63
+
64
+ return res.permute(inv_perm)
65
+
66
+
67
+ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
68
+ # prefix: model.layers.0.self_attn.q_proj
69
+ # proj_name: q_proj
70
+ proj_name = prefix.split(".")[-1]
71
+ if proj_name in FUSED_LAYER_NAME_MAPPING:
72
+ shard_prefixes = [
73
+ prefix.replace(proj_name, shard_proj_name)
74
+ for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
75
+ ]
76
+
77
+ is_skipped = None
78
+ for shard_prefix in shard_prefixes:
79
+ is_shard_skipped = shard_prefix in ignored_layers
80
+
81
+ if is_skipped is None:
82
+ is_skipped = is_shard_skipped
83
+ elif is_shard_skipped != is_skipped:
84
+ raise ValueError(
85
+ f"Detected some but not all shards of {prefix} "
86
+ "are quantized. All shards of fused layers "
87
+ "to have the same precision."
88
+ )
89
+ else:
90
+ is_skipped = prefix in ignored_layers
91
+
92
+ assert is_skipped is not None
93
+ return is_skipped
94
+
95
+
96
+ def get_pack_factor(num_bits):
97
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
98
+ return 32 // num_bits
99
+
100
+
101
+ def permute_rows(
102
+ q_w: torch.Tensor,
103
+ w_ref: torch.Tensor,
104
+ group_size: int,
105
+ test_perm: Optional[torch.Tensor] = None,
106
+ ):
107
+ assert q_w.shape == w_ref.shape
108
+
109
+ orig_device = q_w.device
110
+ k_size, _ = q_w.shape
111
+
112
+ g_idx = torch.zeros((k_size,), dtype=torch.int32)
113
+ for i in range(k_size):
114
+ g_idx[i] = i // group_size
115
+
116
+ # Simulate act_order by doing a random permutation on K
117
+ rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
118
+
119
+ g_idx = g_idx[rand_perm].contiguous()
120
+ q_w = q_w[rand_perm, :].contiguous()
121
+ w_ref = w_ref[rand_perm, :].contiguous()
122
+
123
+ return (
124
+ w_ref.to(device=orig_device),
125
+ q_w.to(device=orig_device),
126
+ g_idx.to(device=orig_device),
127
+ rand_perm.to(device=orig_device),
128
+ )
129
+
130
+
131
+ def quantize_weights(
132
+ w: torch.Tensor,
133
+ quant_type: ScalarType,
134
+ group_size: Optional[int],
135
+ zero_points: bool = False,
136
+ ref_zero_points_after_scales: bool = False,
137
+ ):
138
+ assert (
139
+ quant_type.is_integer()
140
+ ), "Floating point quantization may work but has not been tested"
141
+ assert not zero_points or group_size is not None, (
142
+ "to have group zero points, group_size must be provided "
143
+ "(-1 group_size is channelwise)"
144
+ )
145
+
146
+ orig_device = w.device
147
+ orig_type = w.dtype
148
+ size_k, size_n = w.shape
149
+
150
+ assert w.is_floating_point(), "w must be float"
151
+
152
+ if group_size == -1:
153
+ group_size = size_k
154
+
155
+ # Reshape to [groupsize, -1]
156
+ if group_size is not None and group_size < size_k:
157
+ w = w.reshape((-1, group_size, size_n))
158
+ w = w.permute(1, 0, 2)
159
+ w = w.reshape((group_size, -1))
160
+
161
+ # Compute scale for each group
162
+ max_val = torch.max(w, 0, keepdim=True).values
163
+ min_val = torch.min(w, 0, keepdim=True).values
164
+
165
+ max_q_val = quant_type.max()
166
+ min_q_val = quant_type.min()
167
+
168
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
169
+ maybe_w_zp = None
170
+ if group_size is not None:
171
+ if zero_points:
172
+ assert not quant_type.is_signed() and quant_type.max() > 0
173
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
174
+ maybe_w_zp = (
175
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
176
+ )
177
+ else:
178
+ # If the bias is such that there are no possible negative/positive
179
+ # values, set the max value to inf to avoid divide by 0
180
+ w_s = torch.max(
181
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
182
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
183
+ )
184
+
185
+ # Quantize
186
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
187
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
188
+
189
+ # Compute ref (dequantized)
190
+ # For some kernels (namely Machete) the zero-points are applied after the
191
+ # scales are applied, for this case computing the reference in similar way
192
+ # allows us to use tighter error tolerances in our unit tests.
193
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
194
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
195
+ else:
196
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
197
+
198
+ if quant_type.has_bias():
199
+ w_q += quant_type.bias
200
+
201
+ # Restore original shapes
202
+ if group_size is not None and group_size < size_k:
203
+
204
+ def reshape_w(w):
205
+ w = w.reshape((group_size, -1, size_n))
206
+ w = w.permute(1, 0, 2)
207
+ w = w.reshape((size_k, size_n)).contiguous()
208
+ return w
209
+
210
+ w_q = reshape_w(w_q)
211
+ w_ref = reshape_w(w_ref)
212
+ w_s = w_s.reshape((-1, size_n)).contiguous()
213
+
214
+ if maybe_w_zp is not None:
215
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
216
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
217
+
218
+ return (
219
+ w_ref.to(device=orig_device),
220
+ w_q.to(device=orig_device),
221
+ w_s if group_size is not None else None,
222
+ maybe_w_zp,
223
+ )
224
+
225
+
226
+ def gptq_quantize_weights(
227
+ w: torch.Tensor,
228
+ quant_type: ScalarType,
229
+ group_size: int,
230
+ act_order: bool,
231
+ test_perm: Optional[torch.Tensor] = None,
232
+ ):
233
+ size_k, _ = w.shape
234
+
235
+ assert w.is_floating_point(), "w must be float"
236
+ assert (
237
+ quant_type in SUPPORTED_GPTQ_QUANT_TYPES
238
+ ), f"Unsupported gptq type = {quant_type}"
239
+ assert group_size in SUPPORTED_GROUP_SIZES + [
240
+ size_k
241
+ ], f"Unsupported groupsize = {group_size}"
242
+
243
+ w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
244
+
245
+ # Apply act_order
246
+ g_idx = torch.empty(0, dtype=torch.int, device=w.device)
247
+ rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
248
+ if act_order:
249
+ assert (
250
+ group_size < size_k
251
+ ), "For act_order, groupsize = {} must be less than size_k = {}".format(
252
+ group_size, size_k
253
+ )
254
+
255
+ w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
256
+
257
+ return w_ref, w_q, w_s, g_idx, rand_perm
258
+
259
+
260
+ # QQQ employs different quant schemes for per-group and
261
+ # per-channel quantization.
262
+ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
263
+ orig_device = w.device
264
+ size_k, size_n = w.shape
265
+
266
+ assert w.is_floating_point(), "w must be float"
267
+ assert (
268
+ num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
269
+ ), f"Unsupported num_bits = {num_bits}"
270
+ assert group_size in SUPPORTED_GROUP_SIZES + [
271
+ size_k
272
+ ], f"Unsupported groupsize = {group_size}"
273
+
274
+ if group_size == -1:
275
+ group_size = size_k
276
+ assert group_size <= size_k
277
+
278
+ if group_size < size_k:
279
+ # Reshape to [groupsize, -1]
280
+ w = w.reshape((-1, group_size, size_n))
281
+ w = w.permute(1, 0, 2)
282
+ w = w.reshape((group_size, -1))
283
+
284
+ max_q_val = 2**num_bits - 1
285
+ half_q_val = (max_q_val + 1) // 2
286
+
287
+ # Compute scale for each group
288
+ s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
289
+ s_group *= 2 / max_q_val # 2 => symmetric
290
+
291
+ # Quantize
292
+ q_w = torch.round(w / s_group).int()
293
+ q_w += half_q_val
294
+ q_w = torch.clamp(q_w, 0, max_q_val)
295
+ # Compute ref (dequantized)
296
+ w_ref = (q_w - half_q_val).half() * s_group
297
+
298
+ # Restore original shapes
299
+ def reshape_w(w):
300
+ w = w.reshape((group_size, -1, size_n))
301
+ w = w.permute(1, 0, 2)
302
+ w = w.reshape((size_k, size_n)).contiguous()
303
+ return w
304
+
305
+ q_w = reshape_w(q_w)
306
+ w_ref = reshape_w(w_ref)
307
+
308
+ # Compute int8 quantization scale for each channel
309
+ s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
310
+ s_channel /= 127.0
311
+ t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
312
+ w_ref = t_int8.half() * s_channel
313
+ s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
314
+
315
+ # Fuse scales
316
+ s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
317
+ dtype=torch.half
318
+ )
319
+ else:
320
+ max_q_val = 2 ** (num_bits - 1) - 1
321
+
322
+ # Compute scale for each channel
323
+ s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
324
+ s_channel /= max_q_val
325
+
326
+ # Quantize
327
+ q_w = torch.round(w / s_channel).int()
328
+ q_w = torch.clamp(q_w, -max_q_val, max_q_val)
329
+ # Compute ref (dequantized)
330
+ w_ref = q_w.half() * s_channel
331
+
332
+ s_group = torch.tensor([], dtype=torch.half)
333
+ # div 2 ** (8 - self.bits)) to offset right shift in unpacking
334
+ s_channel /= 2 ** (8 - num_bits)
335
+ s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
336
+
337
+ return (
338
+ w_ref.to(device=orig_device),
339
+ q_w.to(device=orig_device),
340
+ s_group.to(device=orig_device),
341
+ s_channel.to(device=orig_device),
342
+ )
343
+
344
+
345
+ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
346
+ orig_device = q_w.device
347
+
348
+ sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
349
+
350
+ g_idx = g_idx[sort_indices].contiguous()
351
+ q_w = q_w[sort_indices, :].contiguous()
352
+
353
+ return (
354
+ q_w.to(device=orig_device),
355
+ g_idx.to(device=orig_device),
356
+ sort_indices.to(device=orig_device),
357
+ )
358
+
359
+
360
+ def pack_rows(
361
+ q_w: torch.Tensor,
362
+ num_bits: int,
363
+ size_k: int,
364
+ size_n: int,
365
+ ):
366
+ assert q_w.shape == (size_k, size_n)
367
+
368
+ pack_factor = get_pack_factor(num_bits)
369
+ assert size_k % pack_factor == 0
370
+
371
+ orig_device = q_w.device
372
+
373
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
374
+
375
+ q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
376
+
377
+ for i in range(pack_factor):
378
+ q_res |= q_w[i::pack_factor, :] << num_bits * i
379
+
380
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
381
+ return q_res
382
+
383
+
384
+ def pack_cols(
385
+ q_w: torch.Tensor,
386
+ num_bits: int,
387
+ size_k: int,
388
+ size_n: int,
389
+ ):
390
+ assert q_w.shape == (size_k, size_n)
391
+
392
+ pack_factor = get_pack_factor(num_bits)
393
+ assert size_n % pack_factor == 0
394
+
395
+ orig_device = q_w.device
396
+
397
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
398
+
399
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
400
+
401
+ for i in range(pack_factor):
402
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
403
+
404
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
405
+ q_res = q_res.contiguous()
406
+
407
+ return q_res
408
+
409
+
410
+ def unpack_cols(
411
+ packed_q_w: torch.Tensor,
412
+ num_bits: int,
413
+ size_k: int,
414
+ size_n: int,
415
+ ):
416
+ pack_factor = get_pack_factor(num_bits)
417
+ assert size_n % pack_factor == 0
418
+ assert packed_q_w.shape == (
419
+ size_k,
420
+ size_n // pack_factor,
421
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
422
+ packed_q_w.shape, size_k, size_n, pack_factor
423
+ )
424
+
425
+ orig_device = packed_q_w.device
426
+
427
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
428
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
429
+
430
+ mask = (1 << num_bits) - 1
431
+ for i in range(pack_factor):
432
+ vals = packed_q_w_cpu & mask
433
+ packed_q_w_cpu >>= num_bits
434
+ q_res[:, i::pack_factor] = vals
435
+
436
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
437
+ q_res = q_res.contiguous()
438
+
439
+ return q_res
440
+
441
+
442
+ def gptq_pack(
443
+ q_w: torch.Tensor,
444
+ num_bits: int,
445
+ size_k: int,
446
+ size_n: int,
447
+ ):
448
+ return pack_rows(q_w, num_bits, size_k, size_n)
449
+
450
+
451
+ def awq_pack(
452
+ q_w: torch.Tensor,
453
+ num_bits: int,
454
+ size_k: int,
455
+ size_n: int,
456
+ ):
457
+ assert q_w.shape == (size_k, size_n)
458
+
459
+ # Interleave column dim (for the dequantize code) and pack it to int32
460
+ if num_bits == 4:
461
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
462
+ elif num_bits == 8:
463
+ interleave = numpy.array([0, 2, 1, 3])
464
+ else:
465
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
466
+
467
+ q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
468
+ q_w = q_w.reshape((-1, size_n)).contiguous()
469
+
470
+ return pack_cols(q_w, num_bits, size_k, size_n)
build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,150 +1,30 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch
4
-
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
- ops = torch.ops._quantization
12
- except ImportError:
13
- raise e
14
-
15
- def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
16
- return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
17
-
18
- def cutlass_scaled_mm(a: torch.Tensor,
19
- b: torch.Tensor,
20
- scale_a: torch.Tensor,
21
- scale_b: torch.Tensor,
22
- out_dtype: torch.dtype,
23
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
24
- assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
25
- assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
26
- assert bias is None or bias.shape[0] == b.shape[
27
- 1] and bias.dtype == out_dtype
28
-
29
- m = a.shape[0]
30
- n = b.shape[1]
31
-
32
- #if current_platform.is_rocm():
33
- # triton_scaled_mm_module = importlib.import_module(
34
- # "vllm.model_executor.layers.quantization.compressed_tensors."
35
- # "triton_scaled_mm")
36
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
37
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
38
-
39
- out = torch.empty((m, n), dtype=out_dtype, device=a.device)
40
-
41
- ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
42
-
43
- return out
44
-
45
- # fp8
46
- def scaled_fp8_quant(
47
- input: torch.Tensor,
48
- scale: Optional[torch.Tensor] = None,
49
- num_token_padding: Optional[int] = None,
50
- scale_ub: Optional[torch.Tensor] = None,
51
- use_per_token_if_dynamic: bool = False,
52
- ) -> Tuple[torch.Tensor, torch.Tensor]:
53
- """
54
- Quantize input tensor to FP8 and return quantized tensor and scale.
55
-
56
- This function supports both static and dynamic quantization: If you
57
- provide the scale, it will use static scaling and if you omit it,
58
- the scale will be determined dynamically. The function also allows
59
- optional padding of the output tensors for downstream kernels that
60
- will benefit from padding.
61
-
62
- Args:
63
- input: The input tensor to be quantized to FP8
64
- scale: Optional scaling factor for the FP8 quantization
65
- scale_ub: Optional upper bound for scaling factor in dynamic
66
- per token case
67
- num_token_padding: If specified, pad the first dimension
68
- of the output to at least this value.
69
- use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
- in the dynamic quantization case.
71
-
72
- Returns:
73
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
- scaling factor.
75
- """
76
- # This code assumes batch_dim and num_tokens are flattened
77
- assert (input.ndim == 2)
78
- shape: Union[Tuple[int, int], torch.Size] = input.shape
79
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
- #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
- # if current_platform.is_rocm() else torch.float8_e4m3fn
82
- out_dtype = torch.float8_e4m3fn
83
- if num_token_padding:
84
- shape = (max(num_token_padding, input.shape[0]), shape[1])
85
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
-
87
- if scale is None:
88
- if use_per_token_if_dynamic:
89
- scale = torch.empty((shape[0], 1),
90
- device=input.device,
91
- dtype=torch.float32)
92
- ops.dynamic_per_token_scaled_fp8_quant(
93
- output, input, scale, scale_ub)
94
- else:
95
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
- ops.dynamic_scaled_fp8_quant(output, input, scale)
97
- else:
98
- # num_token_padding not implemented for this case
99
- assert (scale.numel() == 1 or num_token_padding is None)
100
- ops.static_scaled_fp8_quant(output, input, scale)
101
-
102
- return output, scale
103
-
104
- # int8
105
- def scaled_int8_quant(
106
- input: torch.Tensor,
107
- scale: Optional[torch.Tensor] = None,
108
- azp: Optional[torch.Tensor] = None,
109
- symmetric: bool = True
110
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
- """
112
- Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
-
114
- Args:
115
- input: The input tensor to be quantized to int8.
116
- scale: Optional scaling factor for the int8 quantization.
117
- When not provided, we invoke dynamic-per-token quantization.
118
- azp: Optional zero-point for the int8 quantization.
119
- Must be provided for asymmetric quantization if `scale` is provided.
120
- symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
-
122
- Returns:
123
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
- """
125
- output = torch.empty_like(input, dtype=torch.int8)
126
- if scale is not None:
127
- # static-per-tensor quantization.
128
- assert symmetric == (
129
- azp is
130
- None), "azp must only be provided for asymmetric quantization."
131
- ops.static_scaled_int8_quant(output, input, scale, azp)
132
- return output, scale, azp
133
-
134
- # dynamic-per-token quantization.
135
- input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
- device=input.device,
137
- dtype=torch.float32)
138
- input_azp = None if symmetric else torch.empty_like(input_scales,
139
- dtype=torch.int32)
140
- ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
- input_azp)
142
- return output, input_scales, input_azp
143
-
144
- # fp8 marlin
145
- def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
- b_scales: torch.Tensor, workspace: torch.Tensor,
147
- num_bits: int, size_m: int, size_n: int,
148
- size_k: int) -> torch.Tensor:
149
- return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
- num_bits, size_m, size_n, size_k)
 
1
+ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
+ from .cutlass import (
3
+ cutlass_scaled_mm_supports_fp8,
4
+ cutlass_scaled_mm,
5
+ cutlass_scaled_mm_azp,
6
+ )
7
+ from .marlin import (
8
+ awq_marlin_repack,
9
+ fp8_marlin_gemm,
10
+ gptq_marlin_gemm,
11
+ gptq_marlin_repack,
12
+ gptq_marlin_24_gemm,
13
+ marlin_qqq_gemm,
14
+ marlin_gemm,
15
+ )
16
+
17
+ __all__ = [
18
+ "awq_marlin_repack",
19
+ "cutlass_scaled_mm",
20
+ "cutlass_scaled_mm_azp",
21
+ "cutlass_scaled_mm_supports_fp8",
22
+ "fp8_marlin_gemm",
23
+ "gptq_marlin_24_gemm",
24
+ "gptq_marlin_gemm",
25
+ "gptq_marlin_repack",
26
+ "marlin_gemm",
27
+ "marlin_qqq_gemm",
28
+ "scaled_fp8_quant",
29
+ "scaled_int8_quant",
30
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch24-cxx11-cu124-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,3 +1,9 @@
1
  import torch
2
  from . import _quantization_0_0_1
3
  ops = torch.ops._quantization_0_0_1
 
 
 
 
 
 
 
1
  import torch
2
  from . import _quantization_0_0_1
3
  ops = torch.ops._quantization_0_0_1
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_quantization_0_0_1::{op_name}"
build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1807b0b92ecf8fbb62bdf33fc4ba57dbabc177082c3e0e24b1ac7cd085462ae0
3
- size 89584848
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a458d5efc51f80028811707ee7b9fadb00f3bfc49917c8377188c286c4bd8e12
3
+ size 109249352
build/torch24-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ # fp8
18
+ def scaled_fp8_quant(
19
+ input: torch.Tensor,
20
+ scale: Optional[torch.Tensor] = None,
21
+ num_token_padding: Optional[int] = None,
22
+ scale_ub: Optional[torch.Tensor] = None,
23
+ use_per_token_if_dynamic: bool = False,
24
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Quantize input tensor to FP8 and return quantized tensor and scale.
27
+
28
+ This function supports both static and dynamic quantization: If you
29
+ provide the scale, it will use static scaling and if you omit it,
30
+ the scale will be determined dynamically. The function also allows
31
+ optional padding of the output tensors for downstream kernels that
32
+ will benefit from padding.
33
+
34
+ Args:
35
+ input: The input tensor to be quantized to FP8
36
+ scale: Optional scaling factor for the FP8 quantization
37
+ scale_ub: Optional upper bound for scaling factor in dynamic
38
+ per token case
39
+ num_token_padding: If specified, pad the first dimension
40
+ of the output to at least this value.
41
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
42
+ in the dynamic quantization case.
43
+
44
+ Returns:
45
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
+ scaling factor.
47
+ """
48
+ # This code assumes batch_dim and num_tokens are flattened
49
+ assert input.ndim == 2
50
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
51
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
+ # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
54
+ out_dtype = torch.float8_e4m3fn
55
+ if num_token_padding:
56
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
57
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
58
+
59
+ if scale is None:
60
+ if use_per_token_if_dynamic:
61
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
63
+ else:
64
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
66
+ else:
67
+ # num_token_padding not implemented for this case
68
+ assert scale.numel() == 1 or num_token_padding is None
69
+ ops.static_scaled_fp8_quant(output, input, scale)
70
+
71
+ return output, scale
72
+
73
+
74
+ # int8
75
+ def scaled_int8_quant(
76
+ input: torch.Tensor,
77
+ scale: Optional[torch.Tensor] = None,
78
+ azp: Optional[torch.Tensor] = None,
79
+ symmetric: bool = True,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
+ """
82
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
+
84
+ Args:
85
+ input: The input tensor to be quantized to int8.
86
+ scale: Optional scaling factor for the int8 quantization.
87
+ When not provided, we invoke dynamic-per-token quantization.
88
+ azp: Optional zero-point for the int8 quantization.
89
+ Must be provided for asymmetric quantization if `scale` is provided.
90
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
+
92
+ Returns:
93
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
+ """
95
+ output = torch.empty_like(input, dtype=torch.int8)
96
+ if scale is not None:
97
+ # static-per-tensor quantization.
98
+ assert symmetric == (
99
+ azp is None
100
+ ), "azp must only be provided for asymmetric quantization."
101
+ ops.static_scaled_int8_quant(output, input, scale, azp)
102
+ return output, scale, azp
103
+
104
+ # dynamic-per-token quantization.
105
+ input_scales = torch.empty(
106
+ (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
+ )
108
+ input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
+ ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
110
+ return output, input_scales, input_azp
build/torch24-cxx11-cu124-x86_64-linux/quantization/cutlass.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
+ return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
+
20
+
21
+ def cutlass_scaled_mm(
22
+ a: torch.Tensor,
23
+ b: torch.Tensor,
24
+ scale_a: torch.Tensor,
25
+ scale_b: torch.Tensor,
26
+ out_dtype: torch.dtype,
27
+ bias: Optional[torch.Tensor] = None,
28
+ ) -> torch.Tensor:
29
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
30
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
31
+ assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
32
+
33
+ m = a.shape[0]
34
+ n = b.shape[1]
35
+
36
+ # if current_platform.is_rocm():
37
+ # triton_scaled_mm_module = importlib.import_module(
38
+ # "vllm.model_executor.layers.quantization.compressed_tensors."
39
+ # "triton_scaled_mm")
40
+ # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
+ # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
+
43
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
+
45
+ ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
46
+
47
+ return out
48
+
49
+
50
+ def cutlass_scaled_mm_azp(
51
+ a: torch.Tensor,
52
+ b: torch.Tensor,
53
+ scale_a: torch.Tensor,
54
+ scale_b: torch.Tensor,
55
+ out_dtype: torch.dtype,
56
+ azp_adj: torch.Tensor,
57
+ azp: Optional[torch.Tensor] = None,
58
+ bias: Optional[torch.Tensor] = None,
59
+ ) -> torch.Tensor:
60
+ """
61
+ :param azp_adj: In the per-tensor case, this should include the azp.
62
+ Always per-channel.
63
+ :param azp: Only set in the per-token case. Per-token if set.
64
+ """
65
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
66
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
67
+ assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
68
+ assert azp is None or azp.numel() == a.shape[0]
69
+
70
+ m = a.shape[0]
71
+ n = b.shape[1]
72
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
73
+
74
+ ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
75
+ return out
build/torch24-cxx11-cu124-x86_64-linux/quantization/marlin.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+ def register_fake(fn):
8
+ return lambda name: fn
9
+ else:
10
+ try:
11
+ from torch.library import register_fake
12
+ except ImportError:
13
+ from torch.library import impl_abstract as register_fake
14
+
15
+ try:
16
+ from ._ops import ops, add_op_namespace_prefix
17
+ except ImportError as e:
18
+ # Fallback for local development.
19
+ try:
20
+ import _quantization
21
+
22
+ ops = torch.ops._quantization
23
+
24
+ def add_op_namespace_prefix(op_name: str):
25
+ return f"_quantization::{op_name}"
26
+ except ImportError:
27
+ raise e
28
+
29
+
30
+ from .scalar_type import ScalarType
31
+
32
+
33
+ # fp8 marlin
34
+ def fp8_marlin_gemm(
35
+ a: torch.Tensor,
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ workspace: torch.Tensor,
39
+ num_bits: int,
40
+ size_m: int,
41
+ size_n: int,
42
+ size_k: int,
43
+ ) -> torch.Tensor:
44
+ return ops.fp8_marlin_gemm(
45
+ a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
+ )
47
+
48
+
49
+ # gptq_marlin
50
+ def gptq_marlin_gemm(
51
+ a: torch.Tensor,
52
+ b_q_weight: torch.Tensor,
53
+ b_scales: torch.Tensor,
54
+ b_zeros: torch.Tensor,
55
+ g_idx: torch.Tensor,
56
+ perm: torch.Tensor,
57
+ workspace: torch.Tensor,
58
+ b_q_type: ScalarType,
59
+ size_m: int,
60
+ size_n: int,
61
+ size_k: int,
62
+ is_k_full: bool,
63
+ has_zp: bool = False,
64
+ use_fp32_reduce: bool = False,
65
+ is_zp_float: bool = False,
66
+ ) -> torch.Tensor:
67
+ return ops.gptq_marlin_gemm(
68
+ a,
69
+ b_q_weight,
70
+ b_scales,
71
+ b_zeros,
72
+ g_idx,
73
+ perm,
74
+ workspace,
75
+ b_q_type.id,
76
+ size_m,
77
+ size_n,
78
+ size_k,
79
+ is_k_full,
80
+ has_zp,
81
+ use_fp32_reduce,
82
+ is_zp_float,
83
+ )
84
+
85
+
86
+ # gptq_marlin
87
+ def gptq_marlin_repack(
88
+ b_q_weight: torch.Tensor,
89
+ perm: torch.Tensor,
90
+ size_k: int,
91
+ size_n: int,
92
+ num_bits: int,
93
+ ) -> torch.Tensor:
94
+ return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
95
+
96
+
97
+ # gptq_marlin
98
+ def awq_marlin_repack(
99
+ b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
100
+ ) -> torch.Tensor:
101
+ return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
102
+
103
+
104
+ # marlin
105
+ def marlin_gemm(
106
+ a: torch.Tensor,
107
+ b_q_weight: torch.Tensor,
108
+ b_scales: torch.Tensor,
109
+ workspace: torch.Tensor,
110
+ size_m: int,
111
+ size_n: int,
112
+ size_k: int,
113
+ ) -> torch.Tensor:
114
+ return ops.marlin_gemm(
115
+ a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
116
+ )
117
+
118
+
119
+ # marlin_24
120
+ def gptq_marlin_24_gemm(
121
+ a: torch.Tensor,
122
+ b_q_weight: torch.Tensor,
123
+ b_meta: torch.Tensor,
124
+ b_scales: torch.Tensor,
125
+ workspace: torch.Tensor,
126
+ b_q_type: ScalarType,
127
+ size_m: int,
128
+ size_n: int,
129
+ size_k: int,
130
+ ) -> torch.Tensor:
131
+ return ops.gptq_marlin_24_gemm(
132
+ a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
133
+ )
134
+
135
+
136
+ # qqq ops
137
+ def marlin_qqq_gemm(
138
+ a: torch.Tensor,
139
+ b_q_weight: torch.Tensor,
140
+ s_tok: torch.Tensor,
141
+ s_ch: torch.Tensor,
142
+ s_group: torch.Tensor,
143
+ workspace: torch.Tensor,
144
+ size_m: int,
145
+ size_n: int,
146
+ size_k: int,
147
+ ) -> torch.Tensor:
148
+ return ops.marlin_qqq_gemm(
149
+ a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
150
+ )
151
+
152
+
153
+ # Fake ops
154
+
155
+ if hasattr(ops, "gptq_marlin_24_gemm"):
156
+ @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
+ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
+ b_scales: torch.Tensor, workspace: torch.Tensor,
159
+ num_bits: int, size_m: torch.SymInt,
160
+ size_n: torch.SymInt,
161
+ size_k: torch.SymInt) -> torch.Tensor:
162
+ return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
+
164
+ @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
+ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
+ b_meta: torch.Tensor, b_scales: torch.Tensor,
167
+ workspace: torch.Tensor,
168
+ b_q_type: ScalarType, size_m: torch.SymInt,
169
+ size_n: torch.SymInt,
170
+ size_k: torch.SymInt) -> torch.Tensor:
171
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
172
+
173
+ @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
+ def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
+ b_q_weight: torch.Tensor,
176
+ b_scales: torch.Tensor,
177
+ b_zeros: torch.Tensor,
178
+ g_idx: torch.Tensor,
179
+ perm: torch.Tensor,
180
+ workspace: torch.Tensor,
181
+ b_q_type: ScalarType,
182
+ size_m: torch.SymInt,
183
+ size_n: torch.SymInt,
184
+ size_k: torch.SymInt,
185
+ is_k_full: bool,
186
+ has_zp: bool = False,
187
+ use_fp32_reduce: bool = False,
188
+ is_zp_float: bool = False) -> torch.Tensor:
189
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
+
191
+ @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
192
+ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
193
+ s_tok: torch.Tensor, s_ch: torch.Tensor,
194
+ s_group: torch.Tensor, workspace: torch.Tensor,
195
+ size_m: torch.SymInt, size_n: torch.SymInt,
196
+ size_k: torch.SymInt) -> torch.Tensor:
197
+ return torch.empty((size_m, size_n),
198
+ dtype=torch.float16,
199
+ device=a.device)
200
+
201
+ @register_fake(add_op_namespace_prefix("marlin_gemm"))
202
+ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
203
+ b_scales: torch.Tensor, workspace: torch.Tensor,
204
+ size_m: torch.SymInt, size_n: torch.SymInt,
205
+ size_k: torch.SymInt) -> torch.Tensor:
206
+ return torch.empty((size_m, size_n),
207
+ dtype=torch.float16,
208
+ device=a.device)
build/torch24-cxx11-cu124-x86_64-linux/quantization/scalar_type.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import struct
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Optional, Union
6
+
7
+
8
+ # Mirrors enum in `core/scalar_type.hpp`
9
+ class NanRepr(Enum):
10
+ NONE = 0 # nans are not supported
11
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
12
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
13
+
14
+
15
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
16
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
17
+ # in sync until the inductor fully supports custom C++ classes.
18
+ @dataclass(frozen=True)
19
+ class ScalarType:
20
+ """
21
+ ScalarType can represent a wide range of floating point and integer
22
+ types, in particular it can be used to represent sub-byte data types
23
+ (something that torch.dtype currently does not support). It is also
24
+ capable of representing types with a bias, i.e.:
25
+ `stored_value = value + bias`,
26
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
27
+ of 8). The implementation for this class can be found in
28
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
29
+ with that file.
30
+ """
31
+
32
+ exponent: int
33
+ """
34
+ Number of bits in the exponent if this is a floating point type
35
+ (zero if this an integer type)
36
+ """
37
+
38
+ mantissa: int
39
+ """
40
+ Number of bits in the mantissa if this is a floating point type,
41
+ or the number bits representing an integer excluding the sign bit if
42
+ this an integer type.
43
+ """
44
+
45
+ signed: bool
46
+ "If the type is signed (i.e. has a sign bit)"
47
+
48
+ bias: int
49
+ """
50
+ bias used to encode the values in this scalar type
51
+ (value = stored_value - bias, default 0) for example if we store the
52
+ type as an unsigned integer with a bias of 128 then the value 0 will be
53
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
54
+ """
55
+
56
+ _finite_values_only: bool = False
57
+ """
58
+ Private: if infs are supported, used `has_infs()` instead.
59
+ """
60
+
61
+ nan_repr: NanRepr = NanRepr.IEEE_754
62
+ """
63
+ How NaNs are represent in this scalar type, returns NanRepr value.
64
+ (not applicable for integer types)
65
+ """
66
+
67
+ def _floating_point_max_int(self) -> int:
68
+ assert (
69
+ self.mantissa <= 52 and self.exponent <= 11
70
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
71
+
72
+ max_mantissa = (1 << self.mantissa) - 1
73
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
74
+ max_mantissa = max_mantissa - 1
75
+
76
+ max_exponent = (1 << self.exponent) - 2
77
+ if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
78
+ or self.nan_repr == NanRepr.NONE):
79
+ assert (
80
+ self.exponent < 11
81
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
82
+ max_exponent = max_exponent + 1
83
+
84
+ # adjust the exponent to match that of a double
85
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
86
+ # e is the exponent bits), there is some precedent for non-standard
87
+ # biases, example `float8_e4m3b11fnuz` here:
88
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
89
+ # complication we are just assuming the standard exponent bias until
90
+ # there is a need to support non-standard biases
91
+ exponent_bias = (1 << (self.exponent - 1)) - 1
92
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
93
+
94
+ max_exponent_double = (max_exponent - exponent_bias +
95
+ exponent_bias_double)
96
+
97
+ # shift the mantissa and exponent into the proper positions for an
98
+ # IEEE double and bitwise-or them together.
99
+ return (max_mantissa <<
100
+ (52 - self.mantissa)) | (max_exponent_double << 52)
101
+
102
+ def _floating_point_max(self) -> float:
103
+ double_raw = self._floating_point_max_int()
104
+ return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
105
+
106
+ def _raw_max(self) -> Union[int, float]:
107
+ if self.is_floating_point():
108
+ return self._floating_point_max()
109
+ else:
110
+ assert (self.size_bits < 64 or self.size_bits == 64
111
+ and self.is_signed()), "Cannot represent max as an int"
112
+ return (1 << self.mantissa) - 1
113
+
114
+ def _raw_min(self) -> Union[int, float]:
115
+ if self.is_floating_point():
116
+ assert self.is_signed(
117
+ ), "We currently assume all floating point types are signed"
118
+ sign_bit_double = 1 << 63
119
+
120
+ max_raw = self._floating_point_max_int()
121
+ min_raw = max_raw | sign_bit_double
122
+ return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
+ else:
124
+ assert (not self.is_signed() or
125
+ self.size_bits <= 64), "Cannot represent min as a int64_t"
126
+
127
+ if self.is_signed():
128
+ return -(1 << (self.size_bits - 1))
129
+ else:
130
+ return 0
131
+
132
+ @functools.cached_property
133
+ def id(self) -> int:
134
+ """
135
+ Convert the ScalarType to an int which can be passed to pytorch custom
136
+ ops. This layout of the int must be kept in sync with the C++
137
+ ScalarType's from_id method.
138
+ """
139
+ val = 0
140
+ offset = 0
141
+
142
+ def or_and_advance(member, bit_width):
143
+ nonlocal val
144
+ nonlocal offset
145
+ bit_mask = (1 << bit_width) - 1
146
+ val = val | (int(member) & bit_mask) << offset
147
+ offset = offset + bit_width
148
+
149
+ or_and_advance(self.exponent, 8)
150
+ or_and_advance(self.mantissa, 8)
151
+ or_and_advance(self.signed, 1)
152
+ or_and_advance(self.bias, 32)
153
+ or_and_advance(self._finite_values_only, 1)
154
+ or_and_advance(self.nan_repr.value, 8)
155
+
156
+ assert offset <= 64, \
157
+ f"ScalarType fields too big {offset} to fit into an int64"
158
+
159
+ return val
160
+
161
+ @property
162
+ def size_bits(self) -> int:
163
+ return self.exponent + self.mantissa + int(self.signed)
164
+
165
+ def min(self) -> Union[int, float]:
166
+ """
167
+ Min representable value for this scalar type.
168
+ (accounting for bias if there is one)
169
+ """
170
+ return self._raw_min() - self.bias
171
+
172
+ def max(self) -> Union[int, float]:
173
+ """
174
+ Max representable value for this scalar type.
175
+ (accounting for bias if there is one)
176
+ """
177
+ return self._raw_max() - self.bias
178
+
179
+ def is_signed(self) -> bool:
180
+ """
181
+ If the type is signed (i.e. has a sign bit), same as `signed`
182
+ added for consistency with:
183
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
184
+ """
185
+ return self.signed
186
+
187
+ def is_floating_point(self) -> bool:
188
+ "If the type is a floating point type"
189
+ return self.exponent != 0
190
+
191
+ def is_integer(self) -> bool:
192
+ "If the type is an integer type"
193
+ return self.exponent == 0
194
+
195
+ def has_bias(self) -> bool:
196
+ "If the type has a non-zero bias"
197
+ return self.bias != 0
198
+
199
+ def has_infs(self) -> bool:
200
+ "If the type is floating point and supports infinity"
201
+ return not self._finite_values_only
202
+
203
+ def has_nans(self) -> bool:
204
+ return self.nan_repr != NanRepr.NONE.value
205
+
206
+ def is_ieee_754(self) -> bool:
207
+ """
208
+ If the type is a floating point type that follows IEEE 754
209
+ conventions
210
+ """
211
+ return self.nan_repr == NanRepr.IEEE_754.value and \
212
+ not self._finite_values_only
213
+
214
+ def __str__(self) -> str:
215
+ """
216
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
217
+ for floating point types (leading f) the scheme is:
218
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
219
+ flags:
220
+ - no-flags: means it follows IEEE 754 conventions
221
+ - f: means finite values only (no infinities)
222
+ - n: means nans are supported (non-standard encoding)
223
+ for integer types the scheme is:
224
+ `[u]int<size_bits>[b<bias>]`
225
+ - if bias is not present it means its zero
226
+ """
227
+ if self.is_floating_point():
228
+ ret = "float" + str(self.size_bits) + "_e" + str(
229
+ self.exponent) + "m" + str(self.mantissa)
230
+
231
+ if not self.is_ieee_754():
232
+ if self._finite_values_only:
233
+ ret = ret + "f"
234
+ if self.nan_repr != NanRepr.NONE:
235
+ ret = ret + "n"
236
+
237
+ return ret
238
+ else:
239
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
240
+ if self.has_bias():
241
+ ret = ret + "b" + str(self.bias)
242
+ return ret
243
+
244
+ def __repr__(self) -> str:
245
+ return "ScalarType." + self.__str__()
246
+
247
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
248
+ # opcheck to work.
249
+ def __len__(self) -> int:
250
+ raise TypeError
251
+
252
+ #
253
+ # Convenience Constructors
254
+ #
255
+
256
+ @classmethod
257
+ def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
258
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
259
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
260
+ ret.id # noqa B018: make sure the id is cached
261
+ return ret
262
+
263
+ @classmethod
264
+ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
+ """Create a unsigned integer scalar type."""
266
+ ret = cls(0, size_bits, False, bias if bias else 0)
267
+ ret.id # noqa B018: make sure the id is cached
268
+ return ret
269
+
270
+ @classmethod
271
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
272
+ """
273
+ Create a standard floating point type
274
+ (i.e. follows IEEE 754 conventions).
275
+ """
276
+ assert (mantissa > 0 and exponent > 0)
277
+ ret = cls(exponent, mantissa, True, 0)
278
+ ret.id # noqa B018: make sure the id is cached
279
+ return ret
280
+
281
+ @classmethod
282
+ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
283
+ nan_repr: NanRepr) -> 'ScalarType':
284
+ """
285
+ Create a non-standard floating point type
286
+ (i.e. does not follow IEEE 754 conventions).
287
+ """
288
+ assert (mantissa > 0 and exponent > 0)
289
+ assert (nan_repr != NanRepr.IEEE_754), (
290
+ "use `float_IEEE754` constructor for floating point types that "
291
+ "follow IEEE 754 conventions")
292
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
293
+ ret.id # noqa B018: make sure the id is cached
294
+ return ret
295
+
296
+
297
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
+ # for floating point types (leading f) the scheme is:
299
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
300
+ # flags:
301
+ # - no-flags: means it follows IEEE 754 conventions
302
+ # - f: means finite values only (no infinities)
303
+ # - n: means nans are supported (non-standard encoding)
304
+ # for integer types the scheme is:
305
+ # `[u]int<size_bits>[b<bias>]`
306
+ # - if bias is not present it means its zero
307
+
308
+
309
+ class scalar_types:
310
+ int4 = ScalarType.int_(4, None)
311
+ uint4 = ScalarType.uint(4, None)
312
+ int8 = ScalarType.int_(8, None)
313
+ uint8 = ScalarType.uint(8, None)
314
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
315
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
316
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
317
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
318
+
319
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
+
322
+ # "gptq" types
323
+ uint2b2 = ScalarType.uint(2, 2)
324
+ uint3b4 = ScalarType.uint(3, 4)
325
+ uint4b8 = ScalarType.uint(4, 8)
326
+ uint8b128 = ScalarType.uint(8, 128)
327
+
328
+ # colloquial names
329
+ bfloat16 = float16_e8m7
330
+ float16 = float16_e5m10
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py ADDED
File without changes
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ import quantization as ops
7
+ from quantization.scalar_type import ScalarType, scalar_types
8
+
9
+ from .quant_utils import pack_cols, unpack_cols
10
+
11
+ GPTQ_MARLIN_TILE = 16
12
+ GPTQ_MARLIN_MIN_THREAD_N = 64
13
+ GPTQ_MARLIN_MIN_THREAD_K = 128
14
+ GPTQ_MARLIN_MAX_PARALLEL = 16
15
+
16
+ GPTQ_MARLIN_24_TILE = 16
17
+ GPTQ_MARLIN_24_MIN_THREAD_N = 128
18
+ GPTQ_MARLIN_24_MIN_THREAD_K = 128
19
+ GPTQ_MARLIN_24_MAX_PARALLEL = 64
20
+
21
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
22
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
23
+
24
+ MARLIN_QQQ_TILE = 16
25
+ MARLIN_QQQ_MIN_THREAD_N = 64
26
+ MARLIN_QQQ_MIN_THREAD_K = 128
27
+ MARLIN_QQQ_MAX_PARALLEL = 16
28
+
29
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
30
+ MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
31
+ MARLIN_QQQ_SUPPORTED_SYM = [True]
32
+
33
+ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
34
+
35
+ # In case there is a performance issue with Marlin, the variable below can be
36
+ # changed to False, which allows Marlin to perform global reductions in fp16
37
+ # precision (instead of fp32), and therefore, save on some memory movements.
38
+ USE_FP32_REDUCE_DEFAULT = True
39
+
40
+
41
+ # For binary size and compile time, we don't support the same types for with and
42
+ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
+ # TODO: we may want to move this into the C++ so its closer to the actual impl
44
+ def query_marlin_supported_quant_types(
45
+ has_zp: bool, device_capability: Optional[int] = None
46
+ ):
47
+ if device_capability is None:
48
+ capability_tuple = torch.cuda.get_device_capability()
49
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
50
+
51
+ if device_capability < 80:
52
+ return []
53
+
54
+ if has_zp:
55
+ # AWQ style, unsigned + runtime zero-point
56
+ return [scalar_types.uint4, scalar_types.uint8]
57
+ else:
58
+ # GPTQ style, unsigned + symmetric bias
59
+ # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
+ # to add `scalar_types.float8_e4m3fn` here
61
+ return [scalar_types.uint4b8, scalar_types.uint8b128]
62
+
63
+
64
+ def _check_marlin_supported(
65
+ quant_type: ScalarType,
66
+ group_size: Optional[int],
67
+ has_zp: bool,
68
+ device_capability: Optional[int] = None,
69
+ ) -> Tuple[bool, Optional[str]]:
70
+
71
+ if device_capability is None:
72
+ capability_tuple = torch.cuda.get_device_capability()
73
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
+
75
+ supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
76
+
77
+ if quant_type not in supported_types:
78
+ return (
79
+ False,
80
+ f"Marlin does not support weight_bits = {quant_type}. "
81
+ f"Only types = {supported_types} "
82
+ f"are supported (for group_size = {group_size}, "
83
+ f"device_capability = {device_capability}, zp = {has_zp}).",
84
+ )
85
+ if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
+ return (
87
+ False,
88
+ f"Marlin does not support group_size = {group_size}. "
89
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
+ "are supported.",
91
+ )
92
+
93
+ return True, None
94
+
95
+
96
+ def check_marlin_supported(
97
+ quant_type: ScalarType,
98
+ group_size: int,
99
+ has_zp: bool = False,
100
+ device_capability: Optional[int] = None,
101
+ ) -> bool:
102
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
+ return cond
104
+
105
+
106
+ def verify_marlin_supported(
107
+ quant_type: ScalarType, group_size: int, has_zp: bool = False
108
+ ) -> None:
109
+ cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
+ if not cond:
111
+ assert err_msg is not None
112
+ raise ValueError(err_msg)
113
+
114
+
115
+ def verify_marlin_supports_shape(
116
+ output_size_per_partition: int,
117
+ input_size_per_partition: int,
118
+ input_size: int,
119
+ group_size: int,
120
+ ) -> None:
121
+
122
+ # Validate output_size_per_partition
123
+ if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
+ raise ValueError(
125
+ f"Weight output_size_per_partition = "
126
+ f"{output_size_per_partition} is not divisible by "
127
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
+ "Consider reducing tensor_parallel_size or running "
129
+ "with --quantization gptq."
130
+ )
131
+
132
+ # Validate input_size_per_partition
133
+ if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
+ raise ValueError(
135
+ f"Weight input_size_per_partition = "
136
+ f"{input_size_per_partition} is not divisible "
137
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
+ "Consider reducing tensor_parallel_size or running "
139
+ "with --quantization gptq."
140
+ )
141
+
142
+ if group_size < input_size and input_size_per_partition % group_size != 0:
143
+ raise ValueError(
144
+ f"Weight input_size_per_partition = {input_size_per_partition}"
145
+ f" is not divisible by group_size = {group_size}."
146
+ "Consider reducing tensor_parallel_size or running "
147
+ "with --quantization gptq."
148
+ )
149
+
150
+
151
+ def check_marlin_supports_shape(
152
+ output_size_per_partition: int,
153
+ input_size_per_partition: int,
154
+ input_size: int,
155
+ group_size: int,
156
+ ) -> Tuple[bool, Optional[str]]:
157
+ try:
158
+ verify_marlin_supports_shape(
159
+ output_size_per_partition, input_size_per_partition, input_size, group_size
160
+ )
161
+ except ValueError as e:
162
+ return False, e.__str__()
163
+ return True, None
164
+
165
+
166
+ def marlin_make_workspace(
167
+ output_size_per_partition: int, device: torch.device
168
+ ) -> torch.Tensor:
169
+ max_workspace_size = (
170
+ output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
+ ) * GPTQ_MARLIN_MAX_PARALLEL
172
+
173
+ return torch.zeros(
174
+ max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
+ )
176
+
177
+
178
+ def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
+ return (not act_order) or (act_order and not is_row_parallel)
180
+
181
+
182
+ def marlin_repeat_scales_on_all_ranks(
183
+ act_order: bool, group_size: int, is_row_parallel: bool
184
+ ) -> bool:
185
+ # Need to repeat scales on every rank if act_ordering or
186
+ # channelwise and RowParallelLinear
187
+ is_channelwise = group_size == -1
188
+ return act_order or (is_channelwise and is_row_parallel)
189
+
190
+
191
+ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
+ return torch.nn.Parameter(
193
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
+ )
195
+
196
+
197
+ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
+ return torch.nn.Parameter(
199
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
+ )
201
+
202
+
203
+ def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
+ return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
+
207
+
208
+ def get_scale_perms():
209
+ scale_perm: List[int] = []
210
+ for i in range(8):
211
+ scale_perm.extend([i + 8 * j for j in range(8)])
212
+ scale_perm_single: List[int] = []
213
+ for i in range(4):
214
+ scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
215
+ return scale_perm, scale_perm_single
216
+
217
+
218
+ def marlin_permute_scales(
219
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
+ ) -> torch.Tensor:
221
+
222
+ scale_perm, scale_perm_single = get_scale_perms()
223
+ if group_size < size_k and group_size != -1:
224
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
225
+ else:
226
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
227
+ s = s.reshape((-1, size_n)).contiguous()
228
+
229
+ return s
230
+
231
+
232
+ def marlin_moe_permute_scales(
233
+ s: torch.Tensor,
234
+ size_k: int,
235
+ size_n: int,
236
+ group_size: int,
237
+ ):
238
+ num_experts = s.shape[0]
239
+ output = torch.empty(
240
+ (num_experts, s.shape[1], s.shape[2]),
241
+ device=s.device,
242
+ dtype=s.dtype,
243
+ )
244
+
245
+ for e in range(num_experts):
246
+ output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
247
+ return output
248
+
249
+
250
+ def marlin_zero_points(
251
+ zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
+ ) -> torch.Tensor:
253
+ # Permute zero-points in a similar way to scales, but do not use the
254
+ # "single" permutation, since zero-points are applied on every MMA
255
+ scale_perm, _ = get_scale_perms()
256
+ zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
257
+
258
+ # Interleave column dim (for the dequantize code) and pack it to int32
259
+ if num_bits == 4:
260
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
261
+ elif num_bits == 8:
262
+ interleave = numpy.array([0, 2, 1, 3])
263
+ else:
264
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
265
+
266
+ zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
267
+ zp = zp.reshape((-1, size_n)).contiguous()
268
+ zp = pack_cols(zp, num_bits, size_k, size_n)
269
+
270
+ return zp
271
+
272
+
273
+ def awq_to_marlin_zero_points(
274
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
+ ) -> torch.Tensor:
276
+ # AWQ zero-points are quantized and packed on the column dim.
277
+ # In addition, the values are permuted based on dequantizer.
278
+ # Here we undo both of these, and then apply marlin permutation
279
+ # and pack it back.
280
+ q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
281
+
282
+ # Undo interleaving (use argsort(..) to get inverse perm)
283
+ if num_bits == 4:
284
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
285
+ elif num_bits == 8:
286
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
287
+ else:
288
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
289
+
290
+ q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
291
+ q_zp = q_zp.reshape((-1, size_n)).contiguous()
292
+
293
+ marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
294
+ return marlin_zp
295
+
296
+
297
+ def moe_awq_to_marlin_zero_points(
298
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
+ ):
300
+ num_experts = q_zp_packed.shape[0]
301
+ output = torch.empty(
302
+ (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
303
+ device=q_zp_packed.device,
304
+ dtype=q_zp_packed.dtype,
305
+ )
306
+ for e in range(num_experts):
307
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
308
+ return output
309
+
310
+
311
+ def apply_gptq_marlin_linear(
312
+ input: torch.Tensor,
313
+ weight: torch.Tensor,
314
+ weight_scale: torch.Tensor,
315
+ weight_zp: torch.Tensor,
316
+ g_idx: torch.Tensor,
317
+ g_idx_sort_indices: torch.Tensor,
318
+ workspace: torch.Tensor,
319
+ wtype: ScalarType,
320
+ output_size_per_partition: int,
321
+ input_size_per_partition: int,
322
+ is_k_full: bool,
323
+ bias: Optional[torch.Tensor] = None,
324
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
+ ) -> torch.Tensor:
326
+ reshaped_x = input.reshape(-1, input.shape[-1])
327
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
328
+
329
+ output = ops.gptq_marlin_gemm(
330
+ reshaped_x,
331
+ weight,
332
+ weight_scale,
333
+ weight_zp,
334
+ g_idx,
335
+ g_idx_sort_indices,
336
+ workspace,
337
+ wtype,
338
+ size_m=reshaped_x.shape[0],
339
+ size_n=output_size_per_partition,
340
+ size_k=input_size_per_partition,
341
+ is_k_full=is_k_full,
342
+ has_zp=False,
343
+ use_fp32_reduce=use_fp32_reduce,
344
+ is_zp_float=False,
345
+ )
346
+
347
+ if bias is not None:
348
+ output.add_(bias) # In-place add
349
+
350
+ return output.reshape(out_shape)
351
+
352
+
353
+ def apply_awq_marlin_linear(
354
+ input: torch.Tensor,
355
+ weight: torch.Tensor,
356
+ weight_scale: torch.Tensor,
357
+ weight_zp: torch.Tensor,
358
+ g_idx: torch.Tensor,
359
+ g_idx_sort_indices: torch.Tensor,
360
+ workspace: torch.Tensor,
361
+ quant_type: ScalarType,
362
+ output_size_per_partition: int,
363
+ input_size_per_partition: int,
364
+ bias: Optional[torch.Tensor] = None,
365
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
+ ) -> torch.Tensor:
367
+ reshaped_x = input.reshape(-1, input.shape[-1])
368
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
369
+
370
+ output = ops.gptq_marlin_gemm(
371
+ reshaped_x,
372
+ weight,
373
+ weight_scale,
374
+ weight_zp,
375
+ g_idx,
376
+ g_idx_sort_indices,
377
+ workspace,
378
+ quant_type,
379
+ size_m=reshaped_x.shape[0],
380
+ size_n=output_size_per_partition,
381
+ size_k=input_size_per_partition,
382
+ is_k_full=True,
383
+ has_zp=True,
384
+ use_fp32_reduce=use_fp32_reduce,
385
+ is_zp_float=False,
386
+ )
387
+
388
+ if bias is not None:
389
+ output.add_(bias) # In-place add
390
+
391
+ return output.reshape(out_shape)
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ import quantization as ops
6
+
7
+ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
+
9
+
10
+ def is_fp8_marlin_supported():
11
+ capability = torch.cuda.get_device_capability()
12
+ capability = capability[0] * 10 + capability[1]
13
+ return capability >= 80
14
+
15
+
16
+ def apply_fp8_marlin_linear(
17
+ input: torch.Tensor,
18
+ weight: torch.Tensor,
19
+ weight_scale: torch.Tensor,
20
+ workspace: torch.Tensor,
21
+ size_n: int,
22
+ size_k: int,
23
+ bias: Optional[torch.Tensor],
24
+ ) -> torch.Tensor:
25
+ # For GPUs that lack FP8 hardware support, we can leverage the
26
+ # Marlin kernel for fast weight-only FP8 quantization
27
+
28
+ reshaped_x = input.reshape(-1, input.shape[-1])
29
+ out_shape = input.shape[:-1] + (size_n,)
30
+
31
+ output = ops.fp8_marlin_gemm(
32
+ a=reshaped_x,
33
+ b_q_weight=weight,
34
+ b_scales=weight_scale,
35
+ workspace=workspace,
36
+ num_bits=8,
37
+ size_m=reshaped_x.shape[0],
38
+ size_n=size_n,
39
+ size_k=size_k,
40
+ )
41
+
42
+ if bias is not None:
43
+ output.add_(bias) # In-place add
44
+
45
+ return output.reshape(out_shape)
46
+
47
+
48
+ def prepare_fp8_layer_for_marlin(
49
+ layer: torch.nn.Module, strategy: str = "tensor"
50
+ ) -> None:
51
+ part_size_n = layer.output_size_per_partition
52
+ part_size_k = layer.input_size_per_partition
53
+
54
+ device = layer.weight.device
55
+
56
+ # WORKSPACE
57
+ layer.workspace = marlin_make_workspace(part_size_n, device)
58
+
59
+ # WEIGHT
60
+ # Repack weights to marlin format
61
+ marlin_qweight = ops.gptq_marlin_repack(
62
+ b_q_weight=pack_fp8_to_int32(layer.weight),
63
+ perm=torch.empty(0, dtype=torch.int, device=device),
64
+ size_k=part_size_k,
65
+ size_n=part_size_n,
66
+ num_bits=8,
67
+ )
68
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
+
70
+ # WEIGHT SCALES
71
+ scales = layer.weight_scale.to(layer.orig_dtype)
72
+ # Permute scales
73
+ marlin_scales = marlin_permute_scales(
74
+ s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
+ )
76
+ layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
+
78
+
79
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
+ """
81
+ Repack FP8 weights to gptq format (packed int32 elements)
82
+ """
83
+ assert fp8_tensor.dtype == torch.float8_e4m3fn
84
+ assert fp8_tensor.shape[0] % 4 == 0
85
+
86
+ # Reshape to prepare for packing
87
+ reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
88
+
89
+ # Convert fp8 to uint8 (byte) representation
90
+ byte_tensor = reshaped.view(torch.uint8)
91
+
92
+ # Pack 4 uint8 values into one int32
93
+ packed = (
94
+ byte_tensor[:, 0].to(torch.int32)
95
+ | (byte_tensor[:, 1].to(torch.int32) << 8)
96
+ | (byte_tensor[:, 2].to(torch.int32) << 16)
97
+ | (byte_tensor[:, 3].to(torch.int32) << 24)
98
+ )
99
+
100
+ return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from quantization.scalar_type import ScalarType
9
+
10
+ from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
11
+ from .quant_utils import (
12
+ get_pack_factor,
13
+ gptq_quantize_weights,
14
+ quantize_weights,
15
+ sort_weights,
16
+ )
17
+
18
+
19
+ class MarlinWorkspace:
20
+
21
+ def __init__(self, out_features, min_thread_n, max_parallel):
22
+ assert (
23
+ out_features % min_thread_n == 0
24
+ ), "out_features = {} is undivisible by min_thread_n = {}".format(
25
+ out_features, min_thread_n
26
+ )
27
+
28
+ max_workspace_size = (out_features // min_thread_n) * max_parallel
29
+
30
+ self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
31
+
32
+
33
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
34
+ assert q_w.shape == (size_k, size_n)
35
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
36
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
37
+
38
+ # Permute weights to 16x64 marlin tiles
39
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
40
+ q_w = q_w.permute((0, 2, 1, 3))
41
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
42
+
43
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
44
+
45
+ return q_w
46
+
47
+
48
+ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
49
+ # Permute
50
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
51
+
52
+ # Pack
53
+ pack_factor = get_pack_factor(num_bits)
54
+ orig_device = q_w.device
55
+
56
+ q_w = q_w.cpu().numpy().astype(np.uint32)
57
+
58
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
59
+ for i in range(pack_factor):
60
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
61
+
62
+ q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
63
+
64
+ return q_packed
65
+
66
+
67
+ def get_weight_perm(num_bits: int):
68
+ perm_list: List[int] = []
69
+ for i in range(32):
70
+ perm1: List[int] = []
71
+ col = i // 4
72
+ for block in [0, 1]:
73
+ for row in [
74
+ 2 * (i % 4),
75
+ 2 * (i % 4) + 1,
76
+ 2 * (i % 4 + 4),
77
+ 2 * (i % 4 + 4) + 1,
78
+ ]:
79
+ perm1.append(16 * row + col + 8 * block)
80
+ for j in range(4):
81
+ perm_list.extend([p + 256 * j for p in perm1])
82
+
83
+ perm = np.array(perm_list)
84
+
85
+ if num_bits == 4:
86
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
87
+ elif num_bits == 8:
88
+ interleave = np.array([0, 2, 1, 3])
89
+ else:
90
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
91
+
92
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
93
+ perm = torch.from_numpy(perm)
94
+ return perm
95
+
96
+
97
+ def marlin_quantize(
98
+ w: torch.Tensor,
99
+ quant_type: ScalarType,
100
+ group_size: int,
101
+ act_order: bool,
102
+ test_perm: Optional[torch.Tensor] = None,
103
+ ):
104
+ size_k, size_n = w.shape
105
+ num_bits = quant_type.size_bits
106
+
107
+ # Normalize group_size
108
+ if group_size == -1:
109
+ group_size = size_k
110
+ assert group_size <= size_k
111
+
112
+ # Quantize (and apply act_order if provided)
113
+ w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
114
+ w, quant_type, group_size, act_order, test_perm
115
+ )
116
+
117
+ # For act_order, sort the "weights" and "g_idx" so that group ids are
118
+ # increasing
119
+ sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
120
+ if act_order:
121
+ q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
122
+
123
+ # Reformat to marlin
124
+ weight_perm = get_weight_perm(num_bits)
125
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
126
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
127
+
128
+ # Create result
129
+ res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
130
+ for i in range(len(res_list)):
131
+ res_list[i] = res_list[i].to(w.device)
132
+
133
+ return res_list
134
+
135
+
136
+ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
137
+ size_k, size_n = w.shape
138
+
139
+ # Normalize group_size
140
+ if group_size == -1:
141
+ group_size = size_k
142
+ assert group_size <= size_k
143
+
144
+ # Detect num groups
145
+ assert size_k % group_size == 0
146
+ num_groups = size_k // group_size
147
+
148
+ # Quantize with zp
149
+ w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
150
+
151
+ # Reformat to marlin
152
+ weight_perm = get_weight_perm(quant_type.size_bits)
153
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
154
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
155
+ marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
156
+
157
+ # Create result
158
+ res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
159
+ for i in range(len(res_list)):
160
+ res_list[i] = res_list[i].to(w.device)
161
+
162
+ return res_list
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ import random
4
+ from typing import List
5
+
6
+ import numpy
7
+ import torch
8
+
9
+ from quantization.scalar_type import ScalarType
10
+
11
+ from .marlin_utils_test import marlin_weights
12
+ from .quant_utils import gptq_quantize_weights
13
+
14
+
15
+ # This is PyTorch implementation of main part of reorder_meta()
16
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
17
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
18
+ # GEMM decides upon layout of this matrix, and at the moment for the
19
+ # sparse GEMM executed on tensor cores, this is layout described by
20
+ # ColumnMajorInterleaved<2> data structure, in
21
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
22
+ # reordering of meta matrix into meta_reordered matrix calculated
23
+ # according to these segments of CUTLASS code is re-implemented here.
24
+ # Note that this calculation produces offsets for scattering metadata
25
+ # matrix elements into reordered metadata matrix elements (or,
26
+ # equivalently, for gathering reordered metadata matrix element back
27
+ # into metadata matrix elements).
28
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
29
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
30
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
31
+
32
+ # Reorder the rows, then swizzle the 2x2 blocks.
33
+ group_x = 64
34
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
35
+
36
+ dst_rows = (
37
+ dst_rows // group_x * group_x
38
+ + (dst_rows % 2) * 2
39
+ + (dst_rows % 8) // 4
40
+ + ((dst_rows % group_y) % 4) // 2 * 32
41
+ + ((dst_rows % group_x) // 8) * 4
42
+ )
43
+
44
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
45
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
46
+ dst_rows += topright - bottomleft
47
+ dst_cols -= topright - bottomleft
48
+
49
+ # Assumed that meta tensor is to be stored in CUTLASS
50
+ # InterleavedColumnMajor layout, and reverse engineered
51
+ # corresponding code to store values into this tensor.
52
+ interleave = 2
53
+ cols_maj = dst_cols // interleave
54
+ cols_min = dst_cols % interleave
55
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
56
+
57
+
58
+ # This function converts dense matrix into sparse semi-structured
59
+ # representation, producing "compressed" matrix, in the layout used by
60
+ # CUTLASS backend, and corresponding metadata matrix.
61
+ def sparse_semi_structured_from_dense_cutlass(dense):
62
+ if dense.dim() != 2:
63
+ raise RuntimeError(
64
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
65
+ )
66
+
67
+ m, k = dense.shape
68
+ device = dense.device
69
+
70
+ meta_dtype = torch.int8
71
+ if dense.dtype == torch.int8:
72
+ meta_dtype = torch.int32
73
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
74
+ meta_dtype = torch.int16
75
+ else:
76
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
77
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
78
+ if quadbits_per_meta_elem not in (4, 8):
79
+ raise RuntimeError("Invalid number of elements per meta element calculated")
80
+
81
+ if meta_dtype == torch.int32:
82
+ if m % 16 != 0:
83
+ raise RuntimeError(
84
+ f"Number of rows of dense matrix {m} must be divisible by 16"
85
+ )
86
+ else:
87
+ if m % 32 != 0:
88
+ raise RuntimeError(
89
+ f"Number of rows of dense matrix {m} must be divisible by 32"
90
+ )
91
+ if k % (4 * quadbits_per_meta_elem) != 0:
92
+ raise RuntimeError(
93
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
94
+ )
95
+
96
+ if dense.dtype != torch.float:
97
+ ksparse = 4
98
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
99
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
100
+ else:
101
+ ksparse = 2
102
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
103
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
104
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
105
+
106
+ # Encoding quadruples of True/False values as follows:
107
+ # [True, True, False, False] -> 0b0100
108
+ # [True, False, True, False] -> 0b1000
109
+ # [False, True, True, False] -> 0b1001
110
+ # [True, False, False, True ] -> 0b1100
111
+ # [False, True, False, True ] -> 0b1101
112
+ # [False, False, True, True ] -> 0b1110
113
+ # Thus, lower two bits in the encoding are index of the True value
114
+ # at the lowest index in the quadruple, and the higher two bits in
115
+ # the encoding are index of the other True value in the quadruple.
116
+ # In case there are less than two True values, than False value or
117
+ # values at some index or indices are considered True for the
118
+ # encoding. In case there are more than two True values, then the
119
+ # excess True value(s) at some indices are considered False for
120
+ # the encoding. The exact encodings used for these cases are as
121
+ # follows:
122
+ # [False, False, False, False] -> 0b1110
123
+ # [False, False, False, True ] -> 0b1110
124
+ # [False, False, True, False] -> 0b1110
125
+ # [False, True, False, False] -> 0b1001
126
+ # [False, True, True, True ] -> 0b1101
127
+ # [True, False, False, False] -> 0b1000
128
+ # [True, False, True, True ] -> 0b1100
129
+ # [True, True, False, True ] -> 0b0100
130
+ # [True, True, True, False] -> 0b0100
131
+ # [True, True, True, True ] -> 0b0100
132
+ # These particular encodings are chosen, with the help of Espresso
133
+ # logic minimizer software, for the purpose of minimization of
134
+ # corresponding Boolean functions, that translate non-zero flags
135
+ # into encoding bits. Note also possible choices for the first
136
+ # and last of these encodings were limited only to (0b0100,
137
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
138
+ # case.
139
+
140
+ expr0 = m0 & m1
141
+ expr1 = ~m0 & m1
142
+ expr2 = ~m0 & ~m1
143
+ bit0 = expr1
144
+ bit1 = expr2
145
+ bit2 = expr0 | expr2 | m3
146
+ bit3 = expr1 | ~m1
147
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
148
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
149
+
150
+ if dense.dtype != torch.float:
151
+ sparse0 = dense_4.gather(
152
+ -1, idxs0.unsqueeze(-1)
153
+ ) # type: ignore[possibly-undefined]
154
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
155
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
156
+ else:
157
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
158
+ m, k // 2
159
+ ) # type: ignore[possibly-undefined]
160
+
161
+ meta_4 = idxs0 | (idxs1 << 2)
162
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
163
+
164
+ if quadbits_per_meta_elem == 4:
165
+ meta = (
166
+ meta_n[:, :, 0]
167
+ | (meta_n[:, :, 1] << 4)
168
+ | (meta_n[:, :, 2] << 8)
169
+ | (meta_n[:, :, 3] << 12)
170
+ )
171
+ elif quadbits_per_meta_elem == 8:
172
+ meta = (
173
+ meta_n[:, :, 0]
174
+ | (meta_n[:, :, 1] << 4)
175
+ | (meta_n[:, :, 2] << 8)
176
+ | (meta_n[:, :, 3] << 12)
177
+ | (meta_n[:, :, 4] << 16)
178
+ | (meta_n[:, :, 5] << 20)
179
+ | (meta_n[:, :, 6] << 24)
180
+ | (meta_n[:, :, 7] << 28)
181
+ )
182
+
183
+ # Reorder meta tensor elements.
184
+ meta_reordered = meta.new_empty(
185
+ (m * meta_ncols,)
186
+ ) # type: ignore[possibly-undefined]
187
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
188
+ m, meta_ncols, meta_dtype, device
189
+ )
190
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
191
+
192
+ return (sparse, meta_reordered.view(m, meta_ncols))
193
+
194
+
195
+ # This function performs reverse of the function above - it
196
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
197
+ # in the layout used by CUTLASS backend, and accompanying metadata
198
+ # matrix.
199
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
200
+ if sparse.dim() != 2:
201
+ raise RuntimeError(
202
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
203
+ )
204
+
205
+ m, k = sparse.shape
206
+ device = sparse.device
207
+
208
+ if meta_reordered.dim() != 2:
209
+ raise RuntimeError(
210
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
211
+ )
212
+ if meta_reordered.device != device:
213
+ raise RuntimeError(
214
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
215
+ )
216
+
217
+ meta_dtype = meta_reordered.dtype
218
+ if meta_dtype not in (torch.int16, torch.int32):
219
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
220
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
221
+
222
+ ksparse = 4 if sparse.dtype != torch.float else 2
223
+
224
+ meta_nrows, meta_ncols = meta_reordered.shape
225
+ if meta_nrows != m:
226
+ raise RuntimeError(
227
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
228
+ )
229
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
230
+ raise RuntimeError(
231
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
232
+ "expected according to the number of columns of meta matrix"
233
+ )
234
+
235
+ # Undo meta tensor elements reordering.
236
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
237
+ m, meta_ncols, meta_dtype, device
238
+ )
239
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
240
+
241
+ # Unpack sparse tensor back to original dense tensor, using
242
+ # information provided by meta tensor. Note that torch.float
243
+ # datatype is handled pretty much the same as
244
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
245
+ # value is encoded as if underlying 8 bytes contain four
246
+ # torch.half/torch.bfloat16 values, where either first two or last
247
+ # two are zeros.
248
+ meta_2 = torch.empty(
249
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
250
+ dtype=meta_dtype,
251
+ device=device,
252
+ )
253
+ if quadbits_per_meta_elem == 4:
254
+ meta_2[:, :, 0] = meta & 0b11
255
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
256
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
257
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
258
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
259
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
260
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
261
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
262
+ elif quadbits_per_meta_elem == 8:
263
+ meta_2[:, :, 0] = meta & 0b11
264
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
265
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
266
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
267
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
268
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
269
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
270
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
271
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
272
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
273
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
274
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
275
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
276
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
277
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
278
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
279
+
280
+ dense_offsets = meta_2.view(-1) + (
281
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
282
+ ).view(-1, 1).repeat(1, 2).view(-1)
283
+
284
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
285
+ if sparse.dtype != torch.float:
286
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
287
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
288
+ else:
289
+ dense.view(torch.half).scatter_(
290
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
291
+ )
292
+
293
+ return dense.view(m, 2 * k)
294
+
295
+
296
+ def mask_creator(tensor):
297
+ """
298
+ Class for creating N:M sparsity masks.
299
+ Masks will be created using the N:M ratio, where for every block of
300
+ M weights, N will be pruned based on ranked weight value. Each mask
301
+ will correspond to the given tensor.
302
+
303
+ :param N: The number of weights in a group to keep
304
+ :param M: The size of a weight group
305
+ """
306
+ N = 2
307
+ M = 4
308
+
309
+ mask = None
310
+ # for i, tensor in enumerate(tensors):
311
+ if tensor.numel() % M != 0:
312
+ raise ValueError(
313
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
314
+ )
315
+
316
+ num_groups = tensor.numel() // M
317
+
318
+ # N:M sparsity for linear layers
319
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
320
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
321
+
322
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
323
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
324
+
325
+ return mask
326
+
327
+
328
+ def inject_24(w, size_k, size_n):
329
+ assert w.shape == (size_k, size_n)
330
+
331
+ mask = mask_creator(w.t()).t().cuda().bool()
332
+
333
+ return (mask * w).contiguous(), mask.contiguous()
334
+
335
+
336
+ def check_24(w, num_rows_to_sample=50, _verbose=False):
337
+ BLOCK_SIZE = 4
338
+ MAX_NON_ZEROS = 2
339
+
340
+ w = w.t().contiguous()
341
+
342
+ print("check_24: w.shape = {}".format(w.shape))
343
+
344
+ num_rows, num_cols = w.shape
345
+ sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
346
+ if _verbose:
347
+ print(f"Sampled row idxs = {sampled_row_idxs}")
348
+
349
+ total_segments = 0
350
+ non_24_segments = 0
351
+ for i in sampled_row_idxs:
352
+ for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
353
+ total_segments += 1
354
+ block = w[i, j : j + BLOCK_SIZE]
355
+ num_nonzero = torch.count_nonzero(block)
356
+ if num_nonzero > MAX_NON_ZEROS:
357
+ print("i = {} j = {} block = {}".format(i, j, block))
358
+ non_24_segments += 1
359
+
360
+ print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
361
+
362
+
363
+ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
364
+ assert q_24.shape == (size_k, size_n)
365
+
366
+ # Remove bias to normalize over 0
367
+ q_24_no_zp = q_24 - wtype.bias
368
+
369
+ # Compress
370
+ q_24_no_zp = q_24_no_zp.t().contiguous()
371
+ q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
372
+ q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
373
+
374
+ # Restore bias
375
+ q_24_comp = q_24_no_zp_comp + wtype.bias
376
+
377
+ # Resize meta to its actual shape (without moving any data)
378
+ meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
379
+
380
+ return q_24_comp, meta
381
+
382
+
383
+ def get_scale_perms_24():
384
+ scale_perm: List[int] = []
385
+ for i in range(8):
386
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
387
+ scale_perm_single: List[int] = []
388
+ for i in range(8):
389
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
390
+ return scale_perm, scale_perm_single
391
+
392
+
393
+ def get_weight_perm_24(num_bits: int):
394
+ perm_list: List[int] = []
395
+ for i in range(32):
396
+ perm1: List[int] = []
397
+ col = i // 4
398
+ col_o = col // 2
399
+ for block in [0, 1]:
400
+ for row in [
401
+ 2 * (i % 4),
402
+ 2 * (i % 4) + 1,
403
+ 2 * (i % 4 + 4),
404
+ 2 * (i % 4 + 4) + 1,
405
+ ]:
406
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
407
+ for j in range(4):
408
+ perm_list.extend([p + 1 * j for p in perm1])
409
+ perm = numpy.array(perm_list)
410
+
411
+ if num_bits == 4:
412
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
413
+ elif num_bits == 8:
414
+ interleave = numpy.array([0, 2, 1, 3])
415
+ else:
416
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
417
+
418
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
419
+ perm = torch.from_numpy(perm)
420
+ return perm
421
+
422
+
423
+ def marlin_permute_scales_24(
424
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
425
+ ) -> torch.Tensor:
426
+
427
+ scale_perm, scale_perm_single = get_scale_perms_24()
428
+ if group_size < size_k and group_size != -1:
429
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
430
+ else:
431
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
432
+ s = s.reshape((-1, size_n)).contiguous()
433
+
434
+ return s
435
+
436
+
437
+ def marlin_24_quantize(
438
+ w: torch.Tensor,
439
+ quant_type: ScalarType,
440
+ group_size: int,
441
+ ):
442
+ size_k, size_n = w.shape
443
+
444
+ # Normalize group_size
445
+ if group_size == -1:
446
+ group_size = size_k
447
+ assert group_size <= size_k
448
+
449
+ # Inject 2:4 sparsity
450
+ w_24, mask_24 = inject_24(w, size_k, size_n)
451
+
452
+ # Quantize
453
+ w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
454
+ w_24, quant_type, group_size, act_order=False
455
+ )
456
+
457
+ # Compress quantized weight
458
+ q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
459
+ size_k_comp = size_k // 2
460
+
461
+ # Reformat to marlin
462
+ weight_perm = get_weight_perm_24(quant_type.size_bits)
463
+ marlin_24_q_w_comp = marlin_weights(
464
+ q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
465
+ )
466
+ marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
467
+
468
+ # Create result
469
+ res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
470
+ for i in range(len(res_list)):
471
+ res_list[i] = res_list[i].to(w.device)
472
+
473
+ return res_list
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ from .marlin_utils_test import marlin_permute_weights
7
+ from .quant_utils import get_pack_factor, qqq_quantize_weights
8
+
9
+
10
+ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
11
+ # Permute
12
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
13
+
14
+ # Pack
15
+ pack_factor = get_pack_factor(num_bits)
16
+ orig_device = q_w.device
17
+
18
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
19
+
20
+ q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
21
+ dtype=numpy.uint32)
22
+ if group_size == size_k:
23
+ for i in range(pack_factor):
24
+ q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
25
+ else:
26
+ for i in range(pack_factor):
27
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
28
+
29
+ q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
30
+
31
+ return q_packed
32
+
33
+
34
+ def get_qqq_scale_perms():
35
+ scale_perm: List[int] = []
36
+ for i in range(8):
37
+ scale_perm.extend([i + 8 * j for j in range(8)])
38
+ scale_perm_single: List[int] = []
39
+ for i in range(4):
40
+ scale_perm_single.extend(
41
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
42
+ return scale_perm, scale_perm_single
43
+
44
+
45
+ # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
46
+ def get_qqq_weight_perm(num_bits: int, quant_type: str):
47
+ perm_list: List[int] = []
48
+ for i in range(32):
49
+ perm1: List[int] = []
50
+ col = i // 4
51
+ for block in [0, 1]:
52
+ for row in [
53
+ 4 * (i % 4),
54
+ 4 * (i % 4) + 1,
55
+ 4 * (i % 4) + 2,
56
+ 4 * (i % 4) + 3,
57
+ ]:
58
+ perm1.append(16 * row + col + 8 * block)
59
+ for j in range(4):
60
+ perm_list.extend([p + 256 * j for p in perm1])
61
+
62
+ perm = numpy.array(perm_list)
63
+
64
+ assert quant_type in ["per-channel",
65
+ "per-group"], "not supported quantization type"
66
+ if num_bits == 4:
67
+ if quant_type == "per-channel":
68
+ interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
69
+ else:
70
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
71
+ else:
72
+ raise Exception("num_bits must be 4, got {}".format(num_bits))
73
+
74
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
75
+ perm = torch.from_numpy(perm)
76
+ return perm
77
+
78
+
79
+ def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
80
+ scale_perm, scale_perm_single = get_qqq_scale_perms()
81
+ if group_size < size_k and group_size != -1:
82
+ s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
83
+ s_channel = s_channel.reshape(
84
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
85
+ s_group = s_group.reshape((-1, size_n)).contiguous()
86
+ else:
87
+ s_channel = s_channel.reshape(
88
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
89
+ s_channel = s_channel.reshape((-1, size_n)).contiguous()
90
+
91
+ return s_group, s_channel
92
+
93
+
94
+ def marlin_qqq_quantize(
95
+ w: torch.Tensor,
96
+ num_bits: int,
97
+ group_size: int,
98
+ ):
99
+ size_k, size_n = w.shape
100
+
101
+ # Normalize group_size
102
+ if group_size == -1:
103
+ group_size = size_k
104
+ assert group_size <= size_k
105
+ quant_type = "per-channel" if group_size == size_k else "per-group"
106
+
107
+ # Quantize
108
+ w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
109
+ w, num_bits, group_size)
110
+
111
+ # Reformat to marlin_qqq
112
+ weight_perm = get_qqq_weight_perm(num_bits, quant_type)
113
+ marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
114
+ weight_perm, group_size)
115
+ marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
116
+ s_group, s_channel, size_k, size_n, group_size)
117
+
118
+ # Create result
119
+ res_list = [
120
+ w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
121
+ ]
122
+ for i in range(len(res_list)):
123
+ res_list[i] = res_list[i].to(w.device)
124
+
125
+ return res_list
build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is used for /tests and /benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy
6
+ import torch
7
+
8
+ from quantization.scalar_type import ScalarType, scalar_types
9
+
10
+ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
+ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
12
+
13
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
14
+
15
+ # Note: this is a hack. We should update each model to register the
16
+ # stacked params and get it from there instead in a future PR.
17
+ # fused_name: List[shard_name]
18
+ FUSED_LAYER_NAME_MAPPING = {
19
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
20
+ "gate_up_proj": ["gate_proj", "up_proj"],
21
+ }
22
+
23
+
24
+ def pack_quantized_values_into_int32(
25
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
26
+ ):
27
+ # move dim to pack to the end
28
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
29
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
30
+ w_q_perm = w_q.permute(perm)
31
+
32
+ pack_factor = 32 // wtype.size_bits
33
+ mask = (1 << wtype.size_bits) - 1
34
+
35
+ new_shape_perm = list(w_q_perm.shape)
36
+ assert w_q_perm.shape[-1] % pack_factor == 0
37
+ new_shape_perm[-1] //= pack_factor
38
+
39
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
40
+ for i in range(pack_factor):
41
+ res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
42
+
43
+ return res.permute(inv_perm)
44
+
45
+
46
+ def unpack_quantized_values_into_int32(
47
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
48
+ ):
49
+ # move dim to pack to the end
50
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
51
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
52
+ w_q_perm = w_q.permute(perm)
53
+
54
+ pack_factor = 32 // wtype.size_bits
55
+ mask = (1 << wtype.size_bits) - 1
56
+
57
+ new_shape_perm = list(w_q_perm.shape)
58
+ new_shape_perm[-1] *= pack_factor
59
+
60
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
61
+ for i in range(pack_factor):
62
+ res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
63
+
64
+ return res.permute(inv_perm)
65
+
66
+
67
+ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
68
+ # prefix: model.layers.0.self_attn.q_proj
69
+ # proj_name: q_proj
70
+ proj_name = prefix.split(".")[-1]
71
+ if proj_name in FUSED_LAYER_NAME_MAPPING:
72
+ shard_prefixes = [
73
+ prefix.replace(proj_name, shard_proj_name)
74
+ for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
75
+ ]
76
+
77
+ is_skipped = None
78
+ for shard_prefix in shard_prefixes:
79
+ is_shard_skipped = shard_prefix in ignored_layers
80
+
81
+ if is_skipped is None:
82
+ is_skipped = is_shard_skipped
83
+ elif is_shard_skipped != is_skipped:
84
+ raise ValueError(
85
+ f"Detected some but not all shards of {prefix} "
86
+ "are quantized. All shards of fused layers "
87
+ "to have the same precision."
88
+ )
89
+ else:
90
+ is_skipped = prefix in ignored_layers
91
+
92
+ assert is_skipped is not None
93
+ return is_skipped
94
+
95
+
96
+ def get_pack_factor(num_bits):
97
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
98
+ return 32 // num_bits
99
+
100
+
101
+ def permute_rows(
102
+ q_w: torch.Tensor,
103
+ w_ref: torch.Tensor,
104
+ group_size: int,
105
+ test_perm: Optional[torch.Tensor] = None,
106
+ ):
107
+ assert q_w.shape == w_ref.shape
108
+
109
+ orig_device = q_w.device
110
+ k_size, _ = q_w.shape
111
+
112
+ g_idx = torch.zeros((k_size,), dtype=torch.int32)
113
+ for i in range(k_size):
114
+ g_idx[i] = i // group_size
115
+
116
+ # Simulate act_order by doing a random permutation on K
117
+ rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
118
+
119
+ g_idx = g_idx[rand_perm].contiguous()
120
+ q_w = q_w[rand_perm, :].contiguous()
121
+ w_ref = w_ref[rand_perm, :].contiguous()
122
+
123
+ return (
124
+ w_ref.to(device=orig_device),
125
+ q_w.to(device=orig_device),
126
+ g_idx.to(device=orig_device),
127
+ rand_perm.to(device=orig_device),
128
+ )
129
+
130
+
131
+ def quantize_weights(
132
+ w: torch.Tensor,
133
+ quant_type: ScalarType,
134
+ group_size: Optional[int],
135
+ zero_points: bool = False,
136
+ ref_zero_points_after_scales: bool = False,
137
+ ):
138
+ assert (
139
+ quant_type.is_integer()
140
+ ), "Floating point quantization may work but has not been tested"
141
+ assert not zero_points or group_size is not None, (
142
+ "to have group zero points, group_size must be provided "
143
+ "(-1 group_size is channelwise)"
144
+ )
145
+
146
+ orig_device = w.device
147
+ orig_type = w.dtype
148
+ size_k, size_n = w.shape
149
+
150
+ assert w.is_floating_point(), "w must be float"
151
+
152
+ if group_size == -1:
153
+ group_size = size_k
154
+
155
+ # Reshape to [groupsize, -1]
156
+ if group_size is not None and group_size < size_k:
157
+ w = w.reshape((-1, group_size, size_n))
158
+ w = w.permute(1, 0, 2)
159
+ w = w.reshape((group_size, -1))
160
+
161
+ # Compute scale for each group
162
+ max_val = torch.max(w, 0, keepdim=True).values
163
+ min_val = torch.min(w, 0, keepdim=True).values
164
+
165
+ max_q_val = quant_type.max()
166
+ min_q_val = quant_type.min()
167
+
168
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
169
+ maybe_w_zp = None
170
+ if group_size is not None:
171
+ if zero_points:
172
+ assert not quant_type.is_signed() and quant_type.max() > 0
173
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
174
+ maybe_w_zp = (
175
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
176
+ )
177
+ else:
178
+ # If the bias is such that there are no possible negative/positive
179
+ # values, set the max value to inf to avoid divide by 0
180
+ w_s = torch.max(
181
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
182
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
183
+ )
184
+
185
+ # Quantize
186
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
187
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
188
+
189
+ # Compute ref (dequantized)
190
+ # For some kernels (namely Machete) the zero-points are applied after the
191
+ # scales are applied, for this case computing the reference in similar way
192
+ # allows us to use tighter error tolerances in our unit tests.
193
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
194
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
195
+ else:
196
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
197
+
198
+ if quant_type.has_bias():
199
+ w_q += quant_type.bias
200
+
201
+ # Restore original shapes
202
+ if group_size is not None and group_size < size_k:
203
+
204
+ def reshape_w(w):
205
+ w = w.reshape((group_size, -1, size_n))
206
+ w = w.permute(1, 0, 2)
207
+ w = w.reshape((size_k, size_n)).contiguous()
208
+ return w
209
+
210
+ w_q = reshape_w(w_q)
211
+ w_ref = reshape_w(w_ref)
212
+ w_s = w_s.reshape((-1, size_n)).contiguous()
213
+
214
+ if maybe_w_zp is not None:
215
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
216
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
217
+
218
+ return (
219
+ w_ref.to(device=orig_device),
220
+ w_q.to(device=orig_device),
221
+ w_s if group_size is not None else None,
222
+ maybe_w_zp,
223
+ )
224
+
225
+
226
+ def gptq_quantize_weights(
227
+ w: torch.Tensor,
228
+ quant_type: ScalarType,
229
+ group_size: int,
230
+ act_order: bool,
231
+ test_perm: Optional[torch.Tensor] = None,
232
+ ):
233
+ size_k, _ = w.shape
234
+
235
+ assert w.is_floating_point(), "w must be float"
236
+ assert (
237
+ quant_type in SUPPORTED_GPTQ_QUANT_TYPES
238
+ ), f"Unsupported gptq type = {quant_type}"
239
+ assert group_size in SUPPORTED_GROUP_SIZES + [
240
+ size_k
241
+ ], f"Unsupported groupsize = {group_size}"
242
+
243
+ w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
244
+
245
+ # Apply act_order
246
+ g_idx = torch.empty(0, dtype=torch.int, device=w.device)
247
+ rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
248
+ if act_order:
249
+ assert (
250
+ group_size < size_k
251
+ ), "For act_order, groupsize = {} must be less than size_k = {}".format(
252
+ group_size, size_k
253
+ )
254
+
255
+ w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
256
+
257
+ return w_ref, w_q, w_s, g_idx, rand_perm
258
+
259
+
260
+ # QQQ employs different quant schemes for per-group and
261
+ # per-channel quantization.
262
+ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
263
+ orig_device = w.device
264
+ size_k, size_n = w.shape
265
+
266
+ assert w.is_floating_point(), "w must be float"
267
+ assert (
268
+ num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
269
+ ), f"Unsupported num_bits = {num_bits}"
270
+ assert group_size in SUPPORTED_GROUP_SIZES + [
271
+ size_k
272
+ ], f"Unsupported groupsize = {group_size}"
273
+
274
+ if group_size == -1:
275
+ group_size = size_k
276
+ assert group_size <= size_k
277
+
278
+ if group_size < size_k:
279
+ # Reshape to [groupsize, -1]
280
+ w = w.reshape((-1, group_size, size_n))
281
+ w = w.permute(1, 0, 2)
282
+ w = w.reshape((group_size, -1))
283
+
284
+ max_q_val = 2**num_bits - 1
285
+ half_q_val = (max_q_val + 1) // 2
286
+
287
+ # Compute scale for each group
288
+ s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
289
+ s_group *= 2 / max_q_val # 2 => symmetric
290
+
291
+ # Quantize
292
+ q_w = torch.round(w / s_group).int()
293
+ q_w += half_q_val
294
+ q_w = torch.clamp(q_w, 0, max_q_val)
295
+ # Compute ref (dequantized)
296
+ w_ref = (q_w - half_q_val).half() * s_group
297
+
298
+ # Restore original shapes
299
+ def reshape_w(w):
300
+ w = w.reshape((group_size, -1, size_n))
301
+ w = w.permute(1, 0, 2)
302
+ w = w.reshape((size_k, size_n)).contiguous()
303
+ return w
304
+
305
+ q_w = reshape_w(q_w)
306
+ w_ref = reshape_w(w_ref)
307
+
308
+ # Compute int8 quantization scale for each channel
309
+ s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
310
+ s_channel /= 127.0
311
+ t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
312
+ w_ref = t_int8.half() * s_channel
313
+ s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
314
+
315
+ # Fuse scales
316
+ s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
317
+ dtype=torch.half
318
+ )
319
+ else:
320
+ max_q_val = 2 ** (num_bits - 1) - 1
321
+
322
+ # Compute scale for each channel
323
+ s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
324
+ s_channel /= max_q_val
325
+
326
+ # Quantize
327
+ q_w = torch.round(w / s_channel).int()
328
+ q_w = torch.clamp(q_w, -max_q_val, max_q_val)
329
+ # Compute ref (dequantized)
330
+ w_ref = q_w.half() * s_channel
331
+
332
+ s_group = torch.tensor([], dtype=torch.half)
333
+ # div 2 ** (8 - self.bits)) to offset right shift in unpacking
334
+ s_channel /= 2 ** (8 - num_bits)
335
+ s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
336
+
337
+ return (
338
+ w_ref.to(device=orig_device),
339
+ q_w.to(device=orig_device),
340
+ s_group.to(device=orig_device),
341
+ s_channel.to(device=orig_device),
342
+ )
343
+
344
+
345
+ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
346
+ orig_device = q_w.device
347
+
348
+ sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
349
+
350
+ g_idx = g_idx[sort_indices].contiguous()
351
+ q_w = q_w[sort_indices, :].contiguous()
352
+
353
+ return (
354
+ q_w.to(device=orig_device),
355
+ g_idx.to(device=orig_device),
356
+ sort_indices.to(device=orig_device),
357
+ )
358
+
359
+
360
+ def pack_rows(
361
+ q_w: torch.Tensor,
362
+ num_bits: int,
363
+ size_k: int,
364
+ size_n: int,
365
+ ):
366
+ assert q_w.shape == (size_k, size_n)
367
+
368
+ pack_factor = get_pack_factor(num_bits)
369
+ assert size_k % pack_factor == 0
370
+
371
+ orig_device = q_w.device
372
+
373
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
374
+
375
+ q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
376
+
377
+ for i in range(pack_factor):
378
+ q_res |= q_w[i::pack_factor, :] << num_bits * i
379
+
380
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
381
+ return q_res
382
+
383
+
384
+ def pack_cols(
385
+ q_w: torch.Tensor,
386
+ num_bits: int,
387
+ size_k: int,
388
+ size_n: int,
389
+ ):
390
+ assert q_w.shape == (size_k, size_n)
391
+
392
+ pack_factor = get_pack_factor(num_bits)
393
+ assert size_n % pack_factor == 0
394
+
395
+ orig_device = q_w.device
396
+
397
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
398
+
399
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
400
+
401
+ for i in range(pack_factor):
402
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
403
+
404
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
405
+ q_res = q_res.contiguous()
406
+
407
+ return q_res
408
+
409
+
410
+ def unpack_cols(
411
+ packed_q_w: torch.Tensor,
412
+ num_bits: int,
413
+ size_k: int,
414
+ size_n: int,
415
+ ):
416
+ pack_factor = get_pack_factor(num_bits)
417
+ assert size_n % pack_factor == 0
418
+ assert packed_q_w.shape == (
419
+ size_k,
420
+ size_n // pack_factor,
421
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
422
+ packed_q_w.shape, size_k, size_n, pack_factor
423
+ )
424
+
425
+ orig_device = packed_q_w.device
426
+
427
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
428
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
429
+
430
+ mask = (1 << num_bits) - 1
431
+ for i in range(pack_factor):
432
+ vals = packed_q_w_cpu & mask
433
+ packed_q_w_cpu >>= num_bits
434
+ q_res[:, i::pack_factor] = vals
435
+
436
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
437
+ q_res = q_res.contiguous()
438
+
439
+ return q_res
440
+
441
+
442
+ def gptq_pack(
443
+ q_w: torch.Tensor,
444
+ num_bits: int,
445
+ size_k: int,
446
+ size_n: int,
447
+ ):
448
+ return pack_rows(q_w, num_bits, size_k, size_n)
449
+
450
+
451
+ def awq_pack(
452
+ q_w: torch.Tensor,
453
+ num_bits: int,
454
+ size_k: int,
455
+ size_n: int,
456
+ ):
457
+ assert q_w.shape == (size_k, size_n)
458
+
459
+ # Interleave column dim (for the dequantize code) and pack it to int32
460
+ if num_bits == 4:
461
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
462
+ elif num_bits == 8:
463
+ interleave = numpy.array([0, 2, 1, 3])
464
+ else:
465
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
466
+
467
+ q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
468
+ q_w = q_w.reshape((-1, size_n)).contiguous()
469
+
470
+ return pack_cols(q_w, num_bits, size_k, size_n)
build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,150 +1,30 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch
4
-
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
- ops = torch.ops._quantization
12
- except ImportError:
13
- raise e
14
-
15
- def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
16
- return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
17
-
18
- def cutlass_scaled_mm(a: torch.Tensor,
19
- b: torch.Tensor,
20
- scale_a: torch.Tensor,
21
- scale_b: torch.Tensor,
22
- out_dtype: torch.dtype,
23
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
24
- assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
25
- assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
26
- assert bias is None or bias.shape[0] == b.shape[
27
- 1] and bias.dtype == out_dtype
28
-
29
- m = a.shape[0]
30
- n = b.shape[1]
31
-
32
- #if current_platform.is_rocm():
33
- # triton_scaled_mm_module = importlib.import_module(
34
- # "vllm.model_executor.layers.quantization.compressed_tensors."
35
- # "triton_scaled_mm")
36
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
37
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
38
-
39
- out = torch.empty((m, n), dtype=out_dtype, device=a.device)
40
-
41
- ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
42
-
43
- return out
44
-
45
- # fp8
46
- def scaled_fp8_quant(
47
- input: torch.Tensor,
48
- scale: Optional[torch.Tensor] = None,
49
- num_token_padding: Optional[int] = None,
50
- scale_ub: Optional[torch.Tensor] = None,
51
- use_per_token_if_dynamic: bool = False,
52
- ) -> Tuple[torch.Tensor, torch.Tensor]:
53
- """
54
- Quantize input tensor to FP8 and return quantized tensor and scale.
55
-
56
- This function supports both static and dynamic quantization: If you
57
- provide the scale, it will use static scaling and if you omit it,
58
- the scale will be determined dynamically. The function also allows
59
- optional padding of the output tensors for downstream kernels that
60
- will benefit from padding.
61
-
62
- Args:
63
- input: The input tensor to be quantized to FP8
64
- scale: Optional scaling factor for the FP8 quantization
65
- scale_ub: Optional upper bound for scaling factor in dynamic
66
- per token case
67
- num_token_padding: If specified, pad the first dimension
68
- of the output to at least this value.
69
- use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
- in the dynamic quantization case.
71
-
72
- Returns:
73
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
- scaling factor.
75
- """
76
- # This code assumes batch_dim and num_tokens are flattened
77
- assert (input.ndim == 2)
78
- shape: Union[Tuple[int, int], torch.Size] = input.shape
79
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
- #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
- # if current_platform.is_rocm() else torch.float8_e4m3fn
82
- out_dtype = torch.float8_e4m3fn
83
- if num_token_padding:
84
- shape = (max(num_token_padding, input.shape[0]), shape[1])
85
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
-
87
- if scale is None:
88
- if use_per_token_if_dynamic:
89
- scale = torch.empty((shape[0], 1),
90
- device=input.device,
91
- dtype=torch.float32)
92
- ops.dynamic_per_token_scaled_fp8_quant(
93
- output, input, scale, scale_ub)
94
- else:
95
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
- ops.dynamic_scaled_fp8_quant(output, input, scale)
97
- else:
98
- # num_token_padding not implemented for this case
99
- assert (scale.numel() == 1 or num_token_padding is None)
100
- ops.static_scaled_fp8_quant(output, input, scale)
101
-
102
- return output, scale
103
-
104
- # int8
105
- def scaled_int8_quant(
106
- input: torch.Tensor,
107
- scale: Optional[torch.Tensor] = None,
108
- azp: Optional[torch.Tensor] = None,
109
- symmetric: bool = True
110
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
- """
112
- Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
-
114
- Args:
115
- input: The input tensor to be quantized to int8.
116
- scale: Optional scaling factor for the int8 quantization.
117
- When not provided, we invoke dynamic-per-token quantization.
118
- azp: Optional zero-point for the int8 quantization.
119
- Must be provided for asymmetric quantization if `scale` is provided.
120
- symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
-
122
- Returns:
123
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
- """
125
- output = torch.empty_like(input, dtype=torch.int8)
126
- if scale is not None:
127
- # static-per-tensor quantization.
128
- assert symmetric == (
129
- azp is
130
- None), "azp must only be provided for asymmetric quantization."
131
- ops.static_scaled_int8_quant(output, input, scale, azp)
132
- return output, scale, azp
133
-
134
- # dynamic-per-token quantization.
135
- input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
- device=input.device,
137
- dtype=torch.float32)
138
- input_azp = None if symmetric else torch.empty_like(input_scales,
139
- dtype=torch.int32)
140
- ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
- input_azp)
142
- return output, input_scales, input_azp
143
-
144
- # fp8 marlin
145
- def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
- b_scales: torch.Tensor, workspace: torch.Tensor,
147
- num_bits: int, size_m: int, size_n: int,
148
- size_k: int) -> torch.Tensor:
149
- return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
- num_bits, size_m, size_n, size_k)
 
1
+ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
+ from .cutlass import (
3
+ cutlass_scaled_mm_supports_fp8,
4
+ cutlass_scaled_mm,
5
+ cutlass_scaled_mm_azp,
6
+ )
7
+ from .marlin import (
8
+ awq_marlin_repack,
9
+ fp8_marlin_gemm,
10
+ gptq_marlin_gemm,
11
+ gptq_marlin_repack,
12
+ gptq_marlin_24_gemm,
13
+ marlin_qqq_gemm,
14
+ marlin_gemm,
15
+ )
16
+
17
+ __all__ = [
18
+ "awq_marlin_repack",
19
+ "cutlass_scaled_mm",
20
+ "cutlass_scaled_mm_azp",
21
+ "cutlass_scaled_mm_supports_fp8",
22
+ "fp8_marlin_gemm",
23
+ "gptq_marlin_24_gemm",
24
+ "gptq_marlin_gemm",
25
+ "gptq_marlin_repack",
26
+ "marlin_gemm",
27
+ "marlin_qqq_gemm",
28
+ "scaled_fp8_quant",
29
+ "scaled_int8_quant",
30
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch24-cxx98-cu118-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,3 +1,9 @@
1
  import torch
2
  from . import _quantization_0_0_1
3
  ops = torch.ops._quantization_0_0_1
 
 
 
 
 
 
 
1
  import torch
2
  from . import _quantization_0_0_1
3
  ops = torch.ops._quantization_0_0_1
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_quantization_0_0_1::{op_name}"
build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bfb228c577303d6d981b88ab6955447786f88e4a551b42b3ab13225fb96aa81b
3
- size 70283536
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:734f235fc2749269910ee4e988da205a9442edf73c0f9b3ef41fff100bc66707
3
+ size 85709024
build/torch24-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ # fp8
18
+ def scaled_fp8_quant(
19
+ input: torch.Tensor,
20
+ scale: Optional[torch.Tensor] = None,
21
+ num_token_padding: Optional[int] = None,
22
+ scale_ub: Optional[torch.Tensor] = None,
23
+ use_per_token_if_dynamic: bool = False,
24
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Quantize input tensor to FP8 and return quantized tensor and scale.
27
+
28
+ This function supports both static and dynamic quantization: If you
29
+ provide the scale, it will use static scaling and if you omit it,
30
+ the scale will be determined dynamically. The function also allows
31
+ optional padding of the output tensors for downstream kernels that
32
+ will benefit from padding.
33
+
34
+ Args:
35
+ input: The input tensor to be quantized to FP8
36
+ scale: Optional scaling factor for the FP8 quantization
37
+ scale_ub: Optional upper bound for scaling factor in dynamic
38
+ per token case
39
+ num_token_padding: If specified, pad the first dimension
40
+ of the output to at least this value.
41
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
42
+ in the dynamic quantization case.
43
+
44
+ Returns:
45
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
+ scaling factor.
47
+ """
48
+ # This code assumes batch_dim and num_tokens are flattened
49
+ assert input.ndim == 2
50
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
51
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
+ # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
54
+ out_dtype = torch.float8_e4m3fn
55
+ if num_token_padding:
56
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
57
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
58
+
59
+ if scale is None:
60
+ if use_per_token_if_dynamic:
61
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
63
+ else:
64
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
66
+ else:
67
+ # num_token_padding not implemented for this case
68
+ assert scale.numel() == 1 or num_token_padding is None
69
+ ops.static_scaled_fp8_quant(output, input, scale)
70
+
71
+ return output, scale
72
+
73
+
74
+ # int8
75
+ def scaled_int8_quant(
76
+ input: torch.Tensor,
77
+ scale: Optional[torch.Tensor] = None,
78
+ azp: Optional[torch.Tensor] = None,
79
+ symmetric: bool = True,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
+ """
82
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
+
84
+ Args:
85
+ input: The input tensor to be quantized to int8.
86
+ scale: Optional scaling factor for the int8 quantization.
87
+ When not provided, we invoke dynamic-per-token quantization.
88
+ azp: Optional zero-point for the int8 quantization.
89
+ Must be provided for asymmetric quantization if `scale` is provided.
90
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
+
92
+ Returns:
93
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
+ """
95
+ output = torch.empty_like(input, dtype=torch.int8)
96
+ if scale is not None:
97
+ # static-per-tensor quantization.
98
+ assert symmetric == (
99
+ azp is None
100
+ ), "azp must only be provided for asymmetric quantization."
101
+ ops.static_scaled_int8_quant(output, input, scale, azp)
102
+ return output, scale, azp
103
+
104
+ # dynamic-per-token quantization.
105
+ input_scales = torch.empty(
106
+ (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
+ )
108
+ input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
+ ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
110
+ return output, input_scales, input_azp
build/torch24-cxx98-cu118-x86_64-linux/quantization/cutlass.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+
12
+ ops = torch.ops._quantization
13
+ except ImportError:
14
+ raise e
15
+
16
+
17
+ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
+ return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
+
20
+
21
+ def cutlass_scaled_mm(
22
+ a: torch.Tensor,
23
+ b: torch.Tensor,
24
+ scale_a: torch.Tensor,
25
+ scale_b: torch.Tensor,
26
+ out_dtype: torch.dtype,
27
+ bias: Optional[torch.Tensor] = None,
28
+ ) -> torch.Tensor:
29
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
30
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
31
+ assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
32
+
33
+ m = a.shape[0]
34
+ n = b.shape[1]
35
+
36
+ # if current_platform.is_rocm():
37
+ # triton_scaled_mm_module = importlib.import_module(
38
+ # "vllm.model_executor.layers.quantization.compressed_tensors."
39
+ # "triton_scaled_mm")
40
+ # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
+ # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
+
43
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
+
45
+ ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
46
+
47
+ return out
48
+
49
+
50
+ def cutlass_scaled_mm_azp(
51
+ a: torch.Tensor,
52
+ b: torch.Tensor,
53
+ scale_a: torch.Tensor,
54
+ scale_b: torch.Tensor,
55
+ out_dtype: torch.dtype,
56
+ azp_adj: torch.Tensor,
57
+ azp: Optional[torch.Tensor] = None,
58
+ bias: Optional[torch.Tensor] = None,
59
+ ) -> torch.Tensor:
60
+ """
61
+ :param azp_adj: In the per-tensor case, this should include the azp.
62
+ Always per-channel.
63
+ :param azp: Only set in the per-token case. Per-token if set.
64
+ """
65
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
66
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
67
+ assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
68
+ assert azp is None or azp.numel() == a.shape[0]
69
+
70
+ m = a.shape[0]
71
+ n = b.shape[1]
72
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
73
+
74
+ ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
75
+ return out
build/torch24-cxx98-cu118-x86_64-linux/quantization/marlin.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+ def register_fake(fn):
8
+ return lambda name: fn
9
+ else:
10
+ try:
11
+ from torch.library import register_fake
12
+ except ImportError:
13
+ from torch.library import impl_abstract as register_fake
14
+
15
+ try:
16
+ from ._ops import ops, add_op_namespace_prefix
17
+ except ImportError as e:
18
+ # Fallback for local development.
19
+ try:
20
+ import _quantization
21
+
22
+ ops = torch.ops._quantization
23
+
24
+ def add_op_namespace_prefix(op_name: str):
25
+ return f"_quantization::{op_name}"
26
+ except ImportError:
27
+ raise e
28
+
29
+
30
+ from .scalar_type import ScalarType
31
+
32
+
33
+ # fp8 marlin
34
+ def fp8_marlin_gemm(
35
+ a: torch.Tensor,
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ workspace: torch.Tensor,
39
+ num_bits: int,
40
+ size_m: int,
41
+ size_n: int,
42
+ size_k: int,
43
+ ) -> torch.Tensor:
44
+ return ops.fp8_marlin_gemm(
45
+ a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
+ )
47
+
48
+
49
+ # gptq_marlin
50
+ def gptq_marlin_gemm(
51
+ a: torch.Tensor,
52
+ b_q_weight: torch.Tensor,
53
+ b_scales: torch.Tensor,
54
+ b_zeros: torch.Tensor,
55
+ g_idx: torch.Tensor,
56
+ perm: torch.Tensor,
57
+ workspace: torch.Tensor,
58
+ b_q_type: ScalarType,
59
+ size_m: int,
60
+ size_n: int,
61
+ size_k: int,
62
+ is_k_full: bool,
63
+ has_zp: bool = False,
64
+ use_fp32_reduce: bool = False,
65
+ is_zp_float: bool = False,
66
+ ) -> torch.Tensor:
67
+ return ops.gptq_marlin_gemm(
68
+ a,
69
+ b_q_weight,
70
+ b_scales,
71
+ b_zeros,
72
+ g_idx,
73
+ perm,
74
+ workspace,
75
+ b_q_type.id,
76
+ size_m,
77
+ size_n,
78
+ size_k,
79
+ is_k_full,
80
+ has_zp,
81
+ use_fp32_reduce,
82
+ is_zp_float,
83
+ )
84
+
85
+
86
+ # gptq_marlin
87
+ def gptq_marlin_repack(
88
+ b_q_weight: torch.Tensor,
89
+ perm: torch.Tensor,
90
+ size_k: int,
91
+ size_n: int,
92
+ num_bits: int,
93
+ ) -> torch.Tensor:
94
+ return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
95
+
96
+
97
+ # gptq_marlin
98
+ def awq_marlin_repack(
99
+ b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
100
+ ) -> torch.Tensor:
101
+ return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
102
+
103
+
104
+ # marlin
105
+ def marlin_gemm(
106
+ a: torch.Tensor,
107
+ b_q_weight: torch.Tensor,
108
+ b_scales: torch.Tensor,
109
+ workspace: torch.Tensor,
110
+ size_m: int,
111
+ size_n: int,
112
+ size_k: int,
113
+ ) -> torch.Tensor:
114
+ return ops.marlin_gemm(
115
+ a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
116
+ )
117
+
118
+
119
+ # marlin_24
120
+ def gptq_marlin_24_gemm(
121
+ a: torch.Tensor,
122
+ b_q_weight: torch.Tensor,
123
+ b_meta: torch.Tensor,
124
+ b_scales: torch.Tensor,
125
+ workspace: torch.Tensor,
126
+ b_q_type: ScalarType,
127
+ size_m: int,
128
+ size_n: int,
129
+ size_k: int,
130
+ ) -> torch.Tensor:
131
+ return ops.gptq_marlin_24_gemm(
132
+ a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
133
+ )
134
+
135
+
136
+ # qqq ops
137
+ def marlin_qqq_gemm(
138
+ a: torch.Tensor,
139
+ b_q_weight: torch.Tensor,
140
+ s_tok: torch.Tensor,
141
+ s_ch: torch.Tensor,
142
+ s_group: torch.Tensor,
143
+ workspace: torch.Tensor,
144
+ size_m: int,
145
+ size_n: int,
146
+ size_k: int,
147
+ ) -> torch.Tensor:
148
+ return ops.marlin_qqq_gemm(
149
+ a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
150
+ )
151
+
152
+
153
+ # Fake ops
154
+
155
+ if hasattr(ops, "gptq_marlin_24_gemm"):
156
+ @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
+ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
+ b_scales: torch.Tensor, workspace: torch.Tensor,
159
+ num_bits: int, size_m: torch.SymInt,
160
+ size_n: torch.SymInt,
161
+ size_k: torch.SymInt) -> torch.Tensor:
162
+ return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
+
164
+ @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
+ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
+ b_meta: torch.Tensor, b_scales: torch.Tensor,
167
+ workspace: torch.Tensor,
168
+ b_q_type: ScalarType, size_m: torch.SymInt,
169
+ size_n: torch.SymInt,
170
+ size_k: torch.SymInt) -> torch.Tensor:
171
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
172
+
173
+ @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
+ def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
+ b_q_weight: torch.Tensor,
176
+ b_scales: torch.Tensor,
177
+ b_zeros: torch.Tensor,
178
+ g_idx: torch.Tensor,
179
+ perm: torch.Tensor,
180
+ workspace: torch.Tensor,
181
+ b_q_type: ScalarType,
182
+ size_m: torch.SymInt,
183
+ size_n: torch.SymInt,
184
+ size_k: torch.SymInt,
185
+ is_k_full: bool,
186
+ has_zp: bool = False,
187
+ use_fp32_reduce: bool = False,
188
+ is_zp_float: bool = False) -> torch.Tensor:
189
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
+
191
+ @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
192
+ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
193
+ s_tok: torch.Tensor, s_ch: torch.Tensor,
194
+ s_group: torch.Tensor, workspace: torch.Tensor,
195
+ size_m: torch.SymInt, size_n: torch.SymInt,
196
+ size_k: torch.SymInt) -> torch.Tensor:
197
+ return torch.empty((size_m, size_n),
198
+ dtype=torch.float16,
199
+ device=a.device)
200
+
201
+ @register_fake(add_op_namespace_prefix("marlin_gemm"))
202
+ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
203
+ b_scales: torch.Tensor, workspace: torch.Tensor,
204
+ size_m: torch.SymInt, size_n: torch.SymInt,
205
+ size_k: torch.SymInt) -> torch.Tensor:
206
+ return torch.empty((size_m, size_n),
207
+ dtype=torch.float16,
208
+ device=a.device)
build/torch24-cxx98-cu118-x86_64-linux/quantization/scalar_type.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import struct
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Optional, Union
6
+
7
+
8
+ # Mirrors enum in `core/scalar_type.hpp`
9
+ class NanRepr(Enum):
10
+ NONE = 0 # nans are not supported
11
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
12
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
13
+
14
+
15
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
16
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
17
+ # in sync until the inductor fully supports custom C++ classes.
18
+ @dataclass(frozen=True)
19
+ class ScalarType:
20
+ """
21
+ ScalarType can represent a wide range of floating point and integer
22
+ types, in particular it can be used to represent sub-byte data types
23
+ (something that torch.dtype currently does not support). It is also
24
+ capable of representing types with a bias, i.e.:
25
+ `stored_value = value + bias`,
26
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
27
+ of 8). The implementation for this class can be found in
28
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
29
+ with that file.
30
+ """
31
+
32
+ exponent: int
33
+ """
34
+ Number of bits in the exponent if this is a floating point type
35
+ (zero if this an integer type)
36
+ """
37
+
38
+ mantissa: int
39
+ """
40
+ Number of bits in the mantissa if this is a floating point type,
41
+ or the number bits representing an integer excluding the sign bit if
42
+ this an integer type.
43
+ """
44
+
45
+ signed: bool
46
+ "If the type is signed (i.e. has a sign bit)"
47
+
48
+ bias: int
49
+ """
50
+ bias used to encode the values in this scalar type
51
+ (value = stored_value - bias, default 0) for example if we store the
52
+ type as an unsigned integer with a bias of 128 then the value 0 will be
53
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
54
+ """
55
+
56
+ _finite_values_only: bool = False
57
+ """
58
+ Private: if infs are supported, used `has_infs()` instead.
59
+ """
60
+
61
+ nan_repr: NanRepr = NanRepr.IEEE_754
62
+ """
63
+ How NaNs are represent in this scalar type, returns NanRepr value.
64
+ (not applicable for integer types)
65
+ """
66
+
67
+ def _floating_point_max_int(self) -> int:
68
+ assert (
69
+ self.mantissa <= 52 and self.exponent <= 11
70
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
71
+
72
+ max_mantissa = (1 << self.mantissa) - 1
73
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
74
+ max_mantissa = max_mantissa - 1
75
+
76
+ max_exponent = (1 << self.exponent) - 2
77
+ if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
78
+ or self.nan_repr == NanRepr.NONE):
79
+ assert (
80
+ self.exponent < 11
81
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
82
+ max_exponent = max_exponent + 1
83
+
84
+ # adjust the exponent to match that of a double
85
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
86
+ # e is the exponent bits), there is some precedent for non-standard
87
+ # biases, example `float8_e4m3b11fnuz` here:
88
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
89
+ # complication we are just assuming the standard exponent bias until
90
+ # there is a need to support non-standard biases
91
+ exponent_bias = (1 << (self.exponent - 1)) - 1
92
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
93
+
94
+ max_exponent_double = (max_exponent - exponent_bias +
95
+ exponent_bias_double)
96
+
97
+ # shift the mantissa and exponent into the proper positions for an
98
+ # IEEE double and bitwise-or them together.
99
+ return (max_mantissa <<
100
+ (52 - self.mantissa)) | (max_exponent_double << 52)
101
+
102
+ def _floating_point_max(self) -> float:
103
+ double_raw = self._floating_point_max_int()
104
+ return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
105
+
106
+ def _raw_max(self) -> Union[int, float]:
107
+ if self.is_floating_point():
108
+ return self._floating_point_max()
109
+ else:
110
+ assert (self.size_bits < 64 or self.size_bits == 64
111
+ and self.is_signed()), "Cannot represent max as an int"
112
+ return (1 << self.mantissa) - 1
113
+
114
+ def _raw_min(self) -> Union[int, float]:
115
+ if self.is_floating_point():
116
+ assert self.is_signed(
117
+ ), "We currently assume all floating point types are signed"
118
+ sign_bit_double = 1 << 63
119
+
120
+ max_raw = self._floating_point_max_int()
121
+ min_raw = max_raw | sign_bit_double
122
+ return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
+ else:
124
+ assert (not self.is_signed() or
125
+ self.size_bits <= 64), "Cannot represent min as a int64_t"
126
+
127
+ if self.is_signed():
128
+ return -(1 << (self.size_bits - 1))
129
+ else:
130
+ return 0
131
+
132
+ @functools.cached_property
133
+ def id(self) -> int:
134
+ """
135
+ Convert the ScalarType to an int which can be passed to pytorch custom
136
+ ops. This layout of the int must be kept in sync with the C++
137
+ ScalarType's from_id method.
138
+ """
139
+ val = 0
140
+ offset = 0
141
+
142
+ def or_and_advance(member, bit_width):
143
+ nonlocal val
144
+ nonlocal offset
145
+ bit_mask = (1 << bit_width) - 1
146
+ val = val | (int(member) & bit_mask) << offset
147
+ offset = offset + bit_width
148
+
149
+ or_and_advance(self.exponent, 8)
150
+ or_and_advance(self.mantissa, 8)
151
+ or_and_advance(self.signed, 1)
152
+ or_and_advance(self.bias, 32)
153
+ or_and_advance(self._finite_values_only, 1)
154
+ or_and_advance(self.nan_repr.value, 8)
155
+
156
+ assert offset <= 64, \
157
+ f"ScalarType fields too big {offset} to fit into an int64"
158
+
159
+ return val
160
+
161
+ @property
162
+ def size_bits(self) -> int:
163
+ return self.exponent + self.mantissa + int(self.signed)
164
+
165
+ def min(self) -> Union[int, float]:
166
+ """
167
+ Min representable value for this scalar type.
168
+ (accounting for bias if there is one)
169
+ """
170
+ return self._raw_min() - self.bias
171
+
172
+ def max(self) -> Union[int, float]:
173
+ """
174
+ Max representable value for this scalar type.
175
+ (accounting for bias if there is one)
176
+ """
177
+ return self._raw_max() - self.bias
178
+
179
+ def is_signed(self) -> bool:
180
+ """
181
+ If the type is signed (i.e. has a sign bit), same as `signed`
182
+ added for consistency with:
183
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
184
+ """
185
+ return self.signed
186
+
187
+ def is_floating_point(self) -> bool:
188
+ "If the type is a floating point type"
189
+ return self.exponent != 0
190
+
191
+ def is_integer(self) -> bool:
192
+ "If the type is an integer type"
193
+ return self.exponent == 0
194
+
195
+ def has_bias(self) -> bool:
196
+ "If the type has a non-zero bias"
197
+ return self.bias != 0
198
+
199
+ def has_infs(self) -> bool:
200
+ "If the type is floating point and supports infinity"
201
+ return not self._finite_values_only
202
+
203
+ def has_nans(self) -> bool:
204
+ return self.nan_repr != NanRepr.NONE.value
205
+
206
+ def is_ieee_754(self) -> bool:
207
+ """
208
+ If the type is a floating point type that follows IEEE 754
209
+ conventions
210
+ """
211
+ return self.nan_repr == NanRepr.IEEE_754.value and \
212
+ not self._finite_values_only
213
+
214
+ def __str__(self) -> str:
215
+ """
216
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
217
+ for floating point types (leading f) the scheme is:
218
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
219
+ flags:
220
+ - no-flags: means it follows IEEE 754 conventions
221
+ - f: means finite values only (no infinities)
222
+ - n: means nans are supported (non-standard encoding)
223
+ for integer types the scheme is:
224
+ `[u]int<size_bits>[b<bias>]`
225
+ - if bias is not present it means its zero
226
+ """
227
+ if self.is_floating_point():
228
+ ret = "float" + str(self.size_bits) + "_e" + str(
229
+ self.exponent) + "m" + str(self.mantissa)
230
+
231
+ if not self.is_ieee_754():
232
+ if self._finite_values_only:
233
+ ret = ret + "f"
234
+ if self.nan_repr != NanRepr.NONE:
235
+ ret = ret + "n"
236
+
237
+ return ret
238
+ else:
239
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
240
+ if self.has_bias():
241
+ ret = ret + "b" + str(self.bias)
242
+ return ret
243
+
244
+ def __repr__(self) -> str:
245
+ return "ScalarType." + self.__str__()
246
+
247
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
248
+ # opcheck to work.
249
+ def __len__(self) -> int:
250
+ raise TypeError
251
+
252
+ #
253
+ # Convenience Constructors
254
+ #
255
+
256
+ @classmethod
257
+ def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
258
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
259
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
260
+ ret.id # noqa B018: make sure the id is cached
261
+ return ret
262
+
263
+ @classmethod
264
+ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
+ """Create a unsigned integer scalar type."""
266
+ ret = cls(0, size_bits, False, bias if bias else 0)
267
+ ret.id # noqa B018: make sure the id is cached
268
+ return ret
269
+
270
+ @classmethod
271
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
272
+ """
273
+ Create a standard floating point type
274
+ (i.e. follows IEEE 754 conventions).
275
+ """
276
+ assert (mantissa > 0 and exponent > 0)
277
+ ret = cls(exponent, mantissa, True, 0)
278
+ ret.id # noqa B018: make sure the id is cached
279
+ return ret
280
+
281
+ @classmethod
282
+ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
283
+ nan_repr: NanRepr) -> 'ScalarType':
284
+ """
285
+ Create a non-standard floating point type
286
+ (i.e. does not follow IEEE 754 conventions).
287
+ """
288
+ assert (mantissa > 0 and exponent > 0)
289
+ assert (nan_repr != NanRepr.IEEE_754), (
290
+ "use `float_IEEE754` constructor for floating point types that "
291
+ "follow IEEE 754 conventions")
292
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
293
+ ret.id # noqa B018: make sure the id is cached
294
+ return ret
295
+
296
+
297
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
+ # for floating point types (leading f) the scheme is:
299
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
300
+ # flags:
301
+ # - no-flags: means it follows IEEE 754 conventions
302
+ # - f: means finite values only (no infinities)
303
+ # - n: means nans are supported (non-standard encoding)
304
+ # for integer types the scheme is:
305
+ # `[u]int<size_bits>[b<bias>]`
306
+ # - if bias is not present it means its zero
307
+
308
+
309
+ class scalar_types:
310
+ int4 = ScalarType.int_(4, None)
311
+ uint4 = ScalarType.uint(4, None)
312
+ int8 = ScalarType.int_(8, None)
313
+ uint8 = ScalarType.uint(8, None)
314
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
315
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
316
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
317
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
318
+
319
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
+
322
+ # "gptq" types
323
+ uint2b2 = ScalarType.uint(2, 2)
324
+ uint3b4 = ScalarType.uint(3, 4)
325
+ uint4b8 = ScalarType.uint(4, 8)
326
+ uint8b128 = ScalarType.uint(8, 128)
327
+
328
+ # colloquial names
329
+ bfloat16 = float16_e8m7
330
+ float16 = float16_e5m10
build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py ADDED
File without changes