danieldk HF Staff commited on
Commit
41b5840
·
1 Parent(s): 85bad96
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py +3 -3
  2. build/torch25-cxx11-cu118-x86_64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_85bad96.abi3.so} +2 -2
  3. build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py +3 -3
  4. build/torch25-cxx11-cu121-x86_64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_85bad96.abi3.so} +2 -2
  5. build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py +3 -3
  6. build/torch25-cxx11-cu124-x86_64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_85bad96.abi3.so} +2 -2
  7. build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py +3 -3
  8. build/torch25-cxx98-cu118-x86_64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_85bad96.abi3.so} +2 -2
  9. build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py +3 -3
  10. build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  11. build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  12. build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py +3 -3
  13. build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  14. build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  15. build/torch26-cxx11-cu118-x86_64-linux/quantization/_ops.py +3 -3
  16. build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  17. build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  18. build/torch26-cxx11-cu124-x86_64-linux/quantization/_ops.py +3 -3
  19. build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  20. build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  21. build/torch26-cxx11-cu126-x86_64-linux/quantization/_ops.py +3 -3
  22. build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  23. build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  24. build/torch26-cxx11-rocm62-x86_64-linux/quantization/__init__.py +0 -39
  25. build/torch26-cxx11-rocm62-x86_64-linux/quantization/_ops.py +0 -9
  26. build/torch26-cxx11-rocm62-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  27. build/torch26-cxx11-rocm62-x86_64-linux/quantization/compressed_tensors.py +0 -110
  28. build/torch26-cxx11-rocm62-x86_64-linux/quantization/cutlass.py +0 -75
  29. build/torch26-cxx11-rocm62-x86_64-linux/quantization/marlin.py +0 -208
  30. build/torch26-cxx11-rocm62-x86_64-linux/quantization/scalar_type.py +0 -330
  31. build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/__init__.py +0 -0
  32. build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils.py +0 -391
  33. build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_fp8.py +0 -100
  34. build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test.py +0 -162
  35. build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_24.py +0 -473
  36. build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +0 -125
  37. build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/quant_utils.py +0 -470
  38. build/torch26-cxx98-cu118-x86_64-linux/quantization/_ops.py +3 -3
  39. build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  40. build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  41. build/torch26-cxx98-cu124-x86_64-linux/quantization/_ops.py +3 -3
  42. build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  43. build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  44. build/torch26-cxx98-cu126-x86_64-linux/quantization/_ops.py +3 -3
  45. build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  46. build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  47. build/torch27-cxx11-cu118-x86_64-linux/quantization/_ops.py +3 -3
  48. build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +0 -3
  49. build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_85bad96.abi3.so +3 -0
  50. build/torch27-cxx11-cu126-x86_64-linux/quantization/_ops.py +3 -3
build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch25-cxx11-cu118-x86_64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_85bad96.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b01b2ee690ad54303926af36debf43382f596fee6396822365b8ea88ae284eec
3
- size 63485168
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87d7a88222f779536dfac6c2cfbbc77860d30ba000102d7a799d23c81f70ccd2
3
+ size 87836032
build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch25-cxx11-cu121-x86_64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_85bad96.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8531abccfae1c201e83ad1279bdb092dd77a89e4dc7bc166bbe0625e2bbc6665
3
- size 64993040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea0ed35f3ac02d07df1e2618da72c1db734ccf4bf6c325de8e590280d7166528
3
+ size 90990352
build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch25-cxx11-cu124-x86_64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_85bad96.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:290d7f6a8c742481b6655732a3c2d61567fb7bd24c69e23a98dd1eff94895db6
3
- size 67517912
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d2ade93d66de219962f3b8e812e79042c42cbf654d775e53a109176d0d8200c
3
+ size 93728312
build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch25-cxx98-cu118-x86_64-linux/quantization/{_quantization_0435ccb.abi3.so → _quantization_85bad96.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:893739f27c86e11a259df04fd24021f374256e434339f682baaf0c5fccfc3c8a
3
- size 63468944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0539e415b6f20e6660b324c19f77f2a7738c666a321a86bfcfbbba88044b257b
3
+ size 87811624
build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6692deb3c40c4bcee0ff28bf9d426c843f1a858e2a0bd12d92b5332c0adff4cf
3
- size 64992856
 
 
 
 
build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13400949025d98b9a2465d6277121da0d149b381e523404f870c6a8d948b3128
3
+ size 90994264
build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f232faff4a6793b272825d95405d811f3ccbf8c2393e127d3f6772ff2441f165
3
- size 67519424
 
 
 
 
build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbb5a4c5ab7bc96bf6efc98bb452017ec414ce9b30d9e22d7fb389728f780d27
3
+ size 93721640
build/torch26-cxx11-cu118-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:932f97f2f20cfe21d3f9b75494026e85fc7552c0aac43113ad1af6715a32482c
3
- size 63484368
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56da42995a3fb8fb62656c51c61588f780cd580bfb46fc72a84059760c518cd1
3
+ size 87822944
build/torch26-cxx11-cu124-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:36d47af9178b6187e6a14651f30b21f6038d31ed591688aba5c03b29b0bf88cc
3
- size 67517488
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04fc245b191d5e3bea3debd7383257455edda3ac7d5f4aa372ceea04e767f0ef
3
+ size 93719704
build/torch26-cxx11-cu126-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6b7a98b58caa01f436b3f089dfb62e2ec96a85ffdfad621f332701e0bc69b6a8
3
- size 68279984
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdcffa1ede49fa22a1ea2b534261466a969444f7498425870e50c1e503b789b5
3
+ size 94514968
build/torch26-cxx11-rocm62-x86_64-linux/quantization/__init__.py DELETED
@@ -1,39 +0,0 @@
1
- from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
- from .cutlass import (
3
- cutlass_scaled_mm_supports_fp8,
4
- cutlass_scaled_mm,
5
- cutlass_scaled_mm_azp,
6
- )
7
- from .marlin import (
8
- awq_marlin_repack,
9
- fp8_marlin_gemm,
10
- gptq_marlin_gemm,
11
- gptq_marlin_repack,
12
- gptq_marlin_24_gemm,
13
- marlin_qqq_gemm,
14
- marlin_gemm,
15
- )
16
- from .scalar_type import (
17
- ScalarType,
18
- scalar_types,
19
- )
20
- from ._ops import ops
21
-
22
-
23
- __all__ = [
24
- "ScalarType",
25
- "awq_marlin_repack",
26
- "cutlass_scaled_mm",
27
- "cutlass_scaled_mm_azp",
28
- "cutlass_scaled_mm_supports_fp8",
29
- "fp8_marlin_gemm",
30
- "gptq_marlin_24_gemm",
31
- "gptq_marlin_gemm",
32
- "gptq_marlin_repack",
33
- "marlin_gemm",
34
- "marlin_qqq_gemm",
35
- "ops",
36
- "scalar_types",
37
- "scaled_fp8_quant",
38
- "scaled_int8_quant",
39
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_quantization_0435ccb::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7eacdfbcfd25927283bc5c3653704a6c27da69a7c8ba7ef68f0691a66679054c
3
- size 2878744
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/compressed_tensors.py DELETED
@@ -1,110 +0,0 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch
4
-
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
-
16
-
17
- # fp8
18
- def scaled_fp8_quant(
19
- input: torch.Tensor,
20
- scale: Optional[torch.Tensor] = None,
21
- num_token_padding: Optional[int] = None,
22
- scale_ub: Optional[torch.Tensor] = None,
23
- use_per_token_if_dynamic: bool = False,
24
- ) -> Tuple[torch.Tensor, torch.Tensor]:
25
- """
26
- Quantize input tensor to FP8 and return quantized tensor and scale.
27
-
28
- This function supports both static and dynamic quantization: If you
29
- provide the scale, it will use static scaling and if you omit it,
30
- the scale will be determined dynamically. The function also allows
31
- optional padding of the output tensors for downstream kernels that
32
- will benefit from padding.
33
-
34
- Args:
35
- input: The input tensor to be quantized to FP8
36
- scale: Optional scaling factor for the FP8 quantization
37
- scale_ub: Optional upper bound for scaling factor in dynamic
38
- per token case
39
- num_token_padding: If specified, pad the first dimension
40
- of the output to at least this value.
41
- use_per_token_if_dynamic: Whether to do per_tensor or per_token
42
- in the dynamic quantization case.
43
-
44
- Returns:
45
- Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
46
- scaling factor.
47
- """
48
- # This code assumes batch_dim and num_tokens are flattened
49
- assert input.ndim == 2
50
- shape: Union[Tuple[int, int], torch.Size] = input.shape
51
- # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
52
- # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
53
- # if current_platform.is_rocm() else torch.float8_e4m3fn
54
- out_dtype = torch.float8_e4m3fn
55
- if num_token_padding:
56
- shape = (max(num_token_padding, input.shape[0]), shape[1])
57
- output = torch.empty(shape, device=input.device, dtype=out_dtype)
58
-
59
- if scale is None:
60
- if use_per_token_if_dynamic:
61
- scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
62
- ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
63
- else:
64
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
65
- ops.dynamic_scaled_fp8_quant(output, input, scale)
66
- else:
67
- # num_token_padding not implemented for this case
68
- assert scale.numel() == 1 or num_token_padding is None
69
- ops.static_scaled_fp8_quant(output, input, scale)
70
-
71
- return output, scale
72
-
73
-
74
- # int8
75
- def scaled_int8_quant(
76
- input: torch.Tensor,
77
- scale: Optional[torch.Tensor] = None,
78
- azp: Optional[torch.Tensor] = None,
79
- symmetric: bool = True,
80
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
81
- """
82
- Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
83
-
84
- Args:
85
- input: The input tensor to be quantized to int8.
86
- scale: Optional scaling factor for the int8 quantization.
87
- When not provided, we invoke dynamic-per-token quantization.
88
- azp: Optional zero-point for the int8 quantization.
89
- Must be provided for asymmetric quantization if `scale` is provided.
90
- symmetric: Whether to use symmetric quantization (scale only, azp ignored).
91
-
92
- Returns:
93
- Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
94
- """
95
- output = torch.empty_like(input, dtype=torch.int8)
96
- if scale is not None:
97
- # static-per-tensor quantization.
98
- assert symmetric == (
99
- azp is None
100
- ), "azp must only be provided for asymmetric quantization."
101
- ops.static_scaled_int8_quant(output, input, scale, azp)
102
- return output, scale, azp
103
-
104
- # dynamic-per-token quantization.
105
- input_scales = torch.empty(
106
- (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
107
- )
108
- input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
109
- ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
110
- return output, input_scales, input_azp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/cutlass.py DELETED
@@ -1,75 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
-
5
- try:
6
- from ._ops import ops
7
- except ImportError as e:
8
- # Fallback for local development.
9
- try:
10
- import _quantization
11
-
12
- ops = torch.ops._quantization
13
- except ImportError:
14
- raise e
15
-
16
-
17
- def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
18
- return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
19
-
20
-
21
- def cutlass_scaled_mm(
22
- a: torch.Tensor,
23
- b: torch.Tensor,
24
- scale_a: torch.Tensor,
25
- scale_b: torch.Tensor,
26
- out_dtype: torch.dtype,
27
- bias: Optional[torch.Tensor] = None,
28
- ) -> torch.Tensor:
29
- assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
30
- assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
31
- assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
32
-
33
- m = a.shape[0]
34
- n = b.shape[1]
35
-
36
- # if current_platform.is_rocm():
37
- # triton_scaled_mm_module = importlib.import_module(
38
- # "vllm.model_executor.layers.quantization.compressed_tensors."
39
- # "triton_scaled_mm")
40
- # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
41
- # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
42
-
43
- out = torch.empty((m, n), dtype=out_dtype, device=a.device)
44
-
45
- ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
46
-
47
- return out
48
-
49
-
50
- def cutlass_scaled_mm_azp(
51
- a: torch.Tensor,
52
- b: torch.Tensor,
53
- scale_a: torch.Tensor,
54
- scale_b: torch.Tensor,
55
- out_dtype: torch.dtype,
56
- azp_adj: torch.Tensor,
57
- azp: Optional[torch.Tensor] = None,
58
- bias: Optional[torch.Tensor] = None,
59
- ) -> torch.Tensor:
60
- """
61
- :param azp_adj: In the per-tensor case, this should include the azp.
62
- Always per-channel.
63
- :param azp: Only set in the per-token case. Per-token if set.
64
- """
65
- assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
66
- assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
67
- assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
68
- assert azp is None or azp.numel() == a.shape[0]
69
-
70
- m = a.shape[0]
71
- n = b.shape[1]
72
- out = torch.empty((m, n), dtype=out_dtype, device=a.device)
73
-
74
- ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
75
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/marlin.py DELETED
@@ -1,208 +0,0 @@
1
- from typing import TYPE_CHECKING
2
-
3
- import torch
4
-
5
- # neuron has torch version that doesn't even have impl_abstract
6
- if TYPE_CHECKING:
7
- def register_fake(fn):
8
- return lambda name: fn
9
- else:
10
- try:
11
- from torch.library import register_fake
12
- except ImportError:
13
- from torch.library import impl_abstract as register_fake
14
-
15
- try:
16
- from ._ops import ops, add_op_namespace_prefix
17
- except ImportError as e:
18
- # Fallback for local development.
19
- try:
20
- import _quantization
21
-
22
- ops = torch.ops._quantization
23
-
24
- def add_op_namespace_prefix(op_name: str):
25
- return f"_quantization::{op_name}"
26
- except ImportError:
27
- raise e
28
-
29
-
30
- from .scalar_type import ScalarType
31
-
32
-
33
- # fp8 marlin
34
- def fp8_marlin_gemm(
35
- a: torch.Tensor,
36
- b_q_weight: torch.Tensor,
37
- b_scales: torch.Tensor,
38
- workspace: torch.Tensor,
39
- num_bits: int,
40
- size_m: int,
41
- size_n: int,
42
- size_k: int,
43
- ) -> torch.Tensor:
44
- return ops.fp8_marlin_gemm(
45
- a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
46
- )
47
-
48
-
49
- # gptq_marlin
50
- def gptq_marlin_gemm(
51
- a: torch.Tensor,
52
- b_q_weight: torch.Tensor,
53
- b_scales: torch.Tensor,
54
- b_zeros: torch.Tensor,
55
- g_idx: torch.Tensor,
56
- perm: torch.Tensor,
57
- workspace: torch.Tensor,
58
- b_q_type: ScalarType,
59
- size_m: int,
60
- size_n: int,
61
- size_k: int,
62
- is_k_full: bool,
63
- has_zp: bool = False,
64
- use_fp32_reduce: bool = False,
65
- is_zp_float: bool = False,
66
- ) -> torch.Tensor:
67
- return ops.gptq_marlin_gemm(
68
- a,
69
- b_q_weight,
70
- b_scales,
71
- b_zeros,
72
- g_idx,
73
- perm,
74
- workspace,
75
- b_q_type.id,
76
- size_m,
77
- size_n,
78
- size_k,
79
- is_k_full,
80
- has_zp,
81
- use_fp32_reduce,
82
- is_zp_float,
83
- )
84
-
85
-
86
- # gptq_marlin
87
- def gptq_marlin_repack(
88
- b_q_weight: torch.Tensor,
89
- perm: torch.Tensor,
90
- size_k: int,
91
- size_n: int,
92
- num_bits: int,
93
- ) -> torch.Tensor:
94
- return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
95
-
96
-
97
- # gptq_marlin
98
- def awq_marlin_repack(
99
- b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
100
- ) -> torch.Tensor:
101
- return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
102
-
103
-
104
- # marlin
105
- def marlin_gemm(
106
- a: torch.Tensor,
107
- b_q_weight: torch.Tensor,
108
- b_scales: torch.Tensor,
109
- workspace: torch.Tensor,
110
- size_m: int,
111
- size_n: int,
112
- size_k: int,
113
- ) -> torch.Tensor:
114
- return ops.marlin_gemm(
115
- a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
116
- )
117
-
118
-
119
- # marlin_24
120
- def gptq_marlin_24_gemm(
121
- a: torch.Tensor,
122
- b_q_weight: torch.Tensor,
123
- b_meta: torch.Tensor,
124
- b_scales: torch.Tensor,
125
- workspace: torch.Tensor,
126
- b_q_type: ScalarType,
127
- size_m: int,
128
- size_n: int,
129
- size_k: int,
130
- ) -> torch.Tensor:
131
- return ops.gptq_marlin_24_gemm(
132
- a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
133
- )
134
-
135
-
136
- # qqq ops
137
- def marlin_qqq_gemm(
138
- a: torch.Tensor,
139
- b_q_weight: torch.Tensor,
140
- s_tok: torch.Tensor,
141
- s_ch: torch.Tensor,
142
- s_group: torch.Tensor,
143
- workspace: torch.Tensor,
144
- size_m: int,
145
- size_n: int,
146
- size_k: int,
147
- ) -> torch.Tensor:
148
- return ops.marlin_qqq_gemm(
149
- a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
150
- )
151
-
152
-
153
- # Fake ops
154
-
155
- if hasattr(ops, "gptq_marlin_24_gemm"):
156
- @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
157
- def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
158
- b_scales: torch.Tensor, workspace: torch.Tensor,
159
- num_bits: int, size_m: torch.SymInt,
160
- size_n: torch.SymInt,
161
- size_k: torch.SymInt) -> torch.Tensor:
162
- return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
163
-
164
- @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
165
- def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
166
- b_meta: torch.Tensor, b_scales: torch.Tensor,
167
- workspace: torch.Tensor,
168
- b_q_type: ScalarType, size_m: torch.SymInt,
169
- size_n: torch.SymInt,
170
- size_k: torch.SymInt) -> torch.Tensor:
171
- return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
172
-
173
- @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
174
- def _gptq_marlin_gemm_fake(a: torch.Tensor,
175
- b_q_weight: torch.Tensor,
176
- b_scales: torch.Tensor,
177
- b_zeros: torch.Tensor,
178
- g_idx: torch.Tensor,
179
- perm: torch.Tensor,
180
- workspace: torch.Tensor,
181
- b_q_type: ScalarType,
182
- size_m: torch.SymInt,
183
- size_n: torch.SymInt,
184
- size_k: torch.SymInt,
185
- is_k_full: bool,
186
- has_zp: bool = False,
187
- use_fp32_reduce: bool = False,
188
- is_zp_float: bool = False) -> torch.Tensor:
189
- return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
190
-
191
- @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
192
- def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
193
- s_tok: torch.Tensor, s_ch: torch.Tensor,
194
- s_group: torch.Tensor, workspace: torch.Tensor,
195
- size_m: torch.SymInt, size_n: torch.SymInt,
196
- size_k: torch.SymInt) -> torch.Tensor:
197
- return torch.empty((size_m, size_n),
198
- dtype=torch.float16,
199
- device=a.device)
200
-
201
- @register_fake(add_op_namespace_prefix("marlin_gemm"))
202
- def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
203
- b_scales: torch.Tensor, workspace: torch.Tensor,
204
- size_m: torch.SymInt, size_n: torch.SymInt,
205
- size_k: torch.SymInt) -> torch.Tensor:
206
- return torch.empty((size_m, size_n),
207
- dtype=torch.float16,
208
- device=a.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/scalar_type.py DELETED
@@ -1,330 +0,0 @@
1
- import functools
2
- import struct
3
- from dataclasses import dataclass
4
- from enum import Enum
5
- from typing import Optional, Union
6
-
7
-
8
- # Mirrors enum in `core/scalar_type.hpp`
9
- class NanRepr(Enum):
10
- NONE = 0 # nans are not supported
11
- IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
12
- EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
13
-
14
-
15
- # This ScalarType class is a parallel implementation of the C++ ScalarType
16
- # class found in csrc/core/scalar_type.hpp. These two classes should be kept
17
- # in sync until the inductor fully supports custom C++ classes.
18
- @dataclass(frozen=True)
19
- class ScalarType:
20
- """
21
- ScalarType can represent a wide range of floating point and integer
22
- types, in particular it can be used to represent sub-byte data types
23
- (something that torch.dtype currently does not support). It is also
24
- capable of representing types with a bias, i.e.:
25
- `stored_value = value + bias`,
26
- this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
27
- of 8). The implementation for this class can be found in
28
- csrc/core/scalar_type.hpp, these type signatures should be kept in sync
29
- with that file.
30
- """
31
-
32
- exponent: int
33
- """
34
- Number of bits in the exponent if this is a floating point type
35
- (zero if this an integer type)
36
- """
37
-
38
- mantissa: int
39
- """
40
- Number of bits in the mantissa if this is a floating point type,
41
- or the number bits representing an integer excluding the sign bit if
42
- this an integer type.
43
- """
44
-
45
- signed: bool
46
- "If the type is signed (i.e. has a sign bit)"
47
-
48
- bias: int
49
- """
50
- bias used to encode the values in this scalar type
51
- (value = stored_value - bias, default 0) for example if we store the
52
- type as an unsigned integer with a bias of 128 then the value 0 will be
53
- stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
54
- """
55
-
56
- _finite_values_only: bool = False
57
- """
58
- Private: if infs are supported, used `has_infs()` instead.
59
- """
60
-
61
- nan_repr: NanRepr = NanRepr.IEEE_754
62
- """
63
- How NaNs are represent in this scalar type, returns NanRepr value.
64
- (not applicable for integer types)
65
- """
66
-
67
- def _floating_point_max_int(self) -> int:
68
- assert (
69
- self.mantissa <= 52 and self.exponent <= 11
70
- ), f"Cannot represent max/min as a double for type {self.__str__()}"
71
-
72
- max_mantissa = (1 << self.mantissa) - 1
73
- if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
74
- max_mantissa = max_mantissa - 1
75
-
76
- max_exponent = (1 << self.exponent) - 2
77
- if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
78
- or self.nan_repr == NanRepr.NONE):
79
- assert (
80
- self.exponent < 11
81
- ), f"Cannot represent max/min as a double for type {self.__str__()}"
82
- max_exponent = max_exponent + 1
83
-
84
- # adjust the exponent to match that of a double
85
- # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
86
- # e is the exponent bits), there is some precedent for non-standard
87
- # biases, example `float8_e4m3b11fnuz` here:
88
- # https://github.com/jax-ml/ml_dtypes but to avoid premature over
89
- # complication we are just assuming the standard exponent bias until
90
- # there is a need to support non-standard biases
91
- exponent_bias = (1 << (self.exponent - 1)) - 1
92
- exponent_bias_double = (1 << 10) - 1 # double e = 11
93
-
94
- max_exponent_double = (max_exponent - exponent_bias +
95
- exponent_bias_double)
96
-
97
- # shift the mantissa and exponent into the proper positions for an
98
- # IEEE double and bitwise-or them together.
99
- return (max_mantissa <<
100
- (52 - self.mantissa)) | (max_exponent_double << 52)
101
-
102
- def _floating_point_max(self) -> float:
103
- double_raw = self._floating_point_max_int()
104
- return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
105
-
106
- def _raw_max(self) -> Union[int, float]:
107
- if self.is_floating_point():
108
- return self._floating_point_max()
109
- else:
110
- assert (self.size_bits < 64 or self.size_bits == 64
111
- and self.is_signed()), "Cannot represent max as an int"
112
- return (1 << self.mantissa) - 1
113
-
114
- def _raw_min(self) -> Union[int, float]:
115
- if self.is_floating_point():
116
- assert self.is_signed(
117
- ), "We currently assume all floating point types are signed"
118
- sign_bit_double = 1 << 63
119
-
120
- max_raw = self._floating_point_max_int()
121
- min_raw = max_raw | sign_bit_double
122
- return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
123
- else:
124
- assert (not self.is_signed() or
125
- self.size_bits <= 64), "Cannot represent min as a int64_t"
126
-
127
- if self.is_signed():
128
- return -(1 << (self.size_bits - 1))
129
- else:
130
- return 0
131
-
132
- @functools.cached_property
133
- def id(self) -> int:
134
- """
135
- Convert the ScalarType to an int which can be passed to pytorch custom
136
- ops. This layout of the int must be kept in sync with the C++
137
- ScalarType's from_id method.
138
- """
139
- val = 0
140
- offset = 0
141
-
142
- def or_and_advance(member, bit_width):
143
- nonlocal val
144
- nonlocal offset
145
- bit_mask = (1 << bit_width) - 1
146
- val = val | (int(member) & bit_mask) << offset
147
- offset = offset + bit_width
148
-
149
- or_and_advance(self.exponent, 8)
150
- or_and_advance(self.mantissa, 8)
151
- or_and_advance(self.signed, 1)
152
- or_and_advance(self.bias, 32)
153
- or_and_advance(self._finite_values_only, 1)
154
- or_and_advance(self.nan_repr.value, 8)
155
-
156
- assert offset <= 64, \
157
- f"ScalarType fields too big {offset} to fit into an int64"
158
-
159
- return val
160
-
161
- @property
162
- def size_bits(self) -> int:
163
- return self.exponent + self.mantissa + int(self.signed)
164
-
165
- def min(self) -> Union[int, float]:
166
- """
167
- Min representable value for this scalar type.
168
- (accounting for bias if there is one)
169
- """
170
- return self._raw_min() - self.bias
171
-
172
- def max(self) -> Union[int, float]:
173
- """
174
- Max representable value for this scalar type.
175
- (accounting for bias if there is one)
176
- """
177
- return self._raw_max() - self.bias
178
-
179
- def is_signed(self) -> bool:
180
- """
181
- If the type is signed (i.e. has a sign bit), same as `signed`
182
- added for consistency with:
183
- https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
184
- """
185
- return self.signed
186
-
187
- def is_floating_point(self) -> bool:
188
- "If the type is a floating point type"
189
- return self.exponent != 0
190
-
191
- def is_integer(self) -> bool:
192
- "If the type is an integer type"
193
- return self.exponent == 0
194
-
195
- def has_bias(self) -> bool:
196
- "If the type has a non-zero bias"
197
- return self.bias != 0
198
-
199
- def has_infs(self) -> bool:
200
- "If the type is floating point and supports infinity"
201
- return not self._finite_values_only
202
-
203
- def has_nans(self) -> bool:
204
- return self.nan_repr != NanRepr.NONE.value
205
-
206
- def is_ieee_754(self) -> bool:
207
- """
208
- If the type is a floating point type that follows IEEE 754
209
- conventions
210
- """
211
- return self.nan_repr == NanRepr.IEEE_754.value and \
212
- not self._finite_values_only
213
-
214
- def __str__(self) -> str:
215
- """
216
- naming generally follows: https://github.com/jax-ml/ml_dtypes
217
- for floating point types (leading f) the scheme is:
218
- `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
219
- flags:
220
- - no-flags: means it follows IEEE 754 conventions
221
- - f: means finite values only (no infinities)
222
- - n: means nans are supported (non-standard encoding)
223
- for integer types the scheme is:
224
- `[u]int<size_bits>[b<bias>]`
225
- - if bias is not present it means its zero
226
- """
227
- if self.is_floating_point():
228
- ret = "float" + str(self.size_bits) + "_e" + str(
229
- self.exponent) + "m" + str(self.mantissa)
230
-
231
- if not self.is_ieee_754():
232
- if self._finite_values_only:
233
- ret = ret + "f"
234
- if self.nan_repr != NanRepr.NONE:
235
- ret = ret + "n"
236
-
237
- return ret
238
- else:
239
- ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
240
- if self.has_bias():
241
- ret = ret + "b" + str(self.bias)
242
- return ret
243
-
244
- def __repr__(self) -> str:
245
- return "ScalarType." + self.__str__()
246
-
247
- # __len__ needs to be defined (and has to throw TypeError) for pytorch's
248
- # opcheck to work.
249
- def __len__(self) -> int:
250
- raise TypeError
251
-
252
- #
253
- # Convenience Constructors
254
- #
255
-
256
- @classmethod
257
- def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
258
- "Create a signed integer scalar type (size_bits includes sign-bit)."
259
- ret = cls(0, size_bits - 1, True, bias if bias else 0)
260
- ret.id # noqa B018: make sure the id is cached
261
- return ret
262
-
263
- @classmethod
264
- def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
- """Create a unsigned integer scalar type."""
266
- ret = cls(0, size_bits, False, bias if bias else 0)
267
- ret.id # noqa B018: make sure the id is cached
268
- return ret
269
-
270
- @classmethod
271
- def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
272
- """
273
- Create a standard floating point type
274
- (i.e. follows IEEE 754 conventions).
275
- """
276
- assert (mantissa > 0 and exponent > 0)
277
- ret = cls(exponent, mantissa, True, 0)
278
- ret.id # noqa B018: make sure the id is cached
279
- return ret
280
-
281
- @classmethod
282
- def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
283
- nan_repr: NanRepr) -> 'ScalarType':
284
- """
285
- Create a non-standard floating point type
286
- (i.e. does not follow IEEE 754 conventions).
287
- """
288
- assert (mantissa > 0 and exponent > 0)
289
- assert (nan_repr != NanRepr.IEEE_754), (
290
- "use `float_IEEE754` constructor for floating point types that "
291
- "follow IEEE 754 conventions")
292
- ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
293
- ret.id # noqa B018: make sure the id is cached
294
- return ret
295
-
296
-
297
- # naming generally follows: https://github.com/jax-ml/ml_dtypes
298
- # for floating point types (leading f) the scheme is:
299
- # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
300
- # flags:
301
- # - no-flags: means it follows IEEE 754 conventions
302
- # - f: means finite values only (no infinities)
303
- # - n: means nans are supported (non-standard encoding)
304
- # for integer types the scheme is:
305
- # `[u]int<size_bits>[b<bias>]`
306
- # - if bias is not present it means its zero
307
-
308
-
309
- class scalar_types:
310
- int4 = ScalarType.int_(4, None)
311
- uint4 = ScalarType.uint(4, None)
312
- int8 = ScalarType.int_(8, None)
313
- uint8 = ScalarType.uint(8, None)
314
- float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
315
- float8_e5m2 = ScalarType.float_IEEE754(5, 2)
316
- float16_e8m7 = ScalarType.float_IEEE754(8, 7)
317
- float16_e5m10 = ScalarType.float_IEEE754(5, 10)
318
-
319
- # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
320
- float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
321
-
322
- # "gptq" types
323
- uint2b2 = ScalarType.uint(2, 2)
324
- uint3b4 = ScalarType.uint(3, 4)
325
- uint4b8 = ScalarType.uint(4, 8)
326
- uint8b128 = ScalarType.uint(8, 128)
327
-
328
- # colloquial names
329
- bfloat16 = float16_e8m7
330
- float16 = float16_e5m10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/__init__.py DELETED
File without changes
build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils.py DELETED
@@ -1,391 +0,0 @@
1
- from typing import List, Optional, Tuple
2
-
3
- import numpy
4
- import torch
5
-
6
- import quantization as ops
7
- from quantization.scalar_type import ScalarType, scalar_types
8
-
9
- from .quant_utils import pack_cols, unpack_cols
10
-
11
- GPTQ_MARLIN_TILE = 16
12
- GPTQ_MARLIN_MIN_THREAD_N = 64
13
- GPTQ_MARLIN_MIN_THREAD_K = 128
14
- GPTQ_MARLIN_MAX_PARALLEL = 16
15
-
16
- GPTQ_MARLIN_24_TILE = 16
17
- GPTQ_MARLIN_24_MIN_THREAD_N = 128
18
- GPTQ_MARLIN_24_MIN_THREAD_K = 128
19
- GPTQ_MARLIN_24_MAX_PARALLEL = 64
20
-
21
- GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
22
- GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
23
-
24
- MARLIN_QQQ_TILE = 16
25
- MARLIN_QQQ_MIN_THREAD_N = 64
26
- MARLIN_QQQ_MIN_THREAD_K = 128
27
- MARLIN_QQQ_MAX_PARALLEL = 16
28
-
29
- MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
30
- MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
31
- MARLIN_QQQ_SUPPORTED_SYM = [True]
32
-
33
- MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
34
-
35
- # In case there is a performance issue with Marlin, the variable below can be
36
- # changed to False, which allows Marlin to perform global reductions in fp16
37
- # precision (instead of fp32), and therefore, save on some memory movements.
38
- USE_FP32_REDUCE_DEFAULT = True
39
-
40
-
41
- # For binary size and compile time, we don't support the same types for with and
42
- # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
43
- # TODO: we may want to move this into the C++ so its closer to the actual impl
44
- def query_marlin_supported_quant_types(
45
- has_zp: bool, device_capability: Optional[int] = None
46
- ):
47
- if device_capability is None:
48
- capability_tuple = torch.cuda.get_device_capability()
49
- device_capability = capability_tuple[0] * 10 + capability_tuple[1]
50
-
51
- if device_capability < 80:
52
- return []
53
-
54
- if has_zp:
55
- # AWQ style, unsigned + runtime zero-point
56
- return [scalar_types.uint4, scalar_types.uint8]
57
- else:
58
- # GPTQ style, unsigned + symmetric bias
59
- # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
60
- # to add `scalar_types.float8_e4m3fn` here
61
- return [scalar_types.uint4b8, scalar_types.uint8b128]
62
-
63
-
64
- def _check_marlin_supported(
65
- quant_type: ScalarType,
66
- group_size: Optional[int],
67
- has_zp: bool,
68
- device_capability: Optional[int] = None,
69
- ) -> Tuple[bool, Optional[str]]:
70
-
71
- if device_capability is None:
72
- capability_tuple = torch.cuda.get_device_capability()
73
- device_capability = capability_tuple[0] * 10 + capability_tuple[1]
74
-
75
- supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
76
-
77
- if quant_type not in supported_types:
78
- return (
79
- False,
80
- f"Marlin does not support weight_bits = {quant_type}. "
81
- f"Only types = {supported_types} "
82
- f"are supported (for group_size = {group_size}, "
83
- f"device_capability = {device_capability}, zp = {has_zp}).",
84
- )
85
- if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
86
- return (
87
- False,
88
- f"Marlin does not support group_size = {group_size}. "
89
- f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
90
- "are supported.",
91
- )
92
-
93
- return True, None
94
-
95
-
96
- def check_marlin_supported(
97
- quant_type: ScalarType,
98
- group_size: int,
99
- has_zp: bool = False,
100
- device_capability: Optional[int] = None,
101
- ) -> bool:
102
- cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
103
- return cond
104
-
105
-
106
- def verify_marlin_supported(
107
- quant_type: ScalarType, group_size: int, has_zp: bool = False
108
- ) -> None:
109
- cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
110
- if not cond:
111
- assert err_msg is not None
112
- raise ValueError(err_msg)
113
-
114
-
115
- def verify_marlin_supports_shape(
116
- output_size_per_partition: int,
117
- input_size_per_partition: int,
118
- input_size: int,
119
- group_size: int,
120
- ) -> None:
121
-
122
- # Validate output_size_per_partition
123
- if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
124
- raise ValueError(
125
- f"Weight output_size_per_partition = "
126
- f"{output_size_per_partition} is not divisible by "
127
- f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
128
- "Consider reducing tensor_parallel_size or running "
129
- "with --quantization gptq."
130
- )
131
-
132
- # Validate input_size_per_partition
133
- if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
134
- raise ValueError(
135
- f"Weight input_size_per_partition = "
136
- f"{input_size_per_partition} is not divisible "
137
- f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
138
- "Consider reducing tensor_parallel_size or running "
139
- "with --quantization gptq."
140
- )
141
-
142
- if group_size < input_size and input_size_per_partition % group_size != 0:
143
- raise ValueError(
144
- f"Weight input_size_per_partition = {input_size_per_partition}"
145
- f" is not divisible by group_size = {group_size}."
146
- "Consider reducing tensor_parallel_size or running "
147
- "with --quantization gptq."
148
- )
149
-
150
-
151
- def check_marlin_supports_shape(
152
- output_size_per_partition: int,
153
- input_size_per_partition: int,
154
- input_size: int,
155
- group_size: int,
156
- ) -> Tuple[bool, Optional[str]]:
157
- try:
158
- verify_marlin_supports_shape(
159
- output_size_per_partition, input_size_per_partition, input_size, group_size
160
- )
161
- except ValueError as e:
162
- return False, e.__str__()
163
- return True, None
164
-
165
-
166
- def marlin_make_workspace(
167
- output_size_per_partition: int, device: torch.device
168
- ) -> torch.Tensor:
169
- max_workspace_size = (
170
- output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
171
- ) * GPTQ_MARLIN_MAX_PARALLEL
172
-
173
- return torch.zeros(
174
- max_workspace_size, dtype=torch.int, device=device, requires_grad=False
175
- )
176
-
177
-
178
- def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
179
- return (not act_order) or (act_order and not is_row_parallel)
180
-
181
-
182
- def marlin_repeat_scales_on_all_ranks(
183
- act_order: bool, group_size: int, is_row_parallel: bool
184
- ) -> bool:
185
- # Need to repeat scales on every rank if act_ordering or
186
- # channelwise and RowParallelLinear
187
- is_channelwise = group_size == -1
188
- return act_order or (is_channelwise and is_row_parallel)
189
-
190
-
191
- def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
192
- return torch.nn.Parameter(
193
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
194
- )
195
-
196
-
197
- def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
198
- return torch.nn.Parameter(
199
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
200
- )
201
-
202
-
203
- def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
204
- g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
205
- return g_idx[g_idx_sort_indices], g_idx_sort_indices
206
-
207
-
208
- def get_scale_perms():
209
- scale_perm: List[int] = []
210
- for i in range(8):
211
- scale_perm.extend([i + 8 * j for j in range(8)])
212
- scale_perm_single: List[int] = []
213
- for i in range(4):
214
- scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
215
- return scale_perm, scale_perm_single
216
-
217
-
218
- def marlin_permute_scales(
219
- s: torch.Tensor, size_k: int, size_n: int, group_size: int
220
- ) -> torch.Tensor:
221
-
222
- scale_perm, scale_perm_single = get_scale_perms()
223
- if group_size < size_k and group_size != -1:
224
- s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
225
- else:
226
- s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
227
- s = s.reshape((-1, size_n)).contiguous()
228
-
229
- return s
230
-
231
-
232
- def marlin_moe_permute_scales(
233
- s: torch.Tensor,
234
- size_k: int,
235
- size_n: int,
236
- group_size: int,
237
- ):
238
- num_experts = s.shape[0]
239
- output = torch.empty(
240
- (num_experts, s.shape[1], s.shape[2]),
241
- device=s.device,
242
- dtype=s.dtype,
243
- )
244
-
245
- for e in range(num_experts):
246
- output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
247
- return output
248
-
249
-
250
- def marlin_zero_points(
251
- zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
252
- ) -> torch.Tensor:
253
- # Permute zero-points in a similar way to scales, but do not use the
254
- # "single" permutation, since zero-points are applied on every MMA
255
- scale_perm, _ = get_scale_perms()
256
- zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
257
-
258
- # Interleave column dim (for the dequantize code) and pack it to int32
259
- if num_bits == 4:
260
- interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
261
- elif num_bits == 8:
262
- interleave = numpy.array([0, 2, 1, 3])
263
- else:
264
- raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
265
-
266
- zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
267
- zp = zp.reshape((-1, size_n)).contiguous()
268
- zp = pack_cols(zp, num_bits, size_k, size_n)
269
-
270
- return zp
271
-
272
-
273
- def awq_to_marlin_zero_points(
274
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
275
- ) -> torch.Tensor:
276
- # AWQ zero-points are quantized and packed on the column dim.
277
- # In addition, the values are permuted based on dequantizer.
278
- # Here we undo both of these, and then apply marlin permutation
279
- # and pack it back.
280
- q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
281
-
282
- # Undo interleaving (use argsort(..) to get inverse perm)
283
- if num_bits == 4:
284
- undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
285
- elif num_bits == 8:
286
- undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
287
- else:
288
- raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
289
-
290
- q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
291
- q_zp = q_zp.reshape((-1, size_n)).contiguous()
292
-
293
- marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
294
- return marlin_zp
295
-
296
-
297
- def moe_awq_to_marlin_zero_points(
298
- q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
299
- ):
300
- num_experts = q_zp_packed.shape[0]
301
- output = torch.empty(
302
- (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
303
- device=q_zp_packed.device,
304
- dtype=q_zp_packed.dtype,
305
- )
306
- for e in range(num_experts):
307
- output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
308
- return output
309
-
310
-
311
- def apply_gptq_marlin_linear(
312
- input: torch.Tensor,
313
- weight: torch.Tensor,
314
- weight_scale: torch.Tensor,
315
- weight_zp: torch.Tensor,
316
- g_idx: torch.Tensor,
317
- g_idx_sort_indices: torch.Tensor,
318
- workspace: torch.Tensor,
319
- wtype: ScalarType,
320
- output_size_per_partition: int,
321
- input_size_per_partition: int,
322
- is_k_full: bool,
323
- bias: Optional[torch.Tensor] = None,
324
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
325
- ) -> torch.Tensor:
326
- reshaped_x = input.reshape(-1, input.shape[-1])
327
- out_shape = input.shape[:-1] + (output_size_per_partition,)
328
-
329
- output = ops.gptq_marlin_gemm(
330
- reshaped_x,
331
- weight,
332
- weight_scale,
333
- weight_zp,
334
- g_idx,
335
- g_idx_sort_indices,
336
- workspace,
337
- wtype,
338
- size_m=reshaped_x.shape[0],
339
- size_n=output_size_per_partition,
340
- size_k=input_size_per_partition,
341
- is_k_full=is_k_full,
342
- has_zp=False,
343
- use_fp32_reduce=use_fp32_reduce,
344
- is_zp_float=False,
345
- )
346
-
347
- if bias is not None:
348
- output.add_(bias) # In-place add
349
-
350
- return output.reshape(out_shape)
351
-
352
-
353
- def apply_awq_marlin_linear(
354
- input: torch.Tensor,
355
- weight: torch.Tensor,
356
- weight_scale: torch.Tensor,
357
- weight_zp: torch.Tensor,
358
- g_idx: torch.Tensor,
359
- g_idx_sort_indices: torch.Tensor,
360
- workspace: torch.Tensor,
361
- quant_type: ScalarType,
362
- output_size_per_partition: int,
363
- input_size_per_partition: int,
364
- bias: Optional[torch.Tensor] = None,
365
- use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
366
- ) -> torch.Tensor:
367
- reshaped_x = input.reshape(-1, input.shape[-1])
368
- out_shape = input.shape[:-1] + (output_size_per_partition,)
369
-
370
- output = ops.gptq_marlin_gemm(
371
- reshaped_x,
372
- weight,
373
- weight_scale,
374
- weight_zp,
375
- g_idx,
376
- g_idx_sort_indices,
377
- workspace,
378
- quant_type,
379
- size_m=reshaped_x.shape[0],
380
- size_n=output_size_per_partition,
381
- size_k=input_size_per_partition,
382
- is_k_full=True,
383
- has_zp=True,
384
- use_fp32_reduce=use_fp32_reduce,
385
- is_zp_float=False,
386
- )
387
-
388
- if bias is not None:
389
- output.add_(bias) # In-place add
390
-
391
- return output.reshape(out_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_fp8.py DELETED
@@ -1,100 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
-
5
- import quantization as ops
6
-
7
- from .marlin_utils import marlin_make_workspace, marlin_permute_scales
8
-
9
-
10
- def is_fp8_marlin_supported():
11
- capability = torch.cuda.get_device_capability()
12
- capability = capability[0] * 10 + capability[1]
13
- return capability >= 80
14
-
15
-
16
- def apply_fp8_marlin_linear(
17
- input: torch.Tensor,
18
- weight: torch.Tensor,
19
- weight_scale: torch.Tensor,
20
- workspace: torch.Tensor,
21
- size_n: int,
22
- size_k: int,
23
- bias: Optional[torch.Tensor],
24
- ) -> torch.Tensor:
25
- # For GPUs that lack FP8 hardware support, we can leverage the
26
- # Marlin kernel for fast weight-only FP8 quantization
27
-
28
- reshaped_x = input.reshape(-1, input.shape[-1])
29
- out_shape = input.shape[:-1] + (size_n,)
30
-
31
- output = ops.fp8_marlin_gemm(
32
- a=reshaped_x,
33
- b_q_weight=weight,
34
- b_scales=weight_scale,
35
- workspace=workspace,
36
- num_bits=8,
37
- size_m=reshaped_x.shape[0],
38
- size_n=size_n,
39
- size_k=size_k,
40
- )
41
-
42
- if bias is not None:
43
- output.add_(bias) # In-place add
44
-
45
- return output.reshape(out_shape)
46
-
47
-
48
- def prepare_fp8_layer_for_marlin(
49
- layer: torch.nn.Module, strategy: str = "tensor"
50
- ) -> None:
51
- part_size_n = layer.output_size_per_partition
52
- part_size_k = layer.input_size_per_partition
53
-
54
- device = layer.weight.device
55
-
56
- # WORKSPACE
57
- layer.workspace = marlin_make_workspace(part_size_n, device)
58
-
59
- # WEIGHT
60
- # Repack weights to marlin format
61
- marlin_qweight = ops.gptq_marlin_repack(
62
- b_q_weight=pack_fp8_to_int32(layer.weight),
63
- perm=torch.empty(0, dtype=torch.int, device=device),
64
- size_k=part_size_k,
65
- size_n=part_size_n,
66
- num_bits=8,
67
- )
68
- layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
69
-
70
- # WEIGHT SCALES
71
- scales = layer.weight_scale.to(layer.orig_dtype)
72
- # Permute scales
73
- marlin_scales = marlin_permute_scales(
74
- s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
75
- )
76
- layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
77
-
78
-
79
- def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
80
- """
81
- Repack FP8 weights to gptq format (packed int32 elements)
82
- """
83
- assert fp8_tensor.dtype == torch.float8_e4m3fn
84
- assert fp8_tensor.shape[0] % 4 == 0
85
-
86
- # Reshape to prepare for packing
87
- reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
88
-
89
- # Convert fp8 to uint8 (byte) representation
90
- byte_tensor = reshaped.view(torch.uint8)
91
-
92
- # Pack 4 uint8 values into one int32
93
- packed = (
94
- byte_tensor[:, 0].to(torch.int32)
95
- | (byte_tensor[:, 1].to(torch.int32) << 8)
96
- | (byte_tensor[:, 2].to(torch.int32) << 16)
97
- | (byte_tensor[:, 3].to(torch.int32) << 24)
98
- )
99
-
100
- return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test.py DELETED
@@ -1,162 +0,0 @@
1
- """Utility functions used for tests and benchmarks"""
2
-
3
- from typing import List, Optional
4
-
5
- import numpy as np
6
- import torch
7
-
8
- from quantization.scalar_type import ScalarType
9
-
10
- from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
11
- from .quant_utils import (
12
- get_pack_factor,
13
- gptq_quantize_weights,
14
- quantize_weights,
15
- sort_weights,
16
- )
17
-
18
-
19
- class MarlinWorkspace:
20
-
21
- def __init__(self, out_features, min_thread_n, max_parallel):
22
- assert (
23
- out_features % min_thread_n == 0
24
- ), "out_features = {} is undivisible by min_thread_n = {}".format(
25
- out_features, min_thread_n
26
- )
27
-
28
- max_workspace_size = (out_features // min_thread_n) * max_parallel
29
-
30
- self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
31
-
32
-
33
- def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
34
- assert q_w.shape == (size_k, size_n)
35
- assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
36
- assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
37
-
38
- # Permute weights to 16x64 marlin tiles
39
- q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
40
- q_w = q_w.permute((0, 2, 1, 3))
41
- q_w = q_w.reshape((size_k // tile, size_n * tile))
42
-
43
- q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
44
-
45
- return q_w
46
-
47
-
48
- def marlin_weights(q_w, size_k, size_n, num_bits, perm):
49
- # Permute
50
- q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
51
-
52
- # Pack
53
- pack_factor = get_pack_factor(num_bits)
54
- orig_device = q_w.device
55
-
56
- q_w = q_w.cpu().numpy().astype(np.uint32)
57
-
58
- q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
59
- for i in range(pack_factor):
60
- q_packed |= q_w[:, i::pack_factor] << num_bits * i
61
-
62
- q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
63
-
64
- return q_packed
65
-
66
-
67
- def get_weight_perm(num_bits: int):
68
- perm_list: List[int] = []
69
- for i in range(32):
70
- perm1: List[int] = []
71
- col = i // 4
72
- for block in [0, 1]:
73
- for row in [
74
- 2 * (i % 4),
75
- 2 * (i % 4) + 1,
76
- 2 * (i % 4 + 4),
77
- 2 * (i % 4 + 4) + 1,
78
- ]:
79
- perm1.append(16 * row + col + 8 * block)
80
- for j in range(4):
81
- perm_list.extend([p + 256 * j for p in perm1])
82
-
83
- perm = np.array(perm_list)
84
-
85
- if num_bits == 4:
86
- interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
87
- elif num_bits == 8:
88
- interleave = np.array([0, 2, 1, 3])
89
- else:
90
- raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
91
-
92
- perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
93
- perm = torch.from_numpy(perm)
94
- return perm
95
-
96
-
97
- def marlin_quantize(
98
- w: torch.Tensor,
99
- quant_type: ScalarType,
100
- group_size: int,
101
- act_order: bool,
102
- test_perm: Optional[torch.Tensor] = None,
103
- ):
104
- size_k, size_n = w.shape
105
- num_bits = quant_type.size_bits
106
-
107
- # Normalize group_size
108
- if group_size == -1:
109
- group_size = size_k
110
- assert group_size <= size_k
111
-
112
- # Quantize (and apply act_order if provided)
113
- w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
114
- w, quant_type, group_size, act_order, test_perm
115
- )
116
-
117
- # For act_order, sort the "weights" and "g_idx" so that group ids are
118
- # increasing
119
- sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
120
- if act_order:
121
- q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
122
-
123
- # Reformat to marlin
124
- weight_perm = get_weight_perm(num_bits)
125
- marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
126
- marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
127
-
128
- # Create result
129
- res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
130
- for i in range(len(res_list)):
131
- res_list[i] = res_list[i].to(w.device)
132
-
133
- return res_list
134
-
135
-
136
- def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
137
- size_k, size_n = w.shape
138
-
139
- # Normalize group_size
140
- if group_size == -1:
141
- group_size = size_k
142
- assert group_size <= size_k
143
-
144
- # Detect num groups
145
- assert size_k % group_size == 0
146
- num_groups = size_k // group_size
147
-
148
- # Quantize with zp
149
- w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
150
-
151
- # Reformat to marlin
152
- weight_perm = get_weight_perm(quant_type.size_bits)
153
- marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
154
- marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
155
- marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
156
-
157
- # Create result
158
- res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
159
- for i in range(len(res_list)):
160
- res_list[i] = res_list[i].to(w.device)
161
-
162
- return res_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_24.py DELETED
@@ -1,473 +0,0 @@
1
- """Utility functions used for tests and benchmarks"""
2
-
3
- import random
4
- from typing import List
5
-
6
- import numpy
7
- import torch
8
-
9
- from quantization.scalar_type import ScalarType
10
-
11
- from .marlin_utils_test import marlin_weights
12
- from .quant_utils import gptq_quantize_weights
13
-
14
-
15
- # This is PyTorch implementation of main part of reorder_meta()
16
- # function, from tools/util/include/cutlass/util/host_reorder.h file
17
- # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
18
- # GEMM decides upon layout of this matrix, and at the moment for the
19
- # sparse GEMM executed on tensor cores, this is layout described by
20
- # ColumnMajorInterleaved<2> data structure, in
21
- # include/cutlass/layout/matrix.h of CUTLASS source tree. The
22
- # reordering of meta matrix into meta_reordered matrix calculated
23
- # according to these segments of CUTLASS code is re-implemented here.
24
- # Note that this calculation produces offsets for scattering metadata
25
- # matrix elements into reordered metadata matrix elements (or,
26
- # equivalently, for gathering reordered metadata matrix element back
27
- # into metadata matrix elements).
28
- def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
29
- dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
30
- dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
31
-
32
- # Reorder the rows, then swizzle the 2x2 blocks.
33
- group_x = 64
34
- group_y = 32 if meta_dtype.itemsize == 2 else 16
35
-
36
- dst_rows = (
37
- dst_rows // group_x * group_x
38
- + (dst_rows % 2) * 2
39
- + (dst_rows % 8) // 4
40
- + ((dst_rows % group_y) % 4) // 2 * 32
41
- + ((dst_rows % group_x) // 8) * 4
42
- )
43
-
44
- topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
45
- bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
46
- dst_rows += topright - bottomleft
47
- dst_cols -= topright - bottomleft
48
-
49
- # Assumed that meta tensor is to be stored in CUTLASS
50
- # InterleavedColumnMajor layout, and reverse engineered
51
- # corresponding code to store values into this tensor.
52
- interleave = 2
53
- cols_maj = dst_cols // interleave
54
- cols_min = dst_cols % interleave
55
- return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
56
-
57
-
58
- # This function converts dense matrix into sparse semi-structured
59
- # representation, producing "compressed" matrix, in the layout used by
60
- # CUTLASS backend, and corresponding metadata matrix.
61
- def sparse_semi_structured_from_dense_cutlass(dense):
62
- if dense.dim() != 2:
63
- raise RuntimeError(
64
- f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
65
- )
66
-
67
- m, k = dense.shape
68
- device = dense.device
69
-
70
- meta_dtype = torch.int8
71
- if dense.dtype == torch.int8:
72
- meta_dtype = torch.int32
73
- elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
74
- meta_dtype = torch.int16
75
- else:
76
- raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
77
- quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
78
- if quadbits_per_meta_elem not in (4, 8):
79
- raise RuntimeError("Invalid number of elements per meta element calculated")
80
-
81
- if meta_dtype == torch.int32:
82
- if m % 16 != 0:
83
- raise RuntimeError(
84
- f"Number of rows of dense matrix {m} must be divisible by 16"
85
- )
86
- else:
87
- if m % 32 != 0:
88
- raise RuntimeError(
89
- f"Number of rows of dense matrix {m} must be divisible by 32"
90
- )
91
- if k % (4 * quadbits_per_meta_elem) != 0:
92
- raise RuntimeError(
93
- f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
94
- )
95
-
96
- if dense.dtype != torch.float:
97
- ksparse = 4
98
- dense_4 = dense.view(-1, k // ksparse, ksparse)
99
- m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
100
- else:
101
- ksparse = 2
102
- dense_2 = dense.view(-1, k // ksparse, ksparse)
103
- m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
104
- meta_ncols = k // (ksparse * quadbits_per_meta_elem)
105
-
106
- # Encoding quadruples of True/False values as follows:
107
- # [True, True, False, False] -> 0b0100
108
- # [True, False, True, False] -> 0b1000
109
- # [False, True, True, False] -> 0b1001
110
- # [True, False, False, True ] -> 0b1100
111
- # [False, True, False, True ] -> 0b1101
112
- # [False, False, True, True ] -> 0b1110
113
- # Thus, lower two bits in the encoding are index of the True value
114
- # at the lowest index in the quadruple, and the higher two bits in
115
- # the encoding are index of the other True value in the quadruple.
116
- # In case there are less than two True values, than False value or
117
- # values at some index or indices are considered True for the
118
- # encoding. In case there are more than two True values, then the
119
- # excess True value(s) at some indices are considered False for
120
- # the encoding. The exact encodings used for these cases are as
121
- # follows:
122
- # [False, False, False, False] -> 0b1110
123
- # [False, False, False, True ] -> 0b1110
124
- # [False, False, True, False] -> 0b1110
125
- # [False, True, False, False] -> 0b1001
126
- # [False, True, True, True ] -> 0b1101
127
- # [True, False, False, False] -> 0b1000
128
- # [True, False, True, True ] -> 0b1100
129
- # [True, True, False, True ] -> 0b0100
130
- # [True, True, True, False] -> 0b0100
131
- # [True, True, True, True ] -> 0b0100
132
- # These particular encodings are chosen, with the help of Espresso
133
- # logic minimizer software, for the purpose of minimization of
134
- # corresponding Boolean functions, that translate non-zero flags
135
- # into encoding bits. Note also possible choices for the first
136
- # and last of these encodings were limited only to (0b0100,
137
- # 0b1110), in order to produce valid encodings for 1:2 sparsity
138
- # case.
139
-
140
- expr0 = m0 & m1
141
- expr1 = ~m0 & m1
142
- expr2 = ~m0 & ~m1
143
- bit0 = expr1
144
- bit1 = expr2
145
- bit2 = expr0 | expr2 | m3
146
- bit3 = expr1 | ~m1
147
- idxs0 = bit0 | (bit1.to(torch.int64) << 1)
148
- idxs1 = bit2 | (bit3.to(torch.int64) << 1)
149
-
150
- if dense.dtype != torch.float:
151
- sparse0 = dense_4.gather(
152
- -1, idxs0.unsqueeze(-1)
153
- ) # type: ignore[possibly-undefined]
154
- sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
155
- sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
156
- else:
157
- sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
158
- m, k // 2
159
- ) # type: ignore[possibly-undefined]
160
-
161
- meta_4 = idxs0 | (idxs1 << 2)
162
- meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
163
-
164
- if quadbits_per_meta_elem == 4:
165
- meta = (
166
- meta_n[:, :, 0]
167
- | (meta_n[:, :, 1] << 4)
168
- | (meta_n[:, :, 2] << 8)
169
- | (meta_n[:, :, 3] << 12)
170
- )
171
- elif quadbits_per_meta_elem == 8:
172
- meta = (
173
- meta_n[:, :, 0]
174
- | (meta_n[:, :, 1] << 4)
175
- | (meta_n[:, :, 2] << 8)
176
- | (meta_n[:, :, 3] << 12)
177
- | (meta_n[:, :, 4] << 16)
178
- | (meta_n[:, :, 5] << 20)
179
- | (meta_n[:, :, 6] << 24)
180
- | (meta_n[:, :, 7] << 28)
181
- )
182
-
183
- # Reorder meta tensor elements.
184
- meta_reordered = meta.new_empty(
185
- (m * meta_ncols,)
186
- ) # type: ignore[possibly-undefined]
187
- meta_offsets = _calculate_meta_reordering_scatter_offsets(
188
- m, meta_ncols, meta_dtype, device
189
- )
190
- meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
191
-
192
- return (sparse, meta_reordered.view(m, meta_ncols))
193
-
194
-
195
- # This function performs reverse of the function above - it
196
- # reconstructs dense matrix from a pair of "compressed" matrix, given
197
- # in the layout used by CUTLASS backend, and accompanying metadata
198
- # matrix.
199
- def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
200
- if sparse.dim() != 2:
201
- raise RuntimeError(
202
- f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
203
- )
204
-
205
- m, k = sparse.shape
206
- device = sparse.device
207
-
208
- if meta_reordered.dim() != 2:
209
- raise RuntimeError(
210
- f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
211
- )
212
- if meta_reordered.device != device:
213
- raise RuntimeError(
214
- f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
215
- )
216
-
217
- meta_dtype = meta_reordered.dtype
218
- if meta_dtype not in (torch.int16, torch.int32):
219
- raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
220
- quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
221
-
222
- ksparse = 4 if sparse.dtype != torch.float else 2
223
-
224
- meta_nrows, meta_ncols = meta_reordered.shape
225
- if meta_nrows != m:
226
- raise RuntimeError(
227
- f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
228
- )
229
- if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
230
- raise RuntimeError(
231
- f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
232
- "expected according to the number of columns of meta matrix"
233
- )
234
-
235
- # Undo meta tensor elements reordering.
236
- meta_offsets = _calculate_meta_reordering_scatter_offsets(
237
- m, meta_ncols, meta_dtype, device
238
- )
239
- meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
240
-
241
- # Unpack sparse tensor back to original dense tensor, using
242
- # information provided by meta tensor. Note that torch.float
243
- # datatype is handled pretty much the same as
244
- # torch.half/torch.bfloat16, as metadata for a pair of torch.float
245
- # value is encoded as if underlying 8 bytes contain four
246
- # torch.half/torch.bfloat16 values, where either first two or last
247
- # two are zeros.
248
- meta_2 = torch.empty(
249
- (m, meta_ncols, 2 * quadbits_per_meta_elem),
250
- dtype=meta_dtype,
251
- device=device,
252
- )
253
- if quadbits_per_meta_elem == 4:
254
- meta_2[:, :, 0] = meta & 0b11
255
- meta_2[:, :, 1] = (meta >> 2) & 0b11
256
- meta_2[:, :, 2] = (meta >> 4) & 0b11
257
- meta_2[:, :, 3] = (meta >> 6) & 0b11
258
- meta_2[:, :, 4] = (meta >> 8) & 0b11
259
- meta_2[:, :, 5] = (meta >> 10) & 0b11
260
- meta_2[:, :, 6] = (meta >> 12) & 0b11
261
- meta_2[:, :, 7] = (meta >> 14) & 0b11
262
- elif quadbits_per_meta_elem == 8:
263
- meta_2[:, :, 0] = meta & 0b11
264
- meta_2[:, :, 1] = (meta >> 2) & 0b11
265
- meta_2[:, :, 2] = (meta >> 4) & 0b11
266
- meta_2[:, :, 3] = (meta >> 6) & 0b11
267
- meta_2[:, :, 4] = (meta >> 8) & 0b11
268
- meta_2[:, :, 5] = (meta >> 10) & 0b11
269
- meta_2[:, :, 6] = (meta >> 12) & 0b11
270
- meta_2[:, :, 7] = (meta >> 14) & 0b11
271
- meta_2[:, :, 8] = (meta >> 16) & 0b11
272
- meta_2[:, :, 9] = (meta >> 18) & 0b11
273
- meta_2[:, :, 10] = (meta >> 20) & 0b11
274
- meta_2[:, :, 11] = (meta >> 22) & 0b11
275
- meta_2[:, :, 12] = (meta >> 24) & 0b11
276
- meta_2[:, :, 13] = (meta >> 26) & 0b11
277
- meta_2[:, :, 14] = (meta >> 28) & 0b11
278
- meta_2[:, :, 15] = (meta >> 30) & 0b11
279
-
280
- dense_offsets = meta_2.view(-1) + (
281
- torch.arange(0, 2 * m * k // ksparse, device=device) * 4
282
- ).view(-1, 1).repeat(1, 2).view(-1)
283
-
284
- dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
285
- if sparse.dtype != torch.float:
286
- # dense.scatter_(0, dense_offsets, sparse.view(-1))
287
- dense.scatter_(0, dense_offsets, sparse.reshape(-1))
288
- else:
289
- dense.view(torch.half).scatter_(
290
- 0, dense_offsets, sparse.view(torch.half).view(-1)
291
- )
292
-
293
- return dense.view(m, 2 * k)
294
-
295
-
296
- def mask_creator(tensor):
297
- """
298
- Class for creating N:M sparsity masks.
299
- Masks will be created using the N:M ratio, where for every block of
300
- M weights, N will be pruned based on ranked weight value. Each mask
301
- will correspond to the given tensor.
302
-
303
- :param N: The number of weights in a group to keep
304
- :param M: The size of a weight group
305
- """
306
- N = 2
307
- M = 4
308
-
309
- mask = None
310
- # for i, tensor in enumerate(tensors):
311
- if tensor.numel() % M != 0:
312
- raise ValueError(
313
- f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
314
- )
315
-
316
- num_groups = tensor.numel() // M
317
-
318
- # N:M sparsity for linear layers
319
- tensor_temp = tensor.detach().abs().reshape(num_groups, M)
320
- index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
321
-
322
- w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
323
- mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
324
-
325
- return mask
326
-
327
-
328
- def inject_24(w, size_k, size_n):
329
- assert w.shape == (size_k, size_n)
330
-
331
- mask = mask_creator(w.t()).t().cuda().bool()
332
-
333
- return (mask * w).contiguous(), mask.contiguous()
334
-
335
-
336
- def check_24(w, num_rows_to_sample=50, _verbose=False):
337
- BLOCK_SIZE = 4
338
- MAX_NON_ZEROS = 2
339
-
340
- w = w.t().contiguous()
341
-
342
- print("check_24: w.shape = {}".format(w.shape))
343
-
344
- num_rows, num_cols = w.shape
345
- sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
346
- if _verbose:
347
- print(f"Sampled row idxs = {sampled_row_idxs}")
348
-
349
- total_segments = 0
350
- non_24_segments = 0
351
- for i in sampled_row_idxs:
352
- for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
353
- total_segments += 1
354
- block = w[i, j : j + BLOCK_SIZE]
355
- num_nonzero = torch.count_nonzero(block)
356
- if num_nonzero > MAX_NON_ZEROS:
357
- print("i = {} j = {} block = {}".format(i, j, block))
358
- non_24_segments += 1
359
-
360
- print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
361
-
362
-
363
- def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
364
- assert q_24.shape == (size_k, size_n)
365
-
366
- # Remove bias to normalize over 0
367
- q_24_no_zp = q_24 - wtype.bias
368
-
369
- # Compress
370
- q_24_no_zp = q_24_no_zp.t().contiguous()
371
- q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
372
- q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
373
-
374
- # Restore bias
375
- q_24_comp = q_24_no_zp_comp + wtype.bias
376
-
377
- # Resize meta to its actual shape (without moving any data)
378
- meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
379
-
380
- return q_24_comp, meta
381
-
382
-
383
- def get_scale_perms_24():
384
- scale_perm: List[int] = []
385
- for i in range(8):
386
- scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
387
- scale_perm_single: List[int] = []
388
- for i in range(8):
389
- scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
390
- return scale_perm, scale_perm_single
391
-
392
-
393
- def get_weight_perm_24(num_bits: int):
394
- perm_list: List[int] = []
395
- for i in range(32):
396
- perm1: List[int] = []
397
- col = i // 4
398
- col_o = col // 2
399
- for block in [0, 1]:
400
- for row in [
401
- 2 * (i % 4),
402
- 2 * (i % 4) + 1,
403
- 2 * (i % 4 + 4),
404
- 2 * (i % 4 + 4) + 1,
405
- ]:
406
- perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
407
- for j in range(4):
408
- perm_list.extend([p + 1 * j for p in perm1])
409
- perm = numpy.array(perm_list)
410
-
411
- if num_bits == 4:
412
- interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
413
- elif num_bits == 8:
414
- interleave = numpy.array([0, 2, 1, 3])
415
- else:
416
- raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
417
-
418
- perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
419
- perm = torch.from_numpy(perm)
420
- return perm
421
-
422
-
423
- def marlin_permute_scales_24(
424
- s: torch.Tensor, size_k: int, size_n: int, group_size: int
425
- ) -> torch.Tensor:
426
-
427
- scale_perm, scale_perm_single = get_scale_perms_24()
428
- if group_size < size_k and group_size != -1:
429
- s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
430
- else:
431
- s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
432
- s = s.reshape((-1, size_n)).contiguous()
433
-
434
- return s
435
-
436
-
437
- def marlin_24_quantize(
438
- w: torch.Tensor,
439
- quant_type: ScalarType,
440
- group_size: int,
441
- ):
442
- size_k, size_n = w.shape
443
-
444
- # Normalize group_size
445
- if group_size == -1:
446
- group_size = size_k
447
- assert group_size <= size_k
448
-
449
- # Inject 2:4 sparsity
450
- w_24, mask_24 = inject_24(w, size_k, size_n)
451
-
452
- # Quantize
453
- w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
454
- w_24, quant_type, group_size, act_order=False
455
- )
456
-
457
- # Compress quantized weight
458
- q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
459
- size_k_comp = size_k // 2
460
-
461
- # Reformat to marlin
462
- weight_perm = get_weight_perm_24(quant_type.size_bits)
463
- marlin_24_q_w_comp = marlin_weights(
464
- q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
465
- )
466
- marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
467
-
468
- # Create result
469
- res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
470
- for i in range(len(res_list)):
471
- res_list[i] = res_list[i].to(w.device)
472
-
473
- return res_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py DELETED
@@ -1,125 +0,0 @@
1
- from typing import List
2
-
3
- import numpy
4
- import torch
5
-
6
- from .marlin_utils_test import marlin_permute_weights
7
- from .quant_utils import get_pack_factor, qqq_quantize_weights
8
-
9
-
10
- def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
11
- # Permute
12
- q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
13
-
14
- # Pack
15
- pack_factor = get_pack_factor(num_bits)
16
- orig_device = q_w.device
17
-
18
- q_w = q_w.cpu().numpy().astype(numpy.uint32)
19
-
20
- q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
21
- dtype=numpy.uint32)
22
- if group_size == size_k:
23
- for i in range(pack_factor):
24
- q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
25
- else:
26
- for i in range(pack_factor):
27
- q_packed |= q_w[:, i::pack_factor] << num_bits * i
28
-
29
- q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
30
-
31
- return q_packed
32
-
33
-
34
- def get_qqq_scale_perms():
35
- scale_perm: List[int] = []
36
- for i in range(8):
37
- scale_perm.extend([i + 8 * j for j in range(8)])
38
- scale_perm_single: List[int] = []
39
- for i in range(4):
40
- scale_perm_single.extend(
41
- [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
42
- return scale_perm, scale_perm_single
43
-
44
-
45
- # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
46
- def get_qqq_weight_perm(num_bits: int, quant_type: str):
47
- perm_list: List[int] = []
48
- for i in range(32):
49
- perm1: List[int] = []
50
- col = i // 4
51
- for block in [0, 1]:
52
- for row in [
53
- 4 * (i % 4),
54
- 4 * (i % 4) + 1,
55
- 4 * (i % 4) + 2,
56
- 4 * (i % 4) + 3,
57
- ]:
58
- perm1.append(16 * row + col + 8 * block)
59
- for j in range(4):
60
- perm_list.extend([p + 256 * j for p in perm1])
61
-
62
- perm = numpy.array(perm_list)
63
-
64
- assert quant_type in ["per-channel",
65
- "per-group"], "not supported quantization type"
66
- if num_bits == 4:
67
- if quant_type == "per-channel":
68
- interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
69
- else:
70
- interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
71
- else:
72
- raise Exception("num_bits must be 4, got {}".format(num_bits))
73
-
74
- perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
75
- perm = torch.from_numpy(perm)
76
- return perm
77
-
78
-
79
- def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
80
- scale_perm, scale_perm_single = get_qqq_scale_perms()
81
- if group_size < size_k and group_size != -1:
82
- s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
83
- s_channel = s_channel.reshape(
84
- (-1, len(scale_perm_single)))[:, scale_perm_single]
85
- s_group = s_group.reshape((-1, size_n)).contiguous()
86
- else:
87
- s_channel = s_channel.reshape(
88
- (-1, len(scale_perm_single)))[:, scale_perm_single]
89
- s_channel = s_channel.reshape((-1, size_n)).contiguous()
90
-
91
- return s_group, s_channel
92
-
93
-
94
- def marlin_qqq_quantize(
95
- w: torch.Tensor,
96
- num_bits: int,
97
- group_size: int,
98
- ):
99
- size_k, size_n = w.shape
100
-
101
- # Normalize group_size
102
- if group_size == -1:
103
- group_size = size_k
104
- assert group_size <= size_k
105
- quant_type = "per-channel" if group_size == size_k else "per-group"
106
-
107
- # Quantize
108
- w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
109
- w, num_bits, group_size)
110
-
111
- # Reformat to marlin_qqq
112
- weight_perm = get_qqq_weight_perm(num_bits, quant_type)
113
- marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
114
- weight_perm, group_size)
115
- marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
116
- s_group, s_channel, size_k, size_n, group_size)
117
-
118
- # Create result
119
- res_list = [
120
- w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
121
- ]
122
- for i in range(len(res_list)):
123
- res_list[i] = res_list[i].to(w.device)
124
-
125
- return res_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/quant_utils.py DELETED
@@ -1,470 +0,0 @@
1
- """This file is used for /tests and /benchmarks"""
2
-
3
- from typing import List, Optional
4
-
5
- import numpy
6
- import torch
7
-
8
- from quantization.scalar_type import ScalarType, scalar_types
9
-
10
- SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
- SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
12
-
13
- MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
14
-
15
- # Note: this is a hack. We should update each model to register the
16
- # stacked params and get it from there instead in a future PR.
17
- # fused_name: List[shard_name]
18
- FUSED_LAYER_NAME_MAPPING = {
19
- "qkv_proj": ["q_proj", "k_proj", "v_proj"],
20
- "gate_up_proj": ["gate_proj", "up_proj"],
21
- }
22
-
23
-
24
- def pack_quantized_values_into_int32(
25
- w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
26
- ):
27
- # move dim to pack to the end
28
- perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
29
- inv_perm = tuple(perm.index(i) for i in range(len(perm)))
30
- w_q_perm = w_q.permute(perm)
31
-
32
- pack_factor = 32 // wtype.size_bits
33
- mask = (1 << wtype.size_bits) - 1
34
-
35
- new_shape_perm = list(w_q_perm.shape)
36
- assert w_q_perm.shape[-1] % pack_factor == 0
37
- new_shape_perm[-1] //= pack_factor
38
-
39
- res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
40
- for i in range(pack_factor):
41
- res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
42
-
43
- return res.permute(inv_perm)
44
-
45
-
46
- def unpack_quantized_values_into_int32(
47
- w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
48
- ):
49
- # move dim to pack to the end
50
- perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
51
- inv_perm = tuple(perm.index(i) for i in range(len(perm)))
52
- w_q_perm = w_q.permute(perm)
53
-
54
- pack_factor = 32 // wtype.size_bits
55
- mask = (1 << wtype.size_bits) - 1
56
-
57
- new_shape_perm = list(w_q_perm.shape)
58
- new_shape_perm[-1] *= pack_factor
59
-
60
- res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
61
- for i in range(pack_factor):
62
- res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
63
-
64
- return res.permute(inv_perm)
65
-
66
-
67
- def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
68
- # prefix: model.layers.0.self_attn.q_proj
69
- # proj_name: q_proj
70
- proj_name = prefix.split(".")[-1]
71
- if proj_name in FUSED_LAYER_NAME_MAPPING:
72
- shard_prefixes = [
73
- prefix.replace(proj_name, shard_proj_name)
74
- for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
75
- ]
76
-
77
- is_skipped = None
78
- for shard_prefix in shard_prefixes:
79
- is_shard_skipped = shard_prefix in ignored_layers
80
-
81
- if is_skipped is None:
82
- is_skipped = is_shard_skipped
83
- elif is_shard_skipped != is_skipped:
84
- raise ValueError(
85
- f"Detected some but not all shards of {prefix} "
86
- "are quantized. All shards of fused layers "
87
- "to have the same precision."
88
- )
89
- else:
90
- is_skipped = prefix in ignored_layers
91
-
92
- assert is_skipped is not None
93
- return is_skipped
94
-
95
-
96
- def get_pack_factor(num_bits):
97
- assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
98
- return 32 // num_bits
99
-
100
-
101
- def permute_rows(
102
- q_w: torch.Tensor,
103
- w_ref: torch.Tensor,
104
- group_size: int,
105
- test_perm: Optional[torch.Tensor] = None,
106
- ):
107
- assert q_w.shape == w_ref.shape
108
-
109
- orig_device = q_w.device
110
- k_size, _ = q_w.shape
111
-
112
- g_idx = torch.zeros((k_size,), dtype=torch.int32)
113
- for i in range(k_size):
114
- g_idx[i] = i // group_size
115
-
116
- # Simulate act_order by doing a random permutation on K
117
- rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
118
-
119
- g_idx = g_idx[rand_perm].contiguous()
120
- q_w = q_w[rand_perm, :].contiguous()
121
- w_ref = w_ref[rand_perm, :].contiguous()
122
-
123
- return (
124
- w_ref.to(device=orig_device),
125
- q_w.to(device=orig_device),
126
- g_idx.to(device=orig_device),
127
- rand_perm.to(device=orig_device),
128
- )
129
-
130
-
131
- def quantize_weights(
132
- w: torch.Tensor,
133
- quant_type: ScalarType,
134
- group_size: Optional[int],
135
- zero_points: bool = False,
136
- ref_zero_points_after_scales: bool = False,
137
- ):
138
- assert (
139
- quant_type.is_integer()
140
- ), "Floating point quantization may work but has not been tested"
141
- assert not zero_points or group_size is not None, (
142
- "to have group zero points, group_size must be provided "
143
- "(-1 group_size is channelwise)"
144
- )
145
-
146
- orig_device = w.device
147
- orig_type = w.dtype
148
- size_k, size_n = w.shape
149
-
150
- assert w.is_floating_point(), "w must be float"
151
-
152
- if group_size == -1:
153
- group_size = size_k
154
-
155
- # Reshape to [groupsize, -1]
156
- if group_size is not None and group_size < size_k:
157
- w = w.reshape((-1, group_size, size_n))
158
- w = w.permute(1, 0, 2)
159
- w = w.reshape((group_size, -1))
160
-
161
- # Compute scale for each group
162
- max_val = torch.max(w, 0, keepdim=True).values
163
- min_val = torch.min(w, 0, keepdim=True).values
164
-
165
- max_q_val = quant_type.max()
166
- min_q_val = quant_type.min()
167
-
168
- w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
169
- maybe_w_zp = None
170
- if group_size is not None:
171
- if zero_points:
172
- assert not quant_type.is_signed() and quant_type.max() > 0
173
- w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
174
- maybe_w_zp = (
175
- torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
176
- )
177
- else:
178
- # If the bias is such that there are no possible negative/positive
179
- # values, set the max value to inf to avoid divide by 0
180
- w_s = torch.max(
181
- abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
182
- abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
183
- )
184
-
185
- # Quantize
186
- w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
187
- w_q = torch.clamp(w_q, min_q_val, max_q_val)
188
-
189
- # Compute ref (dequantized)
190
- # For some kernels (namely Machete) the zero-points are applied after the
191
- # scales are applied, for this case computing the reference in similar way
192
- # allows us to use tighter error tolerances in our unit tests.
193
- if ref_zero_points_after_scales and maybe_w_zp is not None:
194
- w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
195
- else:
196
- w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
197
-
198
- if quant_type.has_bias():
199
- w_q += quant_type.bias
200
-
201
- # Restore original shapes
202
- if group_size is not None and group_size < size_k:
203
-
204
- def reshape_w(w):
205
- w = w.reshape((group_size, -1, size_n))
206
- w = w.permute(1, 0, 2)
207
- w = w.reshape((size_k, size_n)).contiguous()
208
- return w
209
-
210
- w_q = reshape_w(w_q)
211
- w_ref = reshape_w(w_ref)
212
- w_s = w_s.reshape((-1, size_n)).contiguous()
213
-
214
- if maybe_w_zp is not None:
215
- maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
216
- maybe_w_zp = maybe_w_zp.to(device=orig_device)
217
-
218
- return (
219
- w_ref.to(device=orig_device),
220
- w_q.to(device=orig_device),
221
- w_s if group_size is not None else None,
222
- maybe_w_zp,
223
- )
224
-
225
-
226
- def gptq_quantize_weights(
227
- w: torch.Tensor,
228
- quant_type: ScalarType,
229
- group_size: int,
230
- act_order: bool,
231
- test_perm: Optional[torch.Tensor] = None,
232
- ):
233
- size_k, _ = w.shape
234
-
235
- assert w.is_floating_point(), "w must be float"
236
- assert (
237
- quant_type in SUPPORTED_GPTQ_QUANT_TYPES
238
- ), f"Unsupported gptq type = {quant_type}"
239
- assert group_size in SUPPORTED_GROUP_SIZES + [
240
- size_k
241
- ], f"Unsupported groupsize = {group_size}"
242
-
243
- w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
244
-
245
- # Apply act_order
246
- g_idx = torch.empty(0, dtype=torch.int, device=w.device)
247
- rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
248
- if act_order:
249
- assert (
250
- group_size < size_k
251
- ), "For act_order, groupsize = {} must be less than size_k = {}".format(
252
- group_size, size_k
253
- )
254
-
255
- w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
256
-
257
- return w_ref, w_q, w_s, g_idx, rand_perm
258
-
259
-
260
- # QQQ employs different quant schemes for per-group and
261
- # per-channel quantization.
262
- def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
263
- orig_device = w.device
264
- size_k, size_n = w.shape
265
-
266
- assert w.is_floating_point(), "w must be float"
267
- assert (
268
- num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
269
- ), f"Unsupported num_bits = {num_bits}"
270
- assert group_size in SUPPORTED_GROUP_SIZES + [
271
- size_k
272
- ], f"Unsupported groupsize = {group_size}"
273
-
274
- if group_size == -1:
275
- group_size = size_k
276
- assert group_size <= size_k
277
-
278
- if group_size < size_k:
279
- # Reshape to [groupsize, -1]
280
- w = w.reshape((-1, group_size, size_n))
281
- w = w.permute(1, 0, 2)
282
- w = w.reshape((group_size, -1))
283
-
284
- max_q_val = 2**num_bits - 1
285
- half_q_val = (max_q_val + 1) // 2
286
-
287
- # Compute scale for each group
288
- s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
289
- s_group *= 2 / max_q_val # 2 => symmetric
290
-
291
- # Quantize
292
- q_w = torch.round(w / s_group).int()
293
- q_w += half_q_val
294
- q_w = torch.clamp(q_w, 0, max_q_val)
295
- # Compute ref (dequantized)
296
- w_ref = (q_w - half_q_val).half() * s_group
297
-
298
- # Restore original shapes
299
- def reshape_w(w):
300
- w = w.reshape((group_size, -1, size_n))
301
- w = w.permute(1, 0, 2)
302
- w = w.reshape((size_k, size_n)).contiguous()
303
- return w
304
-
305
- q_w = reshape_w(q_w)
306
- w_ref = reshape_w(w_ref)
307
-
308
- # Compute int8 quantization scale for each channel
309
- s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
310
- s_channel /= 127.0
311
- t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
312
- w_ref = t_int8.half() * s_channel
313
- s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
314
-
315
- # Fuse scales
316
- s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
317
- dtype=torch.half
318
- )
319
- else:
320
- max_q_val = 2 ** (num_bits - 1) - 1
321
-
322
- # Compute scale for each channel
323
- s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
324
- s_channel /= max_q_val
325
-
326
- # Quantize
327
- q_w = torch.round(w / s_channel).int()
328
- q_w = torch.clamp(q_w, -max_q_val, max_q_val)
329
- # Compute ref (dequantized)
330
- w_ref = q_w.half() * s_channel
331
-
332
- s_group = torch.tensor([], dtype=torch.half)
333
- # div 2 ** (8 - self.bits)) to offset right shift in unpacking
334
- s_channel /= 2 ** (8 - num_bits)
335
- s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
336
-
337
- return (
338
- w_ref.to(device=orig_device),
339
- q_w.to(device=orig_device),
340
- s_group.to(device=orig_device),
341
- s_channel.to(device=orig_device),
342
- )
343
-
344
-
345
- def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
346
- orig_device = q_w.device
347
-
348
- sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
349
-
350
- g_idx = g_idx[sort_indices].contiguous()
351
- q_w = q_w[sort_indices, :].contiguous()
352
-
353
- return (
354
- q_w.to(device=orig_device),
355
- g_idx.to(device=orig_device),
356
- sort_indices.to(device=orig_device),
357
- )
358
-
359
-
360
- def pack_rows(
361
- q_w: torch.Tensor,
362
- num_bits: int,
363
- size_k: int,
364
- size_n: int,
365
- ):
366
- assert q_w.shape == (size_k, size_n)
367
-
368
- pack_factor = get_pack_factor(num_bits)
369
- assert size_k % pack_factor == 0
370
-
371
- orig_device = q_w.device
372
-
373
- q_w = q_w.cpu().numpy().astype(numpy.uint32)
374
-
375
- q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
376
-
377
- for i in range(pack_factor):
378
- q_res |= q_w[i::pack_factor, :] << num_bits * i
379
-
380
- q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
381
- return q_res
382
-
383
-
384
- def pack_cols(
385
- q_w: torch.Tensor,
386
- num_bits: int,
387
- size_k: int,
388
- size_n: int,
389
- ):
390
- assert q_w.shape == (size_k, size_n)
391
-
392
- pack_factor = get_pack_factor(num_bits)
393
- assert size_n % pack_factor == 0
394
-
395
- orig_device = q_w.device
396
-
397
- q_w = q_w.cpu().numpy().astype(numpy.uint32)
398
-
399
- q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
400
-
401
- for i in range(pack_factor):
402
- q_res |= q_w[:, i::pack_factor] << num_bits * i
403
-
404
- q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
405
- q_res = q_res.contiguous()
406
-
407
- return q_res
408
-
409
-
410
- def unpack_cols(
411
- packed_q_w: torch.Tensor,
412
- num_bits: int,
413
- size_k: int,
414
- size_n: int,
415
- ):
416
- pack_factor = get_pack_factor(num_bits)
417
- assert size_n % pack_factor == 0
418
- assert packed_q_w.shape == (
419
- size_k,
420
- size_n // pack_factor,
421
- ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
422
- packed_q_w.shape, size_k, size_n, pack_factor
423
- )
424
-
425
- orig_device = packed_q_w.device
426
-
427
- packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
428
- q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
429
-
430
- mask = (1 << num_bits) - 1
431
- for i in range(pack_factor):
432
- vals = packed_q_w_cpu & mask
433
- packed_q_w_cpu >>= num_bits
434
- q_res[:, i::pack_factor] = vals
435
-
436
- q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
437
- q_res = q_res.contiguous()
438
-
439
- return q_res
440
-
441
-
442
- def gptq_pack(
443
- q_w: torch.Tensor,
444
- num_bits: int,
445
- size_k: int,
446
- size_n: int,
447
- ):
448
- return pack_rows(q_w, num_bits, size_k, size_n)
449
-
450
-
451
- def awq_pack(
452
- q_w: torch.Tensor,
453
- num_bits: int,
454
- size_k: int,
455
- size_n: int,
456
- ):
457
- assert q_w.shape == (size_k, size_n)
458
-
459
- # Interleave column dim (for the dequantize code) and pack it to int32
460
- if num_bits == 4:
461
- interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
462
- elif num_bits == 8:
463
- interleave = numpy.array([0, 2, 1, 3])
464
- else:
465
- raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
466
-
467
- q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
468
- q_w = q_w.reshape((-1, size_n)).contiguous()
469
-
470
- return pack_cols(q_w, num_bits, size_k, size_n)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx98-cu118-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3885094b146b702d2ac23780b0f102500c30f46e53cfaf42bf527d708485979a
3
- size 63479104
 
 
 
 
build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3054bac8793d647ff85596779b3388f54728e817c26f5584ef8e73c817bd144
3
+ size 87821400
build/torch26-cxx98-cu124-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:50d412800766f33cd706a8126bc47e8119001bcd1cadd96a8e408be114b3b1b7
3
- size 67509408
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9d84de56121fca6315c107147c426de05703359f1347ca5f882933e227a7aa9
3
+ size 93711616
build/torch26-cxx98-cu126-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:acf7c34de931bd4ab88454fad93df1afeccc84b8158a410fb9153ba42e5e82bc
3
- size 68271904
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:765fb0e44ed12d4c9bdc7d99efde5707271024bafc621b168039de0947f01db7
3
+ size 94506888
build/torch27-cxx11-cu118-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a1650df5eeb0e1494932eec92425751998ea2181b86995671edb757ccf6aeb5
3
- size 63484856
 
 
 
 
build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_85bad96.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ce9bf91bfd616af0cf3887e3a46fe40a96638b122a8aedfcc95a8d461f87801
3
+ size 87827240
build/torch27-cxx11-cu126-x86_64-linux/quantization/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _quantization_0435ccb
3
- ops = torch.ops._quantization_0435ccb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_quantization_0435ccb::{op_name}"
 
1
  import torch
2
+ from . import _quantization_85bad96
3
+ ops = torch.ops._quantization_85bad96
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_quantization_85bad96::{op_name}"